fix(tensor/module): fixed and moved BatchAccuracyForLogits to nn/sequential; chore(example): clean-up
This commit is contained in:
parent
44ef7776e5
commit
8b05753eb4
|
@ -1,87 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
// Training various models on the CIFAR-10 dataset.
|
|
||||||
//
|
|
||||||
// The dataset can be downloaded from https:www.cs.toronto.edu/~kriz/cifar.html, files
|
|
||||||
// should be placed in the data/ directory.
|
|
||||||
//
|
|
||||||
// The resnet model reaches 95.4% accuracy.
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
// "log"
|
|
||||||
// "os/exec"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
"github.com/sugarme/gotch/vision"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
dir := "../../data/cifar10"
|
|
||||||
ds := vision.CFLoadDir(dir)
|
|
||||||
|
|
||||||
fmt.Printf("TrainImages shape: %v\n", ds.TrainImages.MustSize())
|
|
||||||
fmt.Printf("TrainLabel shape: %v\n", ds.TrainLabels.MustSize())
|
|
||||||
fmt.Printf("TestImages shape: %v\n", ds.TestImages.MustSize())
|
|
||||||
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
|
|
||||||
fmt.Printf("Number of labels: %v\n", ds.Labels)
|
|
||||||
|
|
||||||
// cuda := gotch.CudaBuilder(0)
|
|
||||||
// device := cuda.CudaIfAvailable()
|
|
||||||
device := gotch.CPU
|
|
||||||
|
|
||||||
var si *gotch.SI
|
|
||||||
si = gotch.GetSysInfo()
|
|
||||||
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
|
|
||||||
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
|
|
||||||
|
|
||||||
startRAM := si.TotalRam - si.FreeRam
|
|
||||||
|
|
||||||
vs := nn.NewVarStore(device)
|
|
||||||
|
|
||||||
for epoch := 0; epoch < 150; epoch++ {
|
|
||||||
|
|
||||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
|
||||||
iter.Shuffle()
|
|
||||||
|
|
||||||
for {
|
|
||||||
item, ok := iter.Next()
|
|
||||||
if !ok {
|
|
||||||
item.Data.MustDrop()
|
|
||||||
item.Label.MustDrop()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
devicedData := item.Data.MustTo(vs.Device(), true)
|
|
||||||
devicedLabel := item.Label.MustTo(vs.Device(), true)
|
|
||||||
bimages := vision.Augmentation(devicedData, true, 4, 8)
|
|
||||||
|
|
||||||
devicedData.MustDrop()
|
|
||||||
devicedLabel.MustDrop()
|
|
||||||
bimages.MustDrop()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
iter.Drop()
|
|
||||||
|
|
||||||
si = gotch.GetSysInfo()
|
|
||||||
memUsed := (float64(si.TotalRam-si.FreeRam) - float64(startRAM)) / 1024
|
|
||||||
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\n", epoch, memUsed)
|
|
||||||
|
|
||||||
/*
|
|
||||||
* // Print out GPU used
|
|
||||||
* nvidia := "nvidia-smi"
|
|
||||||
* cmd := exec.Command(nvidia)
|
|
||||||
* stdout, err := cmd.Output()
|
|
||||||
*
|
|
||||||
* if err != nil {
|
|
||||||
* log.Fatal(err.Error())
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* fmt.Println(string(stdout))
|
|
||||||
* */
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -10,7 +10,6 @@ package main
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
// "os/exec"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
|
@ -80,17 +79,6 @@ func fastResnet(p nn.Path) (retVal nn.SequentialT) {
|
||||||
return seq
|
return seq
|
||||||
}
|
}
|
||||||
|
|
||||||
func learningRate(epoch int) (retVal float64) {
|
|
||||||
switch {
|
|
||||||
case epoch < 50:
|
|
||||||
return 0.1
|
|
||||||
case epoch < 100:
|
|
||||||
return 0.01
|
|
||||||
default:
|
|
||||||
return 0.001
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
dir := "../../data/cifar10"
|
dir := "../../data/cifar10"
|
||||||
ds := vision.CFLoadDir(dir)
|
ds := vision.CFLoadDir(dir)
|
||||||
|
@ -103,50 +91,42 @@ func main() {
|
||||||
|
|
||||||
cuda := gotch.CudaBuilder(0)
|
cuda := gotch.CudaBuilder(0)
|
||||||
device := cuda.CudaIfAvailable()
|
device := cuda.CudaIfAvailable()
|
||||||
// device := gotch.CPU
|
|
||||||
|
|
||||||
vs := nn.NewVarStore(device)
|
vs := nn.NewVarStore(device)
|
||||||
|
|
||||||
net := fastResnet(vs.Root())
|
net := fastResnet(vs.Root())
|
||||||
|
|
||||||
// optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
|
||||||
// opt, err := optConfig.Build(vs, 0.01)
|
|
||||||
// if err != nil {
|
|
||||||
// log.Fatal(err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
var lossVal float64
|
var lossVal float64
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
var bestAccuracy float64
|
var bestAccuracy float64
|
||||||
|
|
||||||
for epoch := 0; epoch < 350; epoch++ {
|
for epoch := 0; epoch < 150; epoch++ {
|
||||||
// opt.SetLR(learningRate(epoch))
|
|
||||||
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||||
var opt nn.Optimizer
|
var (
|
||||||
var err error
|
opt nn.Optimizer
|
||||||
|
err error
|
||||||
|
)
|
||||||
switch {
|
switch {
|
||||||
case epoch < 150:
|
case epoch < 50:
|
||||||
opt, err = optConfig.Build(vs, 0.1)
|
opt, err = optConfig.Build(vs, 0.1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
case epoch < 250:
|
case epoch < 100:
|
||||||
opt, err = optConfig.Build(vs, 0.01)
|
opt, err = optConfig.Build(vs, 0.01)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
case epoch >= 250:
|
case epoch >= 100:
|
||||||
opt, err = optConfig.Build(vs, 0.001)
|
opt, err = optConfig.Build(vs, 0.001)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(128))
|
|
||||||
iter.Shuffle()
|
iter.Shuffle()
|
||||||
// iter = iter.ToDevice(device)
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
item, ok := iter.Next()
|
item, ok := iter.Next()
|
||||||
|
@ -171,63 +151,14 @@ func main() {
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
vs.Freeze()
|
testAcc := nn.BatchAccuracyForLogits(vs, net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||||
testAcc := batchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 100)
|
|
||||||
vs.Unfreeze()
|
|
||||||
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||||
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
|
||||||
if testAcc > bestAccuracy {
|
if testAcc > bestAccuracy {
|
||||||
bestAccuracy = testAcc
|
bestAccuracy = testAcc
|
||||||
}
|
}
|
||||||
iter.Drop()
|
iter.Drop()
|
||||||
|
|
||||||
/*
|
|
||||||
* // Print out GPU used
|
|
||||||
* nvidia := "nvidia-smi"
|
|
||||||
* cmd := exec.Command(nvidia)
|
|
||||||
* stdout, err := cmd.Output()
|
|
||||||
*
|
|
||||||
* if err != nil {
|
|
||||||
* log.Fatal(err.Error())
|
|
||||||
* }
|
|
||||||
*
|
|
||||||
* fmt.Println(string(stdout))
|
|
||||||
* */
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
|
||||||
fmt.Printf("Best Accuracy: %10.2f%%\n", bestAccuracy*100.0)
|
fmt.Printf("Best Accuracy: %10.2f%%\n", bestAccuracy*100.0)
|
||||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func batchAccuracyForLogits(m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
sumAccuracy float64 = 0.0
|
|
||||||
sampleCount float64 = 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
|
||||||
for {
|
|
||||||
item, ok := iter2.Next()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
size := float64(item.Data.MustSize()[0])
|
|
||||||
bImages := item.Data.MustTo(d, true)
|
|
||||||
bLabels := item.Label.MustTo(d, true)
|
|
||||||
|
|
||||||
logits := m.ForwardT(bImages, false)
|
|
||||||
acc := logits.AccuracyForLogits(bLabels)
|
|
||||||
sumAccuracy += acc.Values()[0] * size
|
|
||||||
sampleCount += size
|
|
||||||
|
|
||||||
bImages.MustDrop()
|
|
||||||
bLabels.MustDrop()
|
|
||||||
acc.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
return sumAccuracy / sampleCount
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,16 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
var d gotch.Cuda
|
|
||||||
fmt.Printf("Cuda device count: %v\n", d.DeviceCount())
|
|
||||||
fmt.Printf("Cuda is available: %v\n", d.IsAvailable())
|
|
||||||
fmt.Printf("Cudnn is available: %v\n", d.CudnnIsAvailable())
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,8 +1,8 @@
|
||||||
// A wrapper around the linux syscall sysinfo(2).
|
|
||||||
package main
|
package main
|
||||||
|
|
||||||
|
// helper to debug memory blow-up
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
@ -34,7 +34,7 @@ var sis = &SI{}
|
||||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/zsyscall_linux_amd64.go#L1050
|
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/zsyscall_linux_amd64.go#L1050
|
||||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_amd64.go#L528
|
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_amd64.go#L528
|
||||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_arm.go#L502
|
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_arm.go#L502
|
||||||
func Get() *SI {
|
func CPUInfo() *SI {
|
||||||
|
|
||||||
/*
|
/*
|
||||||
// Note: uint64 is uint32 on 32 bit CPUs
|
// Note: uint64 is uint32 on 32 bit CPUs
|
||||||
|
@ -89,36 +89,3 @@ func Get() *SI {
|
||||||
|
|
||||||
return sis
|
return sis
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make the "fmt" Stringer interface happy.
|
|
||||||
func (si SI) String() string {
|
|
||||||
// XXX: Is the copy of SI done atomic? Not sure.
|
|
||||||
// Without an outer lock this may print a junk.
|
|
||||||
return fmt.Sprintf("uptime\t\t%v\nload\t\t%2.2f %2.2f %2.2f\nprocs\t\t%d\n"+
|
|
||||||
"ram total\t%d kB\nram free\t%d kB\nram buffer\t%d kB\n"+
|
|
||||||
"swap total\t%d kB\nswap free\t%d kB",
|
|
||||||
//"high ram total\t%d kB\nhigh ram free\t%d kB\n"
|
|
||||||
si.Uptime, si.Loads[0], si.Loads[1], si.Loads[2], si.Procs,
|
|
||||||
si.TotalRam, si.FreeRam, si.BufferRam,
|
|
||||||
si.TotalSwap, si.FreeSwap,
|
|
||||||
// archaic si.TotalHighRam, si.FreeHighRam
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Convert to string in a thread safe way.
|
|
||||||
Output:
|
|
||||||
uptime 279h6m21s
|
|
||||||
load 0.12 0.04 0.05
|
|
||||||
procs 143
|
|
||||||
ram total 383752 kB
|
|
||||||
ram free 254980 kB
|
|
||||||
ram buffer 7640 kB
|
|
||||||
swap total 887800 kB
|
|
||||||
swap free 879356 kB
|
|
||||||
*/
|
|
||||||
func (si *SI) ToString() string {
|
|
||||||
defer si.mu.Unlock()
|
|
||||||
si.mu.Lock()
|
|
||||||
return si.String()
|
|
||||||
}
|
|
20
example/debug-memory/gpu.go
Normal file
20
example/debug-memory/gpu.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GPUInfo() {
|
||||||
|
// Print out GPU used
|
||||||
|
nvidia := "nvidia-smi"
|
||||||
|
cmd := exec.Command(nvidia)
|
||||||
|
stdout, err := cmd.Output()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(string(stdout))
|
||||||
|
}
|
90
example/debug-memory/main.go
Normal file
90
example/debug-memory/main.go
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
var device string
|
||||||
|
|
||||||
|
func createTensors(samples int) []ts.Tensor {
|
||||||
|
n := int(10e6)
|
||||||
|
var data []float64
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
data = append(data, float64(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tensors []ts.Tensor
|
||||||
|
s := ts.FloatScalar(float64(0.23))
|
||||||
|
|
||||||
|
for i := 0; i < 1; i++ {
|
||||||
|
t := ts.MustOfSlice(data).MustMul1(s, true)
|
||||||
|
|
||||||
|
tensors = append(tensors, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tensors
|
||||||
|
}
|
||||||
|
|
||||||
|
func dropTensors(tensors []ts.Tensor) {
|
||||||
|
for _, t := range tensors {
|
||||||
|
t.MustDrop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringVar(&device, "device", "CPU", "Select CPU or GPU to use")
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// TODO: create flags to load tensor to device(CPU, GPU) and get CPU or GPU
|
||||||
|
// infor accordingly
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
switch device {
|
||||||
|
case "CPU":
|
||||||
|
var si *SI
|
||||||
|
si = CPUInfo()
|
||||||
|
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
|
||||||
|
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
|
||||||
|
startRAM := si.TotalRam - si.FreeRam
|
||||||
|
epochs := 50
|
||||||
|
for i := 0; i < epochs; i++ {
|
||||||
|
tensors := createTensors(10000)
|
||||||
|
dropTensors(tensors)
|
||||||
|
|
||||||
|
si = CPUInfo()
|
||||||
|
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", i, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "GPU":
|
||||||
|
cuda := gotch.CudaBuilder(0)
|
||||||
|
gpu := cuda.CudaIfAvailable()
|
||||||
|
|
||||||
|
epochs := 50
|
||||||
|
for i := 0; i < epochs; i++ {
|
||||||
|
|
||||||
|
tensors := createTensors(10000)
|
||||||
|
var gpuTensors []ts.Tensor
|
||||||
|
for _, t := range tensors {
|
||||||
|
gpuTensors = append(gpuTensors, t.MustTo(gpu, true))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, t := range gpuTensors {
|
||||||
|
t.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Epoch %v\n", i)
|
||||||
|
GPUInfo()
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Fatalf("Invalid device flag (%v). It should be either CPU or GPU.", device)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,53 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
// Try to compare 2 tensor with incompatible dimensions
|
|
||||||
// and check this returns an error
|
|
||||||
dx := []int32{1, 2, 3}
|
|
||||||
dy := []int32{1, 2, 3, 4}
|
|
||||||
// dy := []int32{1, 2, 5}
|
|
||||||
|
|
||||||
xs, err := tensor.OfSlice(dx)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
ys, err := tensor.OfSlice(dy)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
xs.Print()
|
|
||||||
ys.Print()
|
|
||||||
|
|
||||||
fmt.Printf("xs num of dimensions: %v\n", xs.Dim())
|
|
||||||
fmt.Printf("ys num of dimensions: %v\n", ys.Dim())
|
|
||||||
|
|
||||||
xsize, err := xs.Size()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ysize, err := ys.Size()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("xs shape: %v\n", xsize)
|
|
||||||
fmt.Printf("ys shape: %v\n", ysize)
|
|
||||||
|
|
||||||
res, err := xs.Eq1(ys)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Print()
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
// mockup data
|
|
||||||
var (
|
|
||||||
n int = 20
|
|
||||||
xvals []float32
|
|
||||||
yvals []float32
|
|
||||||
epochs = 10
|
|
||||||
)
|
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
xvals = append(xvals, float32(i))
|
|
||||||
yvals = append(yvals, float32(2*i+1))
|
|
||||||
}
|
|
||||||
|
|
||||||
xtrain, err := ts.NewTensorFromData(xvals, []int64{int64(n), 1})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
ytrain, err := ts.NewTensorFromData(yvals, []int64{int64(n), 1})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ws := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt())
|
|
||||||
bs := ts.MustZeros([]int64{1, int64(n)}, gotch.Float.CInt(), gotch.CPU.CInt())
|
|
||||||
|
|
||||||
for epoch := 0; epoch < epochs; epoch++ {
|
|
||||||
|
|
||||||
logit := ws.MustMatMul(xtrain).MustAdd(bs)
|
|
||||||
loss := ts.NewTensor().MustLogSoftmax(-1, gotch.Float.CInt())
|
|
||||||
|
|
||||||
ws.MustGrad()
|
|
||||||
bs.MustGrad()
|
|
||||||
loss.MustBackward()
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,17 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
vs := nn.NewVarStore(gotch.CPU)
|
|
||||||
|
|
||||||
path := vs.Root()
|
|
||||||
|
|
||||||
l := nn.NewLinear(path, 4, 3, nn.DefaultLinearConfig())
|
|
||||||
|
|
||||||
l.Bs.Print()
|
|
||||||
}
|
|
|
@ -13,21 +13,21 @@
|
||||||
- Run with `go clean -cache -testcache && go run . -model="linear"`
|
- Run with `go clean -cache -testcache && go run . -model="linear"`
|
||||||
|
|
||||||
|
|
||||||
- Accuraccy should be about **91.68%**.
|
- Accuracy should be about **91.68%**.
|
||||||
|
|
||||||
|
|
||||||
## Neural Network (NN)
|
## Neural Network (NN)
|
||||||
|
|
||||||
- Run with `go clean -cache -testcache && go run . -model="nn"`
|
- Run with `go clean -cache -testcache && go run . -model="nn"`
|
||||||
|
|
||||||
- Accuraccy should be about **TODO: update%**.
|
- Accuracy should be about **94%**.
|
||||||
|
|
||||||
|
|
||||||
## Convolutional Neural Network (CNN)
|
## Convolutional Neural Network (CNN)
|
||||||
|
|
||||||
- Run with `go clean -cache -testcache && go run . -model="cnn"`
|
- Run with `go clean -cache -testcache && go run . -model="cnn"`
|
||||||
|
|
||||||
- Accuraccy should be about **TODO: update%**.
|
- Accuracy should be about **99.3%**.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,26 +46,21 @@ func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
defer outView1.MustDrop()
|
defer outView1.MustDrop()
|
||||||
|
|
||||||
outC1 := outView1.Apply(n.conv1)
|
outC1 := outView1.Apply(n.conv1)
|
||||||
// defer outC1.MustDrop()
|
|
||||||
|
|
||||||
outMP1 := outC1.MaxPool2DDefault(2, true)
|
outMP1 := outC1.MaxPool2DDefault(2, true)
|
||||||
defer outMP1.MustDrop()
|
defer outMP1.MustDrop()
|
||||||
|
|
||||||
outC2 := outMP1.Apply(n.conv2)
|
outC2 := outMP1.Apply(n.conv2)
|
||||||
// defer outC2.MustDrop()
|
|
||||||
|
|
||||||
outMP2 := outC2.MaxPool2DDefault(2, true)
|
outMP2 := outC2.MaxPool2DDefault(2, true)
|
||||||
// defer outMP2.MustDrop()
|
|
||||||
|
|
||||||
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
|
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
|
||||||
defer outView2.MustDrop()
|
defer outView2.MustDrop()
|
||||||
|
|
||||||
outFC1 := outView2.Apply(&n.fc1)
|
outFC1 := outView2.Apply(&n.fc1)
|
||||||
// defer outFC1.MustDrop()
|
|
||||||
|
|
||||||
outRelu := outFC1.MustRelu(true)
|
outRelu := outFC1.MustRelu(true)
|
||||||
defer outRelu.MustDrop()
|
defer outRelu.MustDrop()
|
||||||
// outRelu.Dropout_(0.5, train)
|
|
||||||
outDropout := ts.MustDropout(outRelu, 0.5, train)
|
outDropout := ts.MustDropout(outRelu, 0.5, train)
|
||||||
defer outDropout.MustDrop()
|
defer outDropout.MustDrop()
|
||||||
|
|
||||||
|
@ -83,12 +78,14 @@ func runCNN1() {
|
||||||
cuda := gotch.CudaBuilder(0)
|
cuda := gotch.CudaBuilder(0)
|
||||||
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
||||||
// vs := nn.NewVarStore(gotch.CPU)
|
// vs := nn.NewVarStore(gotch.CPU)
|
||||||
|
|
||||||
net := newNet(vs.Root())
|
net := newNet(vs.Root())
|
||||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var bestAccuracy float64 = 0.0
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||||
|
@ -102,20 +99,16 @@ func runCNN1() {
|
||||||
batches := samples / batchSize
|
batches := samples / batchSize
|
||||||
batchIndex := 0
|
batchIndex := 0
|
||||||
var epocLoss ts.Tensor
|
var epocLoss ts.Tensor
|
||||||
// var loss ts.Tensor
|
|
||||||
for i := 0; i < batches; i++ {
|
for i := 0; i < batches; i++ {
|
||||||
start := batchIndex * batchSize
|
start := batchIndex * batchSize
|
||||||
size := batchSize
|
size := batchSize
|
||||||
if samples-start < batchSize {
|
if samples-start < batchSize {
|
||||||
// size = samples - start
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
batchIndex += 1
|
batchIndex += 1
|
||||||
|
|
||||||
// Indexing
|
// Indexing
|
||||||
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||||
// bImages := ds.TrainImages.Idx(narrowIndex)
|
|
||||||
// bLabels := ds.TrainLabels.Idx(narrowIndex)
|
|
||||||
bImages := imagesTs.Idx(narrowIndex)
|
bImages := imagesTs.Idx(narrowIndex)
|
||||||
bLabels := labelsTs.Idx(narrowIndex)
|
bLabels := labelsTs.Idx(narrowIndex)
|
||||||
|
|
||||||
|
@ -126,7 +119,6 @@ func runCNN1() {
|
||||||
loss := logits.CrossEntropyForLogits(bLabels)
|
loss := logits.CrossEntropyForLogits(bLabels)
|
||||||
|
|
||||||
// loss = loss.MustSetRequiresGrad(true)
|
// loss = loss.MustSetRequiresGrad(true)
|
||||||
|
|
||||||
opt.BackwardStep(loss)
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
epocLoss = loss.MustShallowClone()
|
epocLoss = loss.MustShallowClone()
|
||||||
|
@ -136,112 +128,21 @@ func runCNN1() {
|
||||||
|
|
||||||
bImages.MustDrop()
|
bImages.MustDrop()
|
||||||
bLabels.MustDrop()
|
bLabels.MustDrop()
|
||||||
// logits.MustDrop()
|
|
||||||
// loss.MustDrop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vs.Freeze()
|
vs.Freeze()
|
||||||
testAccuracy := batchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024)
|
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
||||||
vs.Unfreeze()
|
vs.Unfreeze()
|
||||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100.0)
|
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100.0)
|
||||||
|
if testAccuracy > bestAccuracy {
|
||||||
|
bestAccuracy = testAccuracy
|
||||||
|
}
|
||||||
|
|
||||||
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
|
|
||||||
epocLoss.MustDrop()
|
epocLoss.MustDrop()
|
||||||
imagesTs.MustDrop()
|
imagesTs.MustDrop()
|
||||||
labelsTs.MustDrop()
|
labelsTs.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
|
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
|
||||||
fmt.Printf("Test accuracy: %.2f%%\n", testAccuracy*100)
|
|
||||||
|
|
||||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||||
}
|
}
|
||||||
|
|
||||||
func runCNN2() {
|
|
||||||
|
|
||||||
var ds vision.Dataset
|
|
||||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
|
||||||
|
|
||||||
cuda := gotch.CudaBuilder(0)
|
|
||||||
vs := nn.NewVarStore(cuda.CudaIfAvailable())
|
|
||||||
net := newNet(vs.Root())
|
|
||||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
var lossVal float64
|
|
||||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
|
||||||
|
|
||||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, batchCNN)
|
|
||||||
// iter.Shuffle()
|
|
||||||
|
|
||||||
for {
|
|
||||||
item, ok := iter.Next()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
bImages := item.Data.MustTo(vs.Device(), true)
|
|
||||||
bLabels := item.Label.MustTo(vs.Device(), true)
|
|
||||||
|
|
||||||
// _ = ts.MustGradSetEnabled(true)
|
|
||||||
|
|
||||||
logits := net.ForwardT(bImages, true)
|
|
||||||
loss := logits.CrossEntropyForLogits(bLabels)
|
|
||||||
|
|
||||||
opt.BackwardStep(loss)
|
|
||||||
|
|
||||||
lossVal = loss.Values()[0]
|
|
||||||
|
|
||||||
bImages.MustDrop()
|
|
||||||
bLabels.MustDrop()
|
|
||||||
loss.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
|
|
||||||
|
|
||||||
vs.Freeze()
|
|
||||||
testAcc := batchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
|
||||||
vs.Unfreeze()
|
|
||||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100.0)
|
|
||||||
}
|
|
||||||
|
|
||||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
|
||||||
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
|
||||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
|
||||||
}
|
|
||||||
|
|
||||||
func batchAccuracyForLogits(m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
|
||||||
|
|
||||||
var (
|
|
||||||
sumAccuracy float64 = 0.0
|
|
||||||
sampleCount float64 = 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
|
||||||
for {
|
|
||||||
item, ok := iter2.Next()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
size := float64(item.Data.MustSize()[0])
|
|
||||||
bImages := item.Data.MustTo(d, true)
|
|
||||||
bLabels := item.Label.MustTo(d, true)
|
|
||||||
|
|
||||||
logits := m.ForwardT(bImages, false)
|
|
||||||
acc := logits.AccuracyForLogits(bLabels)
|
|
||||||
sumAccuracy += acc.Values()[0] * size
|
|
||||||
sampleCount += size
|
|
||||||
|
|
||||||
bImages.MustDrop()
|
|
||||||
bLabels.MustDrop()
|
|
||||||
acc.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
return sumAccuracy / sampleCount
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,64 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func testOptimizer() {
|
|
||||||
|
|
||||||
var data []float64
|
|
||||||
for i := 0; i < 15; i++ {
|
|
||||||
data = append(data, float64(i))
|
|
||||||
}
|
|
||||||
xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ys := xs.MustMul1(ts.FloatScalar(0.42)).MustAdd1(ts.FloatScalar(1.337))
|
|
||||||
|
|
||||||
vs := nn.NewVarStore(gotch.CPU)
|
|
||||||
|
|
||||||
cfg := nn.LinearConfig{
|
|
||||||
WsInit: nn.NewConstInit(0.001),
|
|
||||||
BsInit: nn.NewConstInit(0.001),
|
|
||||||
Bias: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
// fmt.Printf("Number of trainable variables: %v\n", vs.Len())
|
|
||||||
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
|
||||||
// fmt.Printf("Trainable variables at app: %v\n", vs.TrainableVariable())
|
|
||||||
|
|
||||||
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
|
||||||
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
|
||||||
fmt.Printf("Initial Loss: %.3f\n", initialLoss)
|
|
||||||
|
|
||||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal("Failed building SGD optimizer")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
// loss = xs.Apply(linear)
|
|
||||||
loss = linear.Forward(xs)
|
|
||||||
loss = loss.MustMseLoss(ys, ts.ReductionMean.ToInt())
|
|
||||||
|
|
||||||
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
|
||||||
|
|
||||||
opt.BackwardStep(loss)
|
|
||||||
|
|
||||||
fmt.Printf("Bs: %.3f - Bs Grad: %.3f\n", linear.Bs.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Bs.MustGrad().MustFloat64Value([]int64{0}))
|
|
||||||
fmt.Printf("Ws: %.3f - Ws Grad: %.3f\n", linear.Ws.MustView([]int64{-1}).MustFloat64Value([]int64{0}), linear.Ws.MustGrad().MustFloat64Value([]int64{0}))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
testOptimizer()
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
s := tensor.FloatScalar(float64(1.23))
|
|
||||||
fmt.Printf("scalar value: %v\n", s)
|
|
||||||
|
|
||||||
intVal, err := s.ToInt()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
floatVal, err := s.ToFloat()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
strVal, err := s.ToString()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("scalar to int64 value: %v\n", intVal)
|
|
||||||
fmt.Printf("scalar to float64 value: %v\n", floatVal)
|
|
||||||
fmt.Printf("scalar to string value: %v\n", strVal)
|
|
||||||
|
|
||||||
s.Drop() // will set scalar to zero
|
|
||||||
fmt.Printf("scalar value: %v\n", s)
|
|
||||||
|
|
||||||
zeroVal, err := s.ToInt()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Panic: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Won't expect this val: %v\n", zeroVal)
|
|
||||||
}
|
|
|
@ -1,205 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
"github.com/sugarme/gotch/vision"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// noSeq()
|
|
||||||
withSeq()
|
|
||||||
// noSeq2Layers()
|
|
||||||
|
|
||||||
// seqNoVarStore()
|
|
||||||
}
|
|
||||||
|
|
||||||
func noSeq() {
|
|
||||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
|
||||||
|
|
||||||
wsInit := nn.NewKaimingUniformInit()
|
|
||||||
ws := wsInit.InitTensor([]int64{10, 784}, gotch.CPU).MustT(true)
|
|
||||||
|
|
||||||
bound := 1.0 / math.Sqrt(float64(784))
|
|
||||||
bsInit := nn.NewUniformInit(-bound, bound)
|
|
||||||
bs := bsInit.InitTensor([]int64{10}, gotch.CPU)
|
|
||||||
|
|
||||||
for i := 0; i < 2000; i++ {
|
|
||||||
mul := ds.TrainImages.MustMatMul(ws, false)
|
|
||||||
logits := mul.MustAdd(bs, true)
|
|
||||||
loss := logits.AccuracyForLogits(ds.TrainLabels)
|
|
||||||
|
|
||||||
fmt.Printf("Epoch %v\t Loss: %.3f\n", i, loss.Values()[0])
|
|
||||||
loss.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func withSeq() {
|
|
||||||
seq := nn.Seq()
|
|
||||||
vs := nn.NewVarStore(gotch.CPU)
|
|
||||||
// seq.Add(nn.NewLinear(vs.Root(), 784, 10, *nn.DefaultLinearConfig()))
|
|
||||||
seq.Add(nn.NewLinear(vs.Root(), 784, 128, *nn.DefaultLinearConfig()))
|
|
||||||
seq.Add(nn.NewLinear(vs.Root(), 128, 10, *nn.DefaultLinearConfig()))
|
|
||||||
|
|
||||||
opt, err := nn.DefaultAdamConfig().Build(vs, 1e-2)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
|
||||||
|
|
||||||
for i := 0; i < 2000; i++ {
|
|
||||||
logits := seq.Forward(ds.TrainImages)
|
|
||||||
loss := logits.CrossEntropyForLogits(ds.TrainLabels)
|
|
||||||
opt.BackwardStep(loss)
|
|
||||||
|
|
||||||
testLogits := seq.Forward(ds.TestImages)
|
|
||||||
testAccuracy := testLogits.AccuracyForLogits(ds.TestLabels)
|
|
||||||
|
|
||||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", i, loss.Values()[0], testAccuracy.Values()[0]*100)
|
|
||||||
|
|
||||||
loss.MustDrop()
|
|
||||||
testAccuracy.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func noSeq2Layers() {
|
|
||||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
|
||||||
|
|
||||||
wsInit := nn.NewKaimingUniformInit()
|
|
||||||
ws1 := wsInit.InitTensor([]int64{1024, 784}, gotch.CPU).MustT(true)
|
|
||||||
ws2 := wsInit.InitTensor([]int64{10, 1024}, gotch.CPU).MustT(true)
|
|
||||||
|
|
||||||
bound1 := 1.0 / math.Sqrt(float64(784))
|
|
||||||
bsInit1 := nn.NewUniformInit(-bound1, bound1)
|
|
||||||
bs1 := bsInit1.InitTensor([]int64{1024}, gotch.CPU)
|
|
||||||
|
|
||||||
bound2 := 1.0 / math.Sqrt(float64(1024))
|
|
||||||
bsInit2 := nn.NewUniformInit(-bound2, bound2)
|
|
||||||
bs2 := bsInit2.InitTensor([]int64{10}, gotch.CPU)
|
|
||||||
|
|
||||||
for i := 0; i < 2000; i++ {
|
|
||||||
mul1 := ds.TrainImages.MustMatMul(ws1, false)
|
|
||||||
out1 := mul1.MustAdd(bs1, true)
|
|
||||||
|
|
||||||
mul2 := out1.MustMatMul(ws2, true)
|
|
||||||
logits := mul2.MustAdd(bs2, true)
|
|
||||||
|
|
||||||
loss := logits.AccuracyForLogits(ds.TrainLabels)
|
|
||||||
|
|
||||||
fmt.Printf("Epoch %v\t Loss: %.3f\n", i, loss.Values()[0])
|
|
||||||
loss.MustDrop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func seqNoVarStore() {
|
|
||||||
|
|
||||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
|
||||||
|
|
||||||
wsInit := nn.NewKaimingUniformInit()
|
|
||||||
ws1 := wsInit.InitTensor([]int64{1024, 784}, gotch.CPU).MustT(true)
|
|
||||||
ws2 := wsInit.InitTensor([]int64{10, 1024}, gotch.CPU).MustT(true)
|
|
||||||
|
|
||||||
bound1 := 1.0 / math.Sqrt(float64(784))
|
|
||||||
bsInit1 := nn.NewUniformInit(-bound1, bound1)
|
|
||||||
bs1 := bsInit1.InitTensor([]int64{1024}, gotch.CPU)
|
|
||||||
|
|
||||||
bound2 := 1.0 / math.Sqrt(float64(1024))
|
|
||||||
bsInit2 := nn.NewUniformInit(-bound2, bound2)
|
|
||||||
bs2 := bsInit2.InitTensor([]int64{10}, gotch.CPU)
|
|
||||||
|
|
||||||
l1 := Linear{&ws1, &bs1}
|
|
||||||
l2 := Linear{&ws2, &bs2}
|
|
||||||
|
|
||||||
seq := Seq()
|
|
||||||
seq.Add(l1)
|
|
||||||
seq.Add(l2)
|
|
||||||
// seq.Add1(l1)
|
|
||||||
// seq.Add2(l2)
|
|
||||||
|
|
||||||
for i := 0; i < 2000; i++ {
|
|
||||||
logits := seq.Forward(ds.TrainImages)
|
|
||||||
|
|
||||||
logits.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
type Linear struct {
|
|
||||||
Ws *ts.Tensor
|
|
||||||
Bs *ts.Tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l Linear) Forward(xs ts.Tensor) ts.Tensor {
|
|
||||||
mul := xs.MustMatMul(*l.Ws, false)
|
|
||||||
return mul.MustAdd(*l.Bs, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Sequential struct {
|
|
||||||
layers []ts.Module
|
|
||||||
l1 ts.Module
|
|
||||||
l2 ts.Module
|
|
||||||
}
|
|
||||||
|
|
||||||
func Seq() Sequential {
|
|
||||||
return Sequential{layers: make([]ts.Module, 0)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns number of sub-layers embedded in this layer
|
|
||||||
func (s *Sequential) Len() (retVal int64) {
|
|
||||||
return int64(len(s.layers))
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsEmpty returns true if this layer does not have any sub-layers.
|
|
||||||
func (s *Sequential) IsEmpty() (retVal bool) {
|
|
||||||
return len(s.layers) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add appends a layer after all the current layers.
|
|
||||||
func (s *Sequential) Add(l ts.Module) {
|
|
||||||
|
|
||||||
s.layers = append(s.layers, l)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sequential) Add1(l ts.Module) {
|
|
||||||
s.l1 = l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sequential) Add2(l ts.Module) {
|
|
||||||
s.l2 = l
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
|
||||||
if s.IsEmpty() {
|
|
||||||
return xs.MustShallowClone()
|
|
||||||
}
|
|
||||||
|
|
||||||
// forward sequentially
|
|
||||||
outs := make([]ts.Tensor, len(s.layers))
|
|
||||||
for i := 0; i < len(s.layers); i++ {
|
|
||||||
if i == 0 {
|
|
||||||
outs[0] = s.layers[i].Forward(xs)
|
|
||||||
defer outs[0].MustDrop()
|
|
||||||
} else if i == len(s.layers)-1 {
|
|
||||||
return s.layers[i].Forward(outs[i-1])
|
|
||||||
} else {
|
|
||||||
outs[i+1] = s.layers[i].Forward(outs[i-1])
|
|
||||||
defer outs[i+1].MustDrop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
// out1 := s.l1.Forward(xs)
|
|
||||||
// defer out1.MustDrop()
|
|
||||||
//
|
|
||||||
// return s.l2.Forward(out1)
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func myModule(p nn.Path, dim int64) ts.Module {
|
|
||||||
x1 := p.Zeros("x1", []int64{dim})
|
|
||||||
x2 := p.Zeros("x1", []int64{dim})
|
|
||||||
|
|
||||||
return nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
|
||||||
return xs.MustMul(x1).MustAdd(xs.MustExp().MustMul(x2))
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
vs := nn.NewVarStore(gotch.CPU)
|
|
||||||
|
|
||||||
m := myModule(vs.Root(), 7)
|
|
||||||
|
|
||||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
xs := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt())
|
|
||||||
ys := ts.MustZeros([]int64{7}, gotch.Float.CInt(), gotch.CPU.CInt())
|
|
||||||
|
|
||||||
loss := m.Forward(xs).MustSub(ys).MustPow(ts.IntScalar(2)).MustSum(gotch.Float.CInt())
|
|
||||||
|
|
||||||
opt.BackwardStep(loss)
|
|
||||||
|
|
||||||
fmt.Printf("Loss: %v\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,39 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
// TODO: Check Go type of data and tensor DType
|
|
||||||
// For. if data is []int and DType is Bool
|
|
||||||
// It is still running but get wrong result.
|
|
||||||
data := [][]int64{
|
|
||||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
|
||||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
|
||||||
}
|
|
||||||
shape := []int64{2, 8}
|
|
||||||
|
|
||||||
// data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
|
|
||||||
// shape := []int64{1, 8}
|
|
||||||
|
|
||||||
ts, err := tensor.NewTensorFromData(data, shape)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ts.Print()
|
|
||||||
|
|
||||||
numel := uint(6)
|
|
||||||
// dst := make([]uint8, numel)
|
|
||||||
var dst = make([]int64, 6)
|
|
||||||
|
|
||||||
ts.MustCopyData(dst, numel)
|
|
||||||
|
|
||||||
fmt.Println(dst)
|
|
||||||
|
|
||||||
}
|
|
|
@ -15,8 +15,8 @@ func main() {
|
||||||
xy := tensor.TensorFrom([]float64{2.0})
|
xy := tensor.TensorFrom([]float64{2.0})
|
||||||
xz := tensor.TensorFrom([]float64{3.0})
|
xz := tensor.TensorFrom([]float64{3.0})
|
||||||
|
|
||||||
y := x.MustMul(xy)
|
y := x.MustMul(xy, false)
|
||||||
z := x.MustMul(xz)
|
z := x.MustMul(xz, false)
|
||||||
|
|
||||||
y.Backward()
|
y.Backward()
|
||||||
xgrad := x.MustGrad()
|
xgrad := x.MustGrad()
|
||||||
|
@ -31,14 +31,3 @@ func main() {
|
||||||
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
fmt.Printf("Previous GradMode enabled state: %v\n", isGradEnabled)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* // Compute a second order derivative using run_backward.
|
|
||||||
* let mut x = Tensor::from(42.0).set_requires_grad(true);
|
|
||||||
* let y = &x * &x * &x + &x + &x * &x;
|
|
||||||
* x.zero_grad();
|
|
||||||
* let dy_over_dx = Tensor::run_backward(&[y], &[&x], true, true);
|
|
||||||
* assert_eq!(dy_over_dx.len(), 1);
|
|
||||||
* let dy_over_dx = &dy_over_dx[0];
|
|
||||||
* dy_over_dx.backward();
|
|
||||||
* let dy_over_dx2 = x.grad();
|
|
||||||
* assert_eq!(f64::from(&dy_over_dx2), 254.0); */
|
|
||||||
|
|
|
@ -1,41 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
// TODO: Check Go type of data and tensor DType
|
|
||||||
// For. if data is []int and DType is Bool
|
|
||||||
// It is still running but get wrong result.
|
|
||||||
data := [][]int64{
|
|
||||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
|
||||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
|
||||||
}
|
|
||||||
shape := []int64{2, 8}
|
|
||||||
|
|
||||||
ts, err := tensor.NewTensorFromData(data, shape)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ts, err = ts.To(gotch.CPU)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Tensor value BEFORE: %v\n", ts)
|
|
||||||
ts.Print()
|
|
||||||
|
|
||||||
scalarVal := tensor.IntScalar(int64(5))
|
|
||||||
|
|
||||||
ts.Fill_(scalarVal)
|
|
||||||
|
|
||||||
fmt.Printf("Tensor value AFTER: %v\n", ts)
|
|
||||||
ts.Print()
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
data := [][]int64{
|
|
||||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
|
||||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
|
||||||
}
|
|
||||||
shape := []int64{2, 8}
|
|
||||||
// shape := []int64{2, 2, 4}
|
|
||||||
|
|
||||||
ts, err := tensor.NewTensorFromData(data, shape)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ts.Print()
|
|
||||||
|
|
||||||
// Select
|
|
||||||
s := tensor.NewSelect(7)
|
|
||||||
// selectedTs := ts.Idx(s)
|
|
||||||
// selectedTs.Print()
|
|
||||||
|
|
||||||
// Narrow (start inclusive, end exclusive)
|
|
||||||
n := tensor.NewNarrow(0, 1)
|
|
||||||
// narrowedTs := ts.Idx(n)
|
|
||||||
// narrowedTs.Print()
|
|
||||||
|
|
||||||
// InsertNewAxis
|
|
||||||
// i := tensor.NewInsertNewAxis()
|
|
||||||
// newAxisTs := ts.Idx(i)
|
|
||||||
// newAxisTs.Print()
|
|
||||||
|
|
||||||
// IndexSelect
|
|
||||||
// idxTensor := tensor.MustOfSlice([]int64{0, 1})
|
|
||||||
// is := tensor.NewIndexSelect(idxTensor)
|
|
||||||
// isTs := ts.Idx(is)
|
|
||||||
// isTs.Print()
|
|
||||||
|
|
||||||
// Combined
|
|
||||||
var tsIndexes []tensor.TensorIndexer = []tensor.TensorIndexer{n, s}
|
|
||||||
combinedTs := ts.Idx(tsIndexes)
|
|
||||||
|
|
||||||
combinedTs.Print()
|
|
||||||
|
|
||||||
// Copy to index
|
|
||||||
desTs := tensor.MustZeros([]int64{5}, gotch.Float.CInt(), gotch.CPU.CInt())
|
|
||||||
srcTs := tensor.MustOnes([]int64{1}, gotch.Float.CInt(), gotch.CPU.CInt())
|
|
||||||
idx := tensor.NewNarrow(0, 3)
|
|
||||||
|
|
||||||
// NOTE: indexing operations return view on the same memory
|
|
||||||
desTs.Print()
|
|
||||||
desTs.Idx(idx).MustView([]int64{-1}, false).Copy_(srcTs)
|
|
||||||
desTs.Print()
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
|
||||||
|
|
||||||
var idxs []ts.TensorIndexer = []ts.TensorIndexer{
|
|
||||||
// ts.NewNarrow(0, tensor.MustSize()[0]),
|
|
||||||
// ts.NewNarrow(0, tensor.MustSize()[1]),
|
|
||||||
ts.NewInsertNewAxis(),
|
|
||||||
}
|
|
||||||
|
|
||||||
result := tensor.Idx(idxs)
|
|
||||||
|
|
||||||
fmt.Printf("Original Ts shape: %v\n", tensor.MustSize())
|
|
||||||
fmt.Printf("Result Ts shape: %v\n", result.MustSize())
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,32 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
data := [][]int64{
|
|
||||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
|
||||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
|
||||||
}
|
|
||||||
shape := []int64{16}
|
|
||||||
|
|
||||||
ts, err := tensor.NewTensorFromData(data, shape)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
it, err := ts.Iter(reflect.Float64)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < int(it.Len); i++ {
|
|
||||||
v := it.Next()
|
|
||||||
fmt.Println(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
// "runtime"
|
|
||||||
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func createTensors(samples int) []ts.Tensor {
|
|
||||||
n := int(10e6)
|
|
||||||
var data []float64
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
data = append(data, float64(i))
|
|
||||||
}
|
|
||||||
|
|
||||||
var tensors []ts.Tensor
|
|
||||||
s := ts.FloatScalar(float64(0.23))
|
|
||||||
|
|
||||||
// for i := 0; i < samples; i++ {
|
|
||||||
for i := 0; i < 1; i++ {
|
|
||||||
t := ts.MustOfSlice(data).MustMul1(s, true)
|
|
||||||
|
|
||||||
// t1.MustDrop()
|
|
||||||
// t.MustDrop()
|
|
||||||
// t1 = ts.Tensor{}
|
|
||||||
// t = ts.Tensor{}
|
|
||||||
// runtime.GC()
|
|
||||||
|
|
||||||
// fmt.Printf("t values: %v", t.Values())
|
|
||||||
// fmt.Printf("t1 values: %v", t1.Values())
|
|
||||||
tensors = append(tensors, t)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tensors
|
|
||||||
}
|
|
||||||
|
|
||||||
func dropTensors(tensors []ts.Tensor) {
|
|
||||||
for _, t := range tensors {
|
|
||||||
t.MustDrop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
var si *SI
|
|
||||||
si = Get()
|
|
||||||
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
|
|
||||||
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
|
|
||||||
|
|
||||||
startRAM := si.TotalRam - si.FreeRam
|
|
||||||
|
|
||||||
epochs := 50
|
|
||||||
// var m runtime.MemStats
|
|
||||||
|
|
||||||
for i := 0; i < epochs; i++ {
|
|
||||||
// runtime.ReadMemStats(&m)
|
|
||||||
// t0 := float64(m.Sys) / 1024 / 1024
|
|
||||||
|
|
||||||
tensors := createTensors(10000)
|
|
||||||
|
|
||||||
// runtime.ReadMemStats(&m)
|
|
||||||
// t1 := float64(m.Sys) / 1024 / 1024
|
|
||||||
|
|
||||||
dropTensors(tensors)
|
|
||||||
|
|
||||||
// runtime.ReadMemStats(&m)
|
|
||||||
// t2 := float64(m.Sys) / 1024 / 1024
|
|
||||||
|
|
||||||
// fmt.Printf("Epoch: %v \t Start Mem [%.3f MiB] \t Alloc Mem [%.3f MiB] \t Free Mem [%.3f MiB]\n", i, t0, t1, t2)
|
|
||||||
si = Get()
|
|
||||||
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", i, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,54 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
x := tensor.TensorFrom([]float64{2.0})
|
|
||||||
x = x.MustSetRequiresGrad(true)
|
|
||||||
x.ZeroGrad()
|
|
||||||
|
|
||||||
xmul := tensor.TensorFrom([]float64{3.0})
|
|
||||||
xadd := tensor.TensorFrom([]float64{5.0})
|
|
||||||
|
|
||||||
x1 := x.MustMul(xmul)
|
|
||||||
x2 := x1.MustMul(xmul)
|
|
||||||
x3 := x2.MustMul(xmul)
|
|
||||||
|
|
||||||
y := x3.MustAdd(xadd)
|
|
||||||
|
|
||||||
inputs := []tensor.Tensor{x}
|
|
||||||
|
|
||||||
dy_over_dx, err := tensor.RunBackward([]tensor.Tensor{y}, inputs, true, true)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("dy_over_dx length: %v\n", len(dy_over_dx))
|
|
||||||
|
|
||||||
// dy_over_dx1 := dy_over_dx[0]
|
|
||||||
// err = dy_over_dx1.Backward()
|
|
||||||
// if err != nil {
|
|
||||||
// log.Fatalf("Errors:\n, %v", err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
dy_over_dx[0].MustBackward()
|
|
||||||
|
|
||||||
x.MustGrad().Print()
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/* // Compute a second order derivative using run_backward.
|
|
||||||
* let mut x = Tensor::from(42.0).set_requires_grad(true);
|
|
||||||
* let y = &x * &x * &x + &x + &x * &x;
|
|
||||||
* x.zero_grad();
|
|
||||||
* let dy_over_dx = Tensor::run_backward(&[y], &[&x], true, true);
|
|
||||||
* assert_eq!(dy_over_dx.len(), 1);
|
|
||||||
* let dy_over_dx = &dy_over_dx[0];
|
|
||||||
* dy_over_dx.backward();
|
|
||||||
* let dy_over_dx2 = x.grad();
|
|
||||||
* assert_eq!(f64::from(&dy_over_dx2), 254.0); */
|
|
|
@ -1,89 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
// TODO: Check Go type of data and tensor DType
|
|
||||||
// For. if data is []int and DType is Bool
|
|
||||||
// It is still running but get wrong result.
|
|
||||||
data := [][]int64{
|
|
||||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
|
||||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
|
||||||
}
|
|
||||||
shape := []int64{2, 8}
|
|
||||||
// shape := []int64{2, 2, 4}
|
|
||||||
|
|
||||||
// dtype := gotch.Int
|
|
||||||
// ts := tensor.NewTensor()
|
|
||||||
// sliceTensor, err := ts.FOfSlice(data, dtype)
|
|
||||||
// if err != nil {
|
|
||||||
// log.Fatal(err)
|
|
||||||
// }
|
|
||||||
|
|
||||||
ts, err := tensor.NewTensorFromData(data, shape)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ts.Print()
|
|
||||||
|
|
||||||
sz, err := ts.Size2()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
fmt.Printf("Shape: %v\n", sz)
|
|
||||||
|
|
||||||
fmt.Printf("DType: %v\n", ts.DType())
|
|
||||||
|
|
||||||
dx := [][]float64{
|
|
||||||
{1, 1},
|
|
||||||
{1, 1},
|
|
||||||
{1, 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
dy := [][]float64{
|
|
||||||
{1, 2, 3},
|
|
||||||
{1, 1, 1},
|
|
||||||
}
|
|
||||||
|
|
||||||
xs, err := tensor.NewTensorFromData(dx, []int64{3, 2})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
ys, err := tensor.NewTensorFromData(dy, []int64{2, 3})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CPU
|
|
||||||
startCPUTime := time.Now()
|
|
||||||
for i := 1; i < 100000; i++ {
|
|
||||||
xs.Matmul(ys)
|
|
||||||
}
|
|
||||||
fmt.Printf("CPU time: %v\n", time.Since(startCPUTime))
|
|
||||||
|
|
||||||
// Cuda
|
|
||||||
device := gotch.NewCuda()
|
|
||||||
startGPUTime := time.Now()
|
|
||||||
for i := 1; i < 100000; i++ {
|
|
||||||
cx, err := xs.To(device)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
cy, err := ys.To(device)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
cx.Matmul(cy)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("GPU time: %v\n", time.Since(startGPUTime))
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
ts, err := tensor.OfSlice([]float64{1.3, 29.7})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := ts.Float64Value([]int64{1})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println(res)
|
|
||||||
|
|
||||||
resInt64, err := ts.Int64Value([]int64{1})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println(resInt64)
|
|
||||||
|
|
||||||
grad, err := ts.RequiresGrad()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Requires Grad: %v\n", grad)
|
|
||||||
|
|
||||||
ele1, err := ts.DataPtr()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
fmt.Printf("First element address: %v\n", ele1)
|
|
||||||
|
|
||||||
fmt.Printf("Number of tensor elements: %v\n", ts.Numel())
|
|
||||||
|
|
||||||
clone := ts.MustShallowClone()
|
|
||||||
clone.Print()
|
|
||||||
|
|
||||||
atGet := ts.MustGet(1)
|
|
||||||
atGet.Print() // 29.7
|
|
||||||
|
|
||||||
atGet = ts.MustGet(0)
|
|
||||||
atGet.Print() // 1.3
|
|
||||||
|
|
||||||
dst, err := tensor.NewTensorFromData([]int64{1, 2}, []int64{1, 2})
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dst = dst.MustTotype(ts.DType())
|
|
||||||
|
|
||||||
tensor.MustCopy_(dst, ts)
|
|
||||||
dst.Print()
|
|
||||||
|
|
||||||
ts.MustDrop()
|
|
||||||
// The below statement will be panic as `ts` has been dropped.
|
|
||||||
// ts.Print()
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,26 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
vs := nn.NewVarStore(gotch.CPU)
|
|
||||||
|
|
||||||
fmt.Printf("Is VarStore emptry? %v\n ", vs.IsEmpty())
|
|
||||||
|
|
||||||
path := vs.Root()
|
|
||||||
|
|
||||||
init := nn.NewKaimingUniformInit()
|
|
||||||
|
|
||||||
init.InitTensor([]int64{1, 4}, gotch.CPU).Print()
|
|
||||||
|
|
||||||
path.NewVar("layer1", []int64{1, 10}, nn.NewKaimingUniformInit())
|
|
||||||
|
|
||||||
fmt.Printf("Is VarStore emptry? %v\n ", vs.IsEmpty())
|
|
||||||
|
|
||||||
}
|
|
102
nn/sequential.go
102
nn/sequential.go
|
@ -3,6 +3,7 @@ package nn
|
||||||
// A sequential layer used to chain multiple layers and closures.
|
// A sequential layer used to chain multiple layers and closures.
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
// "reflect"
|
// "reflect"
|
||||||
)
|
)
|
||||||
|
@ -224,3 +225,104 @@ type ForwardTWith func(ts.Tensor, bool) ts.Tensor
|
||||||
func (fw ForwardTWith) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
func (fw ForwardTWith) ForwardT(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return fw(xs, train)
|
return fw(xs, train)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BatchAccuracyForLogits calculates average accuracy of test batches.
|
||||||
|
//
|
||||||
|
// NOTE: Pytorch uses `NoGradGuard` which is a thread local scope and
|
||||||
|
// it sets a global flag that is checked by the backend whenever an op is done on a variable.
|
||||||
|
// The guard itself saved the current status and set it to false in the constructor.
|
||||||
|
// And restore the saved status in it’s destructor. That way it is similar to a with torch.no_grad(): block in python.
|
||||||
|
// This seems not working in Go.
|
||||||
|
// There 2 ways to get around. One is freeze VarStore, the other is
|
||||||
|
// set manually set AutoGrad at `loss` tensor. I.e., `loss = loss.MustSetRequiresGrad(true)`
|
||||||
|
func BatchAccuracyForLogits(vs VarStore, m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||||
|
|
||||||
|
var (
|
||||||
|
sumAccuracy float64 = 0.0
|
||||||
|
sampleCount float64 = 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
vs.Freeze()
|
||||||
|
defer vs.Unfreeze()
|
||||||
|
|
||||||
|
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
||||||
|
for {
|
||||||
|
item, ok := iter2.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
size := float64(item.Data.MustSize()[0])
|
||||||
|
bImages := item.Data.MustTo(d, true)
|
||||||
|
bLabels := item.Label.MustTo(d, true)
|
||||||
|
|
||||||
|
logits := m.ForwardT(bImages, false)
|
||||||
|
acc := logits.AccuracyForLogits(bLabels)
|
||||||
|
sumAccuracy += acc.Values()[0] * size
|
||||||
|
sampleCount += size
|
||||||
|
|
||||||
|
bImages.MustDrop()
|
||||||
|
bLabels.MustDrop()
|
||||||
|
acc.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return sumAccuracy / sampleCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
|
||||||
|
// calculate accuracy for specified batch on module weight. It uses tensor
|
||||||
|
// indexing instead of Iter2
|
||||||
|
func BatchAccuracyForLogitsIdx(vs VarStore, m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||||
|
var (
|
||||||
|
sumAccuracy float64 = 0.0
|
||||||
|
sampleCount float64 = 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
totalSize := xs.MustSize()[0]
|
||||||
|
samples := int(totalSize)
|
||||||
|
|
||||||
|
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||||
|
imagesTs := xs.MustIndexSelect(0, index, false)
|
||||||
|
labelsTs := ys.MustIndexSelect(0, index, false)
|
||||||
|
|
||||||
|
batches := samples / batchSize
|
||||||
|
batchIndex := 0
|
||||||
|
|
||||||
|
vs.Freeze()
|
||||||
|
defer vs.Unfreeze()
|
||||||
|
|
||||||
|
for i := 0; i < batches; i++ {
|
||||||
|
start := batchIndex * batchSize
|
||||||
|
size := batchSize
|
||||||
|
if samples-start < batchSize {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
batchIndex += 1
|
||||||
|
|
||||||
|
// Indexing
|
||||||
|
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||||
|
bImages := imagesTs.Idx(narrowIndex)
|
||||||
|
bLabels := labelsTs.Idx(narrowIndex)
|
||||||
|
|
||||||
|
bImages = bImages.MustTo(d, true)
|
||||||
|
bLabels = bLabels.MustTo(d, true)
|
||||||
|
|
||||||
|
logits := m.ForwardT(bImages, true)
|
||||||
|
bAccuracy := logits.AccuracyForLogits(bLabels)
|
||||||
|
|
||||||
|
accuVal := bAccuracy.Values()[0]
|
||||||
|
bSamples := float64(xs.MustSize()[0])
|
||||||
|
sumAccuracy += accuVal * bSamples
|
||||||
|
sampleCount += bSamples
|
||||||
|
|
||||||
|
// Free up tensors on C memory
|
||||||
|
bImages.MustDrop()
|
||||||
|
bLabels.MustDrop()
|
||||||
|
bAccuracy.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
imagesTs.MustDrop()
|
||||||
|
labelsTs.MustDrop()
|
||||||
|
|
||||||
|
return sumAccuracy / sampleCount
|
||||||
|
}
|
||||||
|
|
126
sysinfo.go
126
sysinfo.go
|
@ -1,126 +0,0 @@
|
||||||
// A wrapper around the linux syscall sysinfo(2).
|
|
||||||
package gotch
|
|
||||||
|
|
||||||
// helper to debug memory blow-up
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Go-ized http://man7.org/linux/man-pages/man2/sysinfo.2.html
|
|
||||||
type SI struct {
|
|
||||||
Uptime time.Duration // time since boot
|
|
||||||
Loads [3]float64 // 1, 5, and 15 minute load averages, see e.g. UPTIME(1)
|
|
||||||
Procs uint64 // number of current processes
|
|
||||||
TotalRam uint64 // total usable main memory size [kB]
|
|
||||||
FreeRam uint64 // available memory size [kB]
|
|
||||||
SharedRam uint64 // amount of shared memory [kB]
|
|
||||||
BufferRam uint64 // memory used by buffers [kB]
|
|
||||||
TotalSwap uint64 // total swap space size [kB]
|
|
||||||
FreeSwap uint64 // swap space still available [kB]
|
|
||||||
TotalHighRam uint64 // total high memory size [kB]
|
|
||||||
FreeHighRam uint64 // available high memory size [kB]
|
|
||||||
mu sync.Mutex // ensures atomic writes; protects the following fields
|
|
||||||
}
|
|
||||||
|
|
||||||
var sis = &SI{}
|
|
||||||
|
|
||||||
// Get the linux sysinfo data structure.
|
|
||||||
//
|
|
||||||
// Useful links in the wild web:
|
|
||||||
// http://man7.org/linux/man-pages/man2/sysinfo.2.html
|
|
||||||
// http://man7.org/linux/man-pages/man1/uptime.1.html
|
|
||||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/zsyscall_linux_amd64.go#L1050
|
|
||||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_amd64.go#L528
|
|
||||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_arm.go#L502
|
|
||||||
func GetSysInfo() *SI {
|
|
||||||
|
|
||||||
/*
|
|
||||||
// Note: uint64 is uint32 on 32 bit CPUs
|
|
||||||
type Sysinfo_t struct {
|
|
||||||
Uptime int64 // Seconds since boot
|
|
||||||
Loads [3]uint64 // 1, 5, and 15 minute load averages
|
|
||||||
Totalram uint64 // Total usable main memory size
|
|
||||||
Freeram uint64 // Available memory size
|
|
||||||
Sharedram uint64 // Amount of shared memory
|
|
||||||
Bufferram uint64 // Memory used by buffers
|
|
||||||
Totalswap uint64 // Total swap space size
|
|
||||||
Freeswap uint64 // swap space still available
|
|
||||||
Procs uint16 // Number of current processes
|
|
||||||
Pad uint16
|
|
||||||
Pad_cgo_0 [4]byte
|
|
||||||
Totalhigh uint64 // Total high memory size
|
|
||||||
Freehigh uint64 // Available high memory size
|
|
||||||
Unit uint32 // Memory unit size in bytes
|
|
||||||
X_f [0]byte
|
|
||||||
Pad_cgo_1 [4]byte // Padding to 64 bytes
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// ~1kB garbage
|
|
||||||
si := &syscall.Sysinfo_t{}
|
|
||||||
|
|
||||||
// XXX is a raw syscall thread safe?
|
|
||||||
err := syscall.Sysinfo(si)
|
|
||||||
if err != nil {
|
|
||||||
panic("Commander, we have a problem. syscall.Sysinfo:" + err.Error())
|
|
||||||
}
|
|
||||||
scale := 65536.0 // magic
|
|
||||||
|
|
||||||
defer sis.mu.Unlock()
|
|
||||||
sis.mu.Lock()
|
|
||||||
|
|
||||||
unit := uint64(si.Unit) * 1024 // kB
|
|
||||||
|
|
||||||
sis.Uptime = time.Duration(si.Uptime) * time.Second
|
|
||||||
sis.Loads[0] = float64(si.Loads[0]) / scale
|
|
||||||
sis.Loads[1] = float64(si.Loads[1]) / scale
|
|
||||||
sis.Loads[2] = float64(si.Loads[2]) / scale
|
|
||||||
sis.Procs = uint64(si.Procs)
|
|
||||||
|
|
||||||
sis.TotalRam = uint64(si.Totalram) / unit
|
|
||||||
sis.FreeRam = uint64(si.Freeram) / unit
|
|
||||||
sis.BufferRam = uint64(si.Bufferram) / unit
|
|
||||||
sis.TotalSwap = uint64(si.Totalswap) / unit
|
|
||||||
sis.FreeSwap = uint64(si.Freeswap) / unit
|
|
||||||
sis.TotalHighRam = uint64(si.Totalhigh) / unit
|
|
||||||
sis.FreeHighRam = uint64(si.Freehigh) / unit
|
|
||||||
|
|
||||||
return sis
|
|
||||||
}
|
|
||||||
|
|
||||||
// Make the "fmt" Stringer interface happy.
|
|
||||||
func (si SI) String() string {
|
|
||||||
// XXX: Is the copy of SI done atomic? Not sure.
|
|
||||||
// Without an outer lock this may print a junk.
|
|
||||||
return fmt.Sprintf("uptime\t\t%v\nload\t\t%2.2f %2.2f %2.2f\nprocs\t\t%d\n"+
|
|
||||||
"ram total\t%d kB\nram free\t%d kB\nram buffer\t%d kB\n"+
|
|
||||||
"swap total\t%d kB\nswap free\t%d kB",
|
|
||||||
//"high ram total\t%d kB\nhigh ram free\t%d kB\n"
|
|
||||||
si.Uptime, si.Loads[0], si.Loads[1], si.Loads[2], si.Procs,
|
|
||||||
si.TotalRam, si.FreeRam, si.BufferRam,
|
|
||||||
si.TotalSwap, si.FreeSwap,
|
|
||||||
// archaic si.TotalHighRam, si.FreeHighRam
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
Convert to string in a thread safe way.
|
|
||||||
Output:
|
|
||||||
uptime 279h6m21s
|
|
||||||
load 0.12 0.04 0.05
|
|
||||||
procs 143
|
|
||||||
ram total 383752 kB
|
|
||||||
ram free 254980 kB
|
|
||||||
ram buffer 7640 kB
|
|
||||||
swap total 887800 kB
|
|
||||||
swap free 879356 kB
|
|
||||||
*/
|
|
||||||
func (si *SI) ToString() string {
|
|
||||||
defer si.mu.Unlock()
|
|
||||||
si.mu.Lock()
|
|
||||||
return si.String()
|
|
||||||
}
|
|
148
tensor/module.go
148
tensor/module.go
|
@ -1,7 +1,5 @@
|
||||||
package tensor
|
package tensor
|
||||||
|
|
||||||
import "github.com/sugarme/gotch"
|
|
||||||
|
|
||||||
// Module interface is a container with only one method `Forward`
|
// Module interface is a container with only one method `Forward`
|
||||||
//
|
//
|
||||||
// The following is `module` concept from Pytorch documenation:
|
// The following is `module` concept from Pytorch documenation:
|
||||||
|
@ -52,108 +50,50 @@ type ModuleT interface {
|
||||||
* }
|
* }
|
||||||
* */
|
* */
|
||||||
|
|
||||||
// BatchAccuracyForLigits calculate accuracy in batch.
|
// NOTE: this func has been moved to `nn/sequential` as `NoGradGuard`
|
||||||
//
|
// seem not working in Go and the function needs to add varstore variable
|
||||||
// TODO: It would be nice if it is one method an object that implements ModuleT
|
// parameter. Hence, it is moved to `nn` to avoid cycle reference.
|
||||||
// interface.
|
/*
|
||||||
func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
* // BatchAccuracyForLigits calculate accuracy in batch.
|
||||||
|
* //
|
||||||
var (
|
* // TODO: It would be nice if it is one method an object that implements ModuleT
|
||||||
sumAccuracy float64 = 0.0
|
* // interface.
|
||||||
sampleCount float64 = 0.0
|
* func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||||
)
|
*
|
||||||
|
* var (
|
||||||
_ = MustGradSetEnabled(false)
|
* sumAccuracy float64 = 0.0
|
||||||
|
* sampleCount float64 = 0.0
|
||||||
iter2 := MustNewIter2(xs, ys, int64(batchSize))
|
* )
|
||||||
for {
|
*
|
||||||
item, ok := iter2.Next()
|
* _ = MustGradSetEnabled(false)
|
||||||
if !ok {
|
*
|
||||||
break
|
* iter2 := MustNewIter2(xs, ys, int64(batchSize))
|
||||||
}
|
* for {
|
||||||
|
* item, ok := iter2.Next()
|
||||||
size := float64(item.Data.MustSize()[0])
|
* if !ok {
|
||||||
bImages := item.Data.MustTo(d, true)
|
* break
|
||||||
bLabels := item.Label.MustTo(d, true)
|
* }
|
||||||
|
*
|
||||||
logits := m.ForwardT(bImages, false)
|
* size := float64(item.Data.MustSize()[0])
|
||||||
acc := logits.AccuracyForLogits(bLabels)
|
* bImages := item.Data.MustTo(d, true)
|
||||||
sumAccuracy += acc.Values()[0] * size
|
* bLabels := item.Label.MustTo(d, true)
|
||||||
sampleCount += size
|
*
|
||||||
|
* logits := m.ForwardT(bImages, false)
|
||||||
bImages.MustDrop()
|
* acc := logits.AccuracyForLogits(bLabels)
|
||||||
bLabels.MustDrop()
|
* sumAccuracy += acc.Values()[0] * size
|
||||||
acc.MustDrop()
|
* sampleCount += size
|
||||||
}
|
*
|
||||||
|
* bImages.MustDrop()
|
||||||
_ = MustGradSetEnabled(true)
|
* bLabels.MustDrop()
|
||||||
|
* acc.MustDrop()
|
||||||
return sumAccuracy / sampleCount
|
* }
|
||||||
|
*
|
||||||
}
|
* _ = MustGradSetEnabled(true)
|
||||||
|
*
|
||||||
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
|
* return sumAccuracy / sampleCount
|
||||||
// calculate accuracy for specified batch on module weight. It uses tensor
|
*
|
||||||
// indexing instead of Iter2
|
* }
|
||||||
func BatchAccuracyForLogitsIdx(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
* */
|
||||||
var (
|
|
||||||
sumAccuracy float64 = 0.0
|
|
||||||
sampleCount float64 = 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
// Switch Grad off
|
|
||||||
_ = NewNoGradGuard()
|
|
||||||
|
|
||||||
totalSize := xs.MustSize()[0]
|
|
||||||
samples := int(totalSize)
|
|
||||||
|
|
||||||
index := MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
|
||||||
imagesTs := xs.MustIndexSelect(0, index, false)
|
|
||||||
labelsTs := ys.MustIndexSelect(0, index, false)
|
|
||||||
|
|
||||||
batches := samples / batchSize
|
|
||||||
batchIndex := 0
|
|
||||||
|
|
||||||
for i := 0; i < batches; i++ {
|
|
||||||
start := batchIndex * batchSize
|
|
||||||
size := batchSize
|
|
||||||
if samples-start < batchSize {
|
|
||||||
// size = samples - start
|
|
||||||
break
|
|
||||||
}
|
|
||||||
batchIndex += 1
|
|
||||||
|
|
||||||
// Indexing
|
|
||||||
narrowIndex := NewNarrow(int64(start), int64(start+size))
|
|
||||||
bImages := imagesTs.Idx(narrowIndex)
|
|
||||||
bLabels := labelsTs.Idx(narrowIndex)
|
|
||||||
|
|
||||||
bImages = bImages.MustTo(d, true)
|
|
||||||
bLabels = bLabels.MustTo(d, true)
|
|
||||||
|
|
||||||
logits := m.ForwardT(bImages, true)
|
|
||||||
bAccuracy := logits.AccuracyForLogits(bLabels)
|
|
||||||
|
|
||||||
accuVal := bAccuracy.Values()[0]
|
|
||||||
bSamples := float64(xs.MustSize()[0])
|
|
||||||
sumAccuracy += accuVal * bSamples
|
|
||||||
sampleCount += bSamples
|
|
||||||
|
|
||||||
// Free up tensors on C memory
|
|
||||||
bImages.MustDrop()
|
|
||||||
bLabels.MustDrop()
|
|
||||||
// logits.MustDrop()
|
|
||||||
bAccuracy.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
imagesTs.MustDrop()
|
|
||||||
labelsTs.MustDrop()
|
|
||||||
|
|
||||||
// Switch Grad on
|
|
||||||
// _ = MustGradSetEnabled(true)
|
|
||||||
|
|
||||||
return sumAccuracy / sampleCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// Tensor methods for Module and ModuleT:
|
// Tensor methods for Module and ModuleT:
|
||||||
// ======================================
|
// ======================================
|
||||||
|
|
Loading…
Reference in New Issue
Block a user