diff --git a/example/mnist/linear.go b/example/mnist/linear.go index 2bcff1c..cceabb8 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -13,7 +13,7 @@ const ( Label int64 = 10 MnistDir string = "../../data/mnist" - epochs = 200 + epochs = 500 ) func runLinear() { @@ -28,22 +28,23 @@ func runLinear() { for epoch := 0; epoch < epochs; epoch++ { - logits := ds.TrainImages.MustMm(ws).MustAdd(bs) - loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels) + logits := ds.TrainImages.MustMm(ws, false).MustAdd(bs, true) + loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(ds.TrainLabels, true) ws.ZeroGrad() bs.ZeroGrad() loss.MustBackward() ts.NoGrad(func() { - ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) - bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) + ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0), true)) + bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0), true)) }) - testLogits := ds.TestImages.MustMm(ws).MustAdd(bs) - testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true) + testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0}) fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100) + loss.MustDrop() } } diff --git a/example/mnist/nn.go b/example/mnist/nn.go index c7ccb60..264ae6c 100644 --- a/example/mnist/nn.go +++ b/example/mnist/nn.go @@ -16,8 +16,7 @@ const ( LabelNN int64 = 10 MnistDirNN string = "../../data/mnist" - epochsNN = 50 - batchSizeNN = 256 + epochsNN = 200 LrNN = 1e-3 ) @@ -27,12 +26,10 @@ var l nn.Linear func netInit(vs nn.Path) ts.Module { n := nn.Seq() - l = nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()) + n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())) - n.Add(l) - - n.AddFn(nn.ForwardWith(func(xs ts.Tensor) ts.Tensor { - return xs.MustRelu() + n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor { + return xs.MustRelu(true) })) n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig())) @@ -45,9 +42,11 @@ func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer opt.BackwardStep(loss) - testAccuracy := m.Forward(testX).AccuracyForLogits(testY).Values()[0] - fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100) + testAccuracy := m.Forward(testX).AccuracyForLogits(testY) + fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy.Values()[0]*100) + loss.MustDrop() + testAccuracy.MustDrop() } func runNN() { @@ -62,9 +61,7 @@ func runNN() { } for epoch := 0; epoch < epochsNN; epoch++ { - train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch) - } } diff --git a/example/tensor-memory/main.go b/example/tensor-memory/main.go new file mode 100644 index 0000000..51e0e61 --- /dev/null +++ b/example/tensor-memory/main.go @@ -0,0 +1,74 @@ +package main + +import ( + "fmt" + // "runtime" + + ts "github.com/sugarme/gotch/tensor" +) + +func createTensors(samples int) []ts.Tensor { + n := int(10e6) + var data []float64 + for i := 0; i < n; i++ { + data = append(data, float64(i)) + } + + var tensors []ts.Tensor + s := ts.FloatScalar(float64(0.23)) + + // for i := 0; i < samples; i++ { + for i := 0; i < 1; i++ { + t := ts.MustOfSlice(data).MustMul1(s, true) + + // t1.MustDrop() + // t.MustDrop() + // t1 = ts.Tensor{} + // t = ts.Tensor{} + // runtime.GC() + + // fmt.Printf("t values: %v", t.Values()) + // fmt.Printf("t1 values: %v", t1.Values()) + tensors = append(tensors, t) + } + + return tensors +} + +func dropTensors(tensors []ts.Tensor) { + for _, t := range tensors { + t.MustDrop() + } +} + +func main() { + + var si *SI + si = Get() + fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024) + fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024) + + startRAM := si.TotalRam - si.FreeRam + + epochs := 50 + // var m runtime.MemStats + + for i := 0; i < epochs; i++ { + // runtime.ReadMemStats(&m) + // t0 := float64(m.Sys) / 1024 / 1024 + + tensors := createTensors(10000) + + // runtime.ReadMemStats(&m) + // t1 := float64(m.Sys) / 1024 / 1024 + + dropTensors(tensors) + + // runtime.ReadMemStats(&m) + // t2 := float64(m.Sys) / 1024 / 1024 + + // fmt.Printf("Epoch: %v \t Start Mem [%.3f MiB] \t Alloc Mem [%.3f MiB] \t Free Mem [%.3f MiB]\n", i, t0, t1, t2) + si = Get() + fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", i, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024) + } +} diff --git a/example/tensor-memory/sysinfo.go b/example/tensor-memory/sysinfo.go new file mode 100644 index 0000000..dad1412 --- /dev/null +++ b/example/tensor-memory/sysinfo.go @@ -0,0 +1,124 @@ +// A wrapper around the linux syscall sysinfo(2). +package main + +import ( + "fmt" + "sync" + "syscall" + "time" +) + +// Go-ized http://man7.org/linux/man-pages/man2/sysinfo.2.html +type SI struct { + Uptime time.Duration // time since boot + Loads [3]float64 // 1, 5, and 15 minute load averages, see e.g. UPTIME(1) + Procs uint64 // number of current processes + TotalRam uint64 // total usable main memory size [kB] + FreeRam uint64 // available memory size [kB] + SharedRam uint64 // amount of shared memory [kB] + BufferRam uint64 // memory used by buffers [kB] + TotalSwap uint64 // total swap space size [kB] + FreeSwap uint64 // swap space still available [kB] + TotalHighRam uint64 // total high memory size [kB] + FreeHighRam uint64 // available high memory size [kB] + mu sync.Mutex // ensures atomic writes; protects the following fields +} + +var sis = &SI{} + +// Get the linux sysinfo data structure. +// +// Useful links in the wild web: +// http://man7.org/linux/man-pages/man2/sysinfo.2.html +// http://man7.org/linux/man-pages/man1/uptime.1.html +// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/zsyscall_linux_amd64.go#L1050 +// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_amd64.go#L528 +// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_arm.go#L502 +func Get() *SI { + + /* + // Note: uint64 is uint32 on 32 bit CPUs + type Sysinfo_t struct { + Uptime int64 // Seconds since boot + Loads [3]uint64 // 1, 5, and 15 minute load averages + Totalram uint64 // Total usable main memory size + Freeram uint64 // Available memory size + Sharedram uint64 // Amount of shared memory + Bufferram uint64 // Memory used by buffers + Totalswap uint64 // Total swap space size + Freeswap uint64 // swap space still available + Procs uint16 // Number of current processes + Pad uint16 + Pad_cgo_0 [4]byte + Totalhigh uint64 // Total high memory size + Freehigh uint64 // Available high memory size + Unit uint32 // Memory unit size in bytes + X_f [0]byte + Pad_cgo_1 [4]byte // Padding to 64 bytes + } + */ + + // ~1kB garbage + si := &syscall.Sysinfo_t{} + + // XXX is a raw syscall thread safe? + err := syscall.Sysinfo(si) + if err != nil { + panic("Commander, we have a problem. syscall.Sysinfo:" + err.Error()) + } + scale := 65536.0 // magic + + defer sis.mu.Unlock() + sis.mu.Lock() + + unit := uint64(si.Unit) * 1024 // kB + + sis.Uptime = time.Duration(si.Uptime) * time.Second + sis.Loads[0] = float64(si.Loads[0]) / scale + sis.Loads[1] = float64(si.Loads[1]) / scale + sis.Loads[2] = float64(si.Loads[2]) / scale + sis.Procs = uint64(si.Procs) + + sis.TotalRam = uint64(si.Totalram) / unit + sis.FreeRam = uint64(si.Freeram) / unit + sis.BufferRam = uint64(si.Bufferram) / unit + sis.TotalSwap = uint64(si.Totalswap) / unit + sis.FreeSwap = uint64(si.Freeswap) / unit + sis.TotalHighRam = uint64(si.Totalhigh) / unit + sis.FreeHighRam = uint64(si.Freehigh) / unit + + return sis +} + +// Make the "fmt" Stringer interface happy. +func (si SI) String() string { + // XXX: Is the copy of SI done atomic? Not sure. + // Without an outer lock this may print a junk. + return fmt.Sprintf("uptime\t\t%v\nload\t\t%2.2f %2.2f %2.2f\nprocs\t\t%d\n"+ + "ram total\t%d kB\nram free\t%d kB\nram buffer\t%d kB\n"+ + "swap total\t%d kB\nswap free\t%d kB", + //"high ram total\t%d kB\nhigh ram free\t%d kB\n" + si.Uptime, si.Loads[0], si.Loads[1], si.Loads[2], si.Procs, + si.TotalRam, si.FreeRam, si.BufferRam, + si.TotalSwap, si.FreeSwap, + // archaic si.TotalHighRam, si.FreeHighRam + ) +} + +/* +Convert to string in a thread safe way. + Output: + uptime 279h6m21s + load 0.12 0.04 0.05 + procs 143 + ram total 383752 kB + ram free 254980 kB + ram buffer 7640 kB + swap total 887800 kB + swap free 879356 kB +*/ +func (si *SI) ToString() string { + defer si.mu.Unlock() + si.mu.Lock() + return si.String() +} diff --git a/libtch/tensor.go b/libtch/tensor.go index d48678e..b360b3f 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -126,29 +126,28 @@ func AtInt64ValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) i // int at_requires_grad(tensor); func AtRequiresGrad(ts Ctensor) bool { - retVal := C.at_requires_grad((C.tensor)(ts)) + retVal := C.at_requires_grad(ts) return *(*bool)(unsafe.Pointer(&retVal)) } // int at_defined(tensor); func AtDefined(ts Ctensor) bool { - retVal := C.at_defined((C.tensor)(ts)) + retVal := C.at_defined(ts) return *(*bool)(unsafe.Pointer(&retVal)) } // int at_is_sparse(tensor); func AtIsSparse(ts Ctensor) bool { - retVal := C.at_is_sparse((C.tensor)(ts)) + retVal := C.at_is_sparse(ts) return *(*bool)(unsafe.Pointer(&retVal)) } // void at_backward(tensor, int, int); func AtBackward(ts Ctensor, keepGraph int, createGraph int) { - ctensor := (C.tensor)(ts) ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph)) ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph)) - C.at_backward(ctensor, ckeepGraph, ccreateGraph) + C.at_backward(ts, ckeepGraph, ccreateGraph) } /* @@ -169,39 +168,33 @@ func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninput } // void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes); -func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) { - ctensor := (C.tensor)(tensor) +func AtCopyData(ts Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) { cnumel := *(*C.size_t)(unsafe.Pointer(&numel)) celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes)) - C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes) + C.at_copy_data(ts, vs, cnumel, celement_size_in_bytes) } // tensor at_shallow_clone(tensor); func AtShallowClone(ts Ctensor) Ctensor { - ctensor := (C.tensor)(ts) - return C.at_shallow_clone(ctensor) + return C.at_shallow_clone(ts) } // tensor at_get(tensor, int index); func AtGet(ts Ctensor, index int) Ctensor { - ctensor := (C.tensor)(ts) cindex := *(*C.int)(unsafe.Pointer(&index)) - return C.at_get(ctensor, cindex) + return C.at_get(ts, cindex) } // void at_copy_(tensor dst, tensor src); func AtCopy_(dst Ctensor, src Ctensor) { - cdst := (C.tensor)(dst) - csrc := (C.tensor)(src) - C.at_copy_(cdst, csrc) + C.at_copy_(dst, src) } // void at_save(tensor, char *filename); func AtSave(ts Ctensor, path string) { - ctensor := (C.tensor)(ts) cstringPtr := C.CString(path) defer C.free(unsafe.Pointer(cstringPtr)) - C.at_save(ctensor, cstringPtr) + C.at_save(ts, cstringPtr) } // tensor at_load(char *filename); @@ -300,9 +293,8 @@ func AtLoadCallbackWithDevice(filename string, dataPtr unsafe.Pointer, device in * } * */ func AtToString(ts Ctensor, lineSize int64) string { - ctensor := (C.tensor)(ts) clineSize := *(*C.int)(unsafe.Pointer(&lineSize)) - charPtr := C.at_to_string(ctensor, clineSize) + charPtr := C.at_to_string(ts, clineSize) goString := C.GoString(charPtr) return goString @@ -310,8 +302,7 @@ func AtToString(ts Ctensor, lineSize int64) string { // void at_free(tensor); func AtFree(ts Ctensor) { - ctensor := (C.tensor)(ts) - C.at_free(ctensor) + C.at_free(ts) } //int at_grad_set_enabled(int b); diff --git a/nn/linear.go b/nn/linear.go index 4c647ff..8fa27d1 100644 --- a/nn/linear.go +++ b/nn/linear.go @@ -90,6 +90,6 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear { // 1 1 1 // 1 1 1 ] func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) { - - return xs.MustMatMul(l.Ws.MustT()).MustAdd(l.Bs) + clone := l.Ws.MustShallowClone().MustT(true) + return xs.MustMm(clone, false).MustAdd(l.Bs, true) } diff --git a/nn/sequential.go b/nn/sequential.go index b4f50cc..ad0c21c 100644 --- a/nn/sequential.go +++ b/nn/sequential.go @@ -78,12 +78,11 @@ func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { } // forward sequentially - var currTs ts.Tensor = xs for i := 0; i < len(s.layers); i++ { - currTs = s.layers[i].Forward(currTs) + xs = s.layers[i].Forward(xs) } - return currTs + return xs } // SequentialT is a sequential layer combining new layers with support for a training mode. @@ -136,7 +135,6 @@ func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { // Add appends a layer after all the current layers. func (s *SequentialT) Add(l ts.ModuleT) { - s.layers = append(s.layers, l) } diff --git a/nn/varstore.go b/nn/varstore.go index 4f53359..7577279 100644 --- a/nn/varstore.go +++ b/nn/varstore.go @@ -274,7 +274,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) { for k, v := range vs.Vars.NamedVariables { srcTs, _ := srcNamedVariables[k] - srcDevTs, err := srcTs.To(device) + srcDevTs, err := srcTs.To(device, false) if err != nil { return err } diff --git a/tensor/data.go b/tensor/data.go index ac3d168..173e35f 100644 --- a/tensor/data.go +++ b/tensor/data.go @@ -84,8 +84,8 @@ func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) { func (it Iter2) Shuffle() (retVal Iter2) { index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU) - it.xs = it.xs.MustIndexSelect(0, index) - it.ys = it.ys.MustIndexSelect(0, index) + it.xs = it.xs.MustIndexSelect(0, index, true) + it.ys = it.ys.MustIndexSelect(0, index, true) return it } diff --git a/tensor/index.go b/tensor/index.go index 5007623..459c8b4 100644 --- a/tensor/index.go +++ b/tensor/index.go @@ -252,14 +252,14 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) { switch reflect.TypeOf(spec).Name() { case "InsertNewAxis": - nextTensor, err = currTensor.Unsqueeze(currIdx) + nextTensor, err = currTensor.Unsqueeze(currIdx, true) if err != nil { return retVal, err } nextIdx = currIdx + 1 case "Select": // 1 field: `Index` index := reflect.ValueOf(spec).FieldByName("Index").Interface().(int64) - nextTensor, err = currTensor.Select(currIdx, index) // TODO: double-check is `*index` or `index` + nextTensor, err = currTensor.Select(currIdx, index, true) // TODO: double-check is `*index` or `index` if err != nil { return retVal, err } @@ -269,7 +269,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) { // NOTE: for now, just implement (Included(start), Excluded(end))` case start := reflect.ValueOf(spec).FieldByName("Start").Interface().(int64) end := reflect.ValueOf(spec).FieldByName("End").Interface().(int64) - nextTensor, err = currTensor.Narrow(currIdx, start, end-start) + nextTensor, err = currTensor.Narrow(currIdx, start, end-start, true) if err != nil { return retVal, err } @@ -280,11 +280,11 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) { if err != nil { return retVal, err } - indexTensor, err = indexTensor.To(device) + indexTensor, err = indexTensor.To(device, true) if err != nil { return retVal, err } - nextTensor, err = currTensor.IndexSelect(currIdx, indexTensor) + nextTensor, err = currTensor.IndexSelect(currIdx, indexTensor, true) if err != nil { return retVal, err } diff --git a/tensor/module.go b/tensor/module.go index 4dd8b2b..a139ef4 100644 --- a/tensor/module.go +++ b/tensor/module.go @@ -72,7 +72,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize break } - acc := m.ForwardT(item.Data.MustTo(d), false).AccuracyForLogits(item.Label.MustTo(d)).MustView([]int64{-1}).MustFloat64Value([]int64{0}) + acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}).MustFloat64Value([]int64{0}) size := float64(item.Data.MustSize()[0]) sumAccuracy += acc * size sampleCount += size diff --git a/tensor/other.go b/tensor/other.go index 0c0f9a3..119d258 100644 --- a/tensor/other.go +++ b/tensor/other.go @@ -8,13 +8,13 @@ import ( // CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets. func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) { - return ts.MustLogSoftmax(-1, gotch.Float.CInt()).MustNllLoss(targets) + return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true) } // AccuracyForLogits returns the average accuracy for some given logits assuming that // targets represent ground-truth. func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) { - return ts.MustArgmax(-1, false).MustEq1(targets).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()) + return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true) } // TODO: continue diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index 33da307..0edb6d6 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -12,12 +12,15 @@ import ( lib "github.com/sugarme/gotch/libtch" ) -func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) { +func (ts Tensor) To(device gotch.Device, del bool) (retVal Tensor, err error) { // TODO: how to get pointer to CUDA memory??? // C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + + if del { + defer ts.MustDrop() + } lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt())) @@ -28,9 +31,9 @@ func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) { return Tensor{ctensor: *ptr}, nil } -func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) { +func (ts Tensor) MustTo(device gotch.Device, del bool) (retVal Tensor) { var err error - retVal, err = ts.To(device) + retVal, err = ts.To(device, del) if err != nil { log.Fatal(err) } @@ -38,9 +41,11 @@ func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) { return retVal } -func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) { +func (ts Tensor) Matmul(other Tensor, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgMatmul(ptr, ts.ctensor, other.ctensor) if err = TorchErr(); err != nil { @@ -50,8 +55,8 @@ func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) { return Tensor{ctensor: *ptr}, nil } -func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) { - retVal, err := ts.Matmul(other) +func (ts Tensor) MustMatMul(other Tensor, del bool) (retVal Tensor) { + retVal, err := ts.Matmul(other, del) if err != nil { log.Fatal(err) } @@ -61,7 +66,6 @@ func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) { func (ts Tensor) Grad() (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgGrad(ptr, ts.ctensor) if err = TorchErr(); err != nil { @@ -82,7 +86,6 @@ func (ts Tensor) MustGrad() (retVal Tensor) { func (ts Tensor) Detach_() { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgDetach_(ptr, ts.ctensor) if err := TorchErr(); err != nil { @@ -92,7 +95,6 @@ func (ts Tensor) Detach_() { func (ts Tensor) Zero_() { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgZero_(ptr, ts.ctensor) if err := TorchErr(); err != nil { @@ -107,7 +109,6 @@ func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) { } ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgSetRequiresGrad(ptr, ts.ctensor, r) @@ -127,9 +128,11 @@ func (ts Tensor) MustSetRequiresGrad(rb bool) (retVal Tensor) { return retVal } -func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) { +func (ts Tensor) Mul(other Tensor, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgMul(ptr, ts.ctensor, other.ctensor) if err = TorchErr(); err != nil { @@ -139,8 +142,8 @@ func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) { return Tensor{ctensor: *ptr}, nil } -func (ts Tensor) MustMul(other Tensor) (retVal Tensor) { - retVal, err := ts.Mul(other) +func (ts Tensor) MustMul(other Tensor, del bool) (retVal Tensor) { + retVal, err := ts.Mul(other, del) if err != nil { log.Fatal(err) } @@ -148,9 +151,11 @@ func (ts Tensor) MustMul(other Tensor) (retVal Tensor) { return retVal } -func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) { +func (ts Tensor) Mul1(other Scalar, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgMul1(ptr, ts.ctensor, other.cscalar) if err = TorchErr(); err != nil { @@ -160,8 +165,8 @@ func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) { return Tensor{ctensor: *ptr}, nil } -func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) { - retVal, err := ts.Mul1(other) +func (ts Tensor) MustMul1(other Scalar, del bool) (retVal Tensor) { + retVal, err := ts.Mul1(other, del) if err != nil { log.Fatal(err) } @@ -171,7 +176,6 @@ func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) { func (ts Tensor) Mul_(other Tensor) (err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgMul_(ptr, ts.ctensor, other.ctensor) if err = TorchErr(); err != nil { @@ -188,9 +192,11 @@ func (ts Tensor) MustMul_(other Tensor) { } } -func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) { +func (ts Tensor) Add(other Tensor, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgAdd(ptr, ts.ctensor, other.ctensor) if err = TorchErr(); err != nil { @@ -200,8 +206,8 @@ func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) { return Tensor{ctensor: *ptr}, nil } -func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) { - retVal, err := ts.Add(other) +func (ts Tensor) MustAdd(other Tensor, del bool) (retVal Tensor) { + retVal, err := ts.Add(other, del) if err != nil { log.Fatal(err) } @@ -211,7 +217,6 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) { func (ts Tensor) Add_(other Tensor) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgAdd_(ptr, ts.ctensor, other.ctensor) if err := TorchErr(); err != nil { @@ -219,9 +224,11 @@ func (ts Tensor) Add_(other Tensor) { } } -func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) { +func (ts Tensor) Add1(other Scalar, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgAdd1(ptr, ts.ctensor, other.cscalar) if err = TorchErr(); err != nil { @@ -231,8 +238,8 @@ func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) { return Tensor{ctensor: *ptr}, nil } -func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) { - retVal, err := ts.Add1(other) +func (ts Tensor) MustAdd1(other Scalar, del bool) (retVal Tensor) { + retVal, err := ts.Add1(other, del) if err != nil { log.Fatal(err) @@ -243,7 +250,6 @@ func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) { func (ts Tensor) AddG(other Tensor) (err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgAdd(ptr, ts.ctensor, other.ctensor) if err = TorchErr(); err != nil { @@ -263,9 +269,11 @@ func (ts Tensor) MustAddG(other Tensor) { } // Totype casts type of tensor to a new tensor with specified DType -func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) { +func (ts Tensor) Totype(dtype gotch.DType, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } cint, err := gotch.DType2CInt(dtype) if err != nil { return retVal, err @@ -283,8 +291,8 @@ func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) { // Totype casts type of tensor to a new tensor with specified DType. It will // panic if error -func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) { - retVal, err := ts.Totype(dtype) +func (ts Tensor) MustTotype(dtype gotch.DType, del bool) (retVal Tensor) { + retVal, err := ts.Totype(dtype, del) if err != nil { log.Fatal(err) } @@ -293,10 +301,11 @@ func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) { } // Unsqueeze unsqueezes tensor to specified dimension. -func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) { +func (ts Tensor) Unsqueeze(dim int64, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) - + if del { + defer ts.MustDrop() + } lib.AtgUnsqueeze(ptr, ts.ctensor, dim) if err = TorchErr(); err != nil { return retVal, err @@ -308,10 +317,11 @@ func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) { } // Select creates a new tensor from current tensor given dim and index. -func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) { +func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) - + if del { + defer ts.MustDrop() + } lib.AtgSelect(ptr, ts.ctensor, dim, index) if err = TorchErr(); err != nil { return retVal, err @@ -324,9 +334,11 @@ func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) { // Narrow creates a new tensor from current tensor given dim and start index // and length. -func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, err error) { +func (ts Tensor) Narrow(dim int64, start int64, length int64, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgNarrow(ptr, ts.ctensor, dim, start, length) if err = TorchErr(); err != nil { @@ -340,9 +352,11 @@ func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, er // IndexSelect creates a new tensor from current tensor given dim and index // tensor. -func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error) { +func (ts Tensor) IndexSelect(dim int64, index Tensor, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgIndexSelect(ptr, ts.ctensor, dim, index.ctensor) if err = TorchErr(); err != nil { @@ -353,8 +367,8 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error) return retVal, nil } -func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) { - retVal, err := ts.IndexSelect(dim, index) +func (ts Tensor) MustIndexSelect(dim int64, index Tensor, del bool) (retVal Tensor) { + retVal, err := ts.IndexSelect(dim, index, del) if err != nil { log.Fatal(err) } @@ -364,7 +378,6 @@ func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) { func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgZeros(ptr, size, len(size), optionsKind, optionsDevice) if err = TorchErr(); err != nil { @@ -386,7 +399,6 @@ func MustZeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) { func Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgOnes(ptr, size, len(size), optionsKind, optionsDevice) if err = TorchErr(); err != nil { @@ -410,7 +422,6 @@ func MustOnes(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) { func (ts Tensor) Uniform_(from float64, to float64) { var err error ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgUniform_(ptr, ts.ctensor, from, to) if err = TorchErr(); err != nil { @@ -418,10 +429,11 @@ func (ts Tensor) Uniform_(from float64, to float64) { } } -func (ts Tensor) ZerosLike() (retVal Tensor, err error) { +func (ts Tensor) ZerosLike(del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) - + if del { + defer ts.MustDrop() + } lib.AtgZerosLike(ptr, ts.ctensor) if err = TorchErr(); err != nil { return retVal, err @@ -435,7 +447,6 @@ func (ts Tensor) ZerosLike() (retVal Tensor, err error) { func (ts Tensor) Fill_(value Scalar) { var err error ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgFill_(ptr, ts.ctensor, value.cscalar) if err = TorchErr(); err != nil { @@ -443,10 +454,11 @@ func (ts Tensor) Fill_(value Scalar) { } } -func (ts Tensor) RandnLike() (retVal Tensor, err error) { +func (ts Tensor) RandnLike(del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) - + if del { + defer ts.MustDrop() + } lib.AtgRandnLike(ptr, ts.ctensor) if err = TorchErr(); err != nil { return retVal, err @@ -457,10 +469,11 @@ func (ts Tensor) RandnLike() (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) { +func (ts Tensor) Permute(dims []int64, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) - + if del { + defer ts.MustDrop() + } lib.AtgPermute(ptr, ts.ctensor, dims, len(dims)) if err = TorchErr(); err != nil { @@ -472,9 +485,11 @@ func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) { +func (ts Tensor) Squeeze1(dim int64, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgSqueeze1(ptr, ts.ctensor, dim) @@ -487,9 +502,9 @@ func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) { +func (ts Tensor) MustSqueeze1(dim int64, del bool) (retVal Tensor) { var err error - retVal, err = ts.Squeeze1(dim) + retVal, err = ts.Squeeze1(dim, del) if err != nil { log.Fatal(err) } @@ -527,9 +542,11 @@ func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) { +func (ts Tensor) Mm(mat2 Tensor, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgMm(ptr, ts.ctensor, mat2.ctensor) if err = TorchErr(); err != nil { @@ -541,8 +558,8 @@ func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) { - retVal, err := ts.Mm(mat2) +func (ts Tensor) MustMm(mat2 Tensor, del bool) (retVal Tensor) { + retVal, err := ts.Mm(mat2, del) if err != nil { log.Fatal(err) } @@ -550,9 +567,11 @@ func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) { return retVal } -func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) { +func (ts Tensor) LogSoftmax(dim int64, dtype int32, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgLogSoftmax(ptr, ts.ctensor, dim, dtype) if err = TorchErr(); err != nil { @@ -564,8 +583,8 @@ func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) { - retVal, err := ts.LogSoftmax(dim, dtype) +func (ts Tensor) MustLogSoftmax(dim int64, dtype int32, del bool) (retVal Tensor) { + retVal, err := ts.LogSoftmax(dim, dtype, del) if err != nil { log.Fatal(err) } @@ -573,10 +592,11 @@ func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) { return retVal } -func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) { +func (ts Tensor) NllLoss(target Tensor, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - // NOTE: uncomment this causes panic - // defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } weight := NewTensor() @@ -594,8 +614,8 @@ func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) { - retVal, err := ts.NllLoss(target) +func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) { + retVal, err := ts.NllLoss(target, del) if err != nil { log.Fatal(err) } @@ -603,9 +623,11 @@ func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) { return retVal } -func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) { +func (ts Tensor) Argmax(dim int64, keepDim bool, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } var ckeepDim int = 0 if keepDim { @@ -622,8 +644,8 @@ func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) { - retVal, err := ts.Argmax(dim, keepDim) +func (ts Tensor) MustArgmax(dim int64, keepDim bool, del bool) (retVal Tensor) { + retVal, err := ts.Argmax(dim, keepDim, del) if err != nil { log.Fatal(err) } @@ -631,9 +653,11 @@ func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) { return retVal } -func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) { +func (ts Tensor) Mean(dtype int32, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgMean(ptr, ts.ctensor, dtype) if err = TorchErr(); err != nil { @@ -645,8 +669,8 @@ func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustMean(dtype int32) (retVal Tensor) { - retVal, err := ts.Mean(dtype) +func (ts Tensor) MustMean(dtype int32, del bool) (retVal Tensor) { + retVal, err := ts.Mean(dtype, del) if err != nil { log.Fatal(err) } @@ -677,9 +701,11 @@ func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) { return retVal } -func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) { +func (ts Tensor) Div1(other Scalar, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgDiv1(ptr, ts.ctensor, other.cscalar) if err = TorchErr(); err != nil { @@ -691,8 +717,8 @@ func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) { - retVal, err := ts.Div1(other) +func (ts Tensor) MustDiv1(other Scalar, del bool) (retVal Tensor) { + retVal, err := ts.Div1(other, del) if err != nil { log.Fatal(err) } @@ -702,7 +728,6 @@ func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) { func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgRandperm(ptr, n, optionKind.CInt(), optionDevice.CInt()) if err = TorchErr(); err != nil { @@ -725,7 +750,6 @@ func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (r func (ts Tensor) Clamp_(min Scalar, max Scalar) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar) if err := TorchErr(); err != nil { @@ -735,7 +759,6 @@ func (ts Tensor) Clamp_(min Scalar, max Scalar) { func (ts Tensor) Relu_() { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) lib.AtgRelu_(ptr, ts.ctensor) if err := TorchErr(); err != nil { @@ -743,9 +766,11 @@ func (ts Tensor) Relu_() { } } -func (ts Tensor) Relu() (retVal Tensor, err error) { +func (ts Tensor) Relu(del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgRelu(ptr, ts.ctensor) err = TorchErr() @@ -758,8 +783,8 @@ func (ts Tensor) Relu() (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustRelu() (retVal Tensor) { - retVal, err := ts.Relu() +func (ts Tensor) MustRelu(del bool) (retVal Tensor) { + retVal, err := ts.Relu(del) if err != nil { log.Fatal(err) } @@ -767,9 +792,11 @@ func (ts Tensor) MustRelu() (retVal Tensor) { return retVal } -func (ts Tensor) T() (retVal Tensor, err error) { +func (ts Tensor) T(del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgT(ptr, ts.ctensor) err = TorchErr() @@ -782,8 +809,8 @@ func (ts Tensor) T() (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustT() (retVal Tensor) { - retVal, err := ts.T() +func (ts Tensor) MustT(del bool) (retVal Tensor) { + retVal, err := ts.T(del) if err != nil { log.Fatal(err) } @@ -802,10 +829,12 @@ func (ts Tensor) T_() { } } -func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error) { +func (ts Tensor) MseLoss(target Tensor, reduction int, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgMseLoss(ptr, ts.ctensor, target.ctensor, reduction) err = TorchErr() @@ -818,8 +847,8 @@ func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error return retVal, nil } -func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) { - retVal, err := ts.MseLoss(target, reduction) +func (ts Tensor) MustMseLoss(target Tensor, reduction int, del bool) (retVal Tensor) { + retVal, err := ts.MseLoss(target, reduction, del) if err != nil { log.Fatal(err) @@ -828,9 +857,11 @@ func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) { return retVal } -func (ts Tensor) Exp() (retVal Tensor, err error) { +func (ts Tensor) Exp(del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgExp(ptr, ts.ctensor) err = TorchErr() @@ -843,8 +874,8 @@ func (ts Tensor) Exp() (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustExp() (retVal Tensor) { - retVal, err := ts.Exp() +func (ts Tensor) MustExp(del bool) (retVal Tensor) { + retVal, err := ts.Exp(del) if err != nil { log.Fatal(err) @@ -864,9 +895,11 @@ func (ts Tensor) Exp_() { } } -func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) { +func (ts Tensor) Pow(exponent Scalar, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgPow(ptr, ts.ctensor, exponent.cscalar) err = TorchErr() @@ -879,8 +912,8 @@ func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) { - retVal, err := ts.Pow(exponent) +func (ts Tensor) MustPow(exponent Scalar, del bool) (retVal Tensor) { + retVal, err := ts.Pow(exponent, del) if err != nil { log.Fatal(err) @@ -889,9 +922,11 @@ func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) { return retVal } -func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) { +func (ts Tensor) Sum(dtype int32, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgSum(ptr, ts.ctensor, dtype) err = TorchErr() @@ -904,8 +939,8 @@ func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustSum(dtype int32) (retVal Tensor) { - retVal, err := ts.Sum(dtype) +func (ts Tensor) MustSum(dtype int32, del bool) (retVal Tensor) { + retVal, err := ts.Sum(dtype, del) if err != nil { log.Fatal(err) @@ -914,9 +949,11 @@ func (ts Tensor) MustSum(dtype int32) (retVal Tensor) { return retVal } -func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) { +func (ts Tensor) Sub(other Tensor, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgSub(ptr, ts.ctensor, other.ctensor) err = TorchErr() @@ -929,8 +966,8 @@ func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustSub(other Tensor) (retVal Tensor) { - retVal, err := ts.Sub(other) +func (ts Tensor) MustSub(other Tensor, del bool) (retVal Tensor) { + retVal, err := ts.Sub(other, del) if err != nil { log.Fatal(err) @@ -939,9 +976,11 @@ func (ts Tensor) MustSub(other Tensor) (retVal Tensor) { return retVal } -func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) { +func (ts Tensor) Sub1(other Scalar, del bool) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - defer C.free(unsafe.Pointer(ptr)) + if del { + defer ts.MustDrop() + } lib.AtgSub1(ptr, ts.ctensor, other.cscalar) err = TorchErr() @@ -954,8 +993,8 @@ func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustSub1(other Scalar) (retVal Tensor) { - retVal, err := ts.Sub1(other) +func (ts Tensor) MustSub1(other Scalar, del bool) (retVal Tensor) { + retVal, err := ts.Sub1(other, del) if err != nil { log.Fatal(err) diff --git a/tensor/tensor.go b/tensor/tensor.go index 020dfb1..05fe028 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -2,6 +2,7 @@ package tensor //#include "stdlib.h" //#include "stdbool.h" +//#include import "C" import ( diff --git a/vision/image.go b/vision/image.go index 16f8734..a1ab850 100644 --- a/vision/image.go +++ b/vision/image.go @@ -16,7 +16,7 @@ import ( // (height, width, channel) -> (channel, height, width) func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) { var err error - retVal, err = tensor.Permute([]int64{2, 0, 1}) + retVal, err = tensor.Permute([]int64{2, 0, 1}, true) if err != nil { log.Fatalf("hwcToCHW error: %v\n", err) } @@ -25,7 +25,7 @@ func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) { func chwToHWC(tensor ts.Tensor) (retVal ts.Tensor) { var err error - retVal, err = tensor.Permute([]int64{1, 2, 0}) + retVal, err = tensor.Permute([]int64{1, 2, 0}, true) if err != nil { log.Fatalf("hwcToCHW error: %v\n", err) } @@ -54,7 +54,7 @@ func Load(path string) (retVal ts.Tensor, err error) { // The tensor input should be of kind UInt8 with values ranging from // 0 to 255. func Save(tensor ts.Tensor, path string) (err error) { - t, err := tensor.Totype(gotch.Uint8) + t, err := tensor.Totype(gotch.Uint8, true) if err != nil { err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err) return err @@ -68,9 +68,9 @@ func Save(tensor ts.Tensor, path string) (err error) { switch { case len(shape) == 4 && shape[0] == 1: - return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0)).MustTo(gotch.CPU)), path) + return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0), true).MustTo(gotch.CPU, true)), path) case len(shape) == 3: - return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU)), path) + return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU, true)), path) default: err = fmt.Errorf("Unexpected size (%v) for image tensor.\n", len(shape)) return err @@ -125,7 +125,7 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t var tensorW ts.Tensor var tensorH ts.Tensor if resizeW != outW { - tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW) + tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW, true) if err != nil { err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err) return retVal, err @@ -135,7 +135,7 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t if resizeH == outH { retVal = tensorW } else { - tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH) + tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH, true) if err != nil { err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err) return retVal, err diff --git a/vision/mnist.go b/vision/mnist.go index 2f711e4..c326a6c 100644 --- a/vision/mnist.go +++ b/vision/mnist.go @@ -82,7 +82,7 @@ func readLabels(filename string) (retVal ts.Tensor) { log.Fatal(err) } - retVal = labelsTs.MustTotype(gotch.Int64) + retVal = labelsTs.MustTotype(gotch.Int64, true) return retVal } @@ -125,7 +125,7 @@ func readImages(filename string) (retVal ts.Tensor) { err = fmt.Errorf("create images tensor err.") log.Fatal(err) } - retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float).MustDiv1(ts.FloatScalar(255.0)) + retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true) return retVal }