feat(tensor): add option to delete tensor after applying operator to free up C memory

This commit is contained in:
sugarme 2020-06-21 23:37:42 +10:00
parent 4ffe5feb7a
commit f36d2482a1
16 changed files with 414 additions and 189 deletions

View File

@ -13,7 +13,7 @@ const (
Label int64 = 10
MnistDir string = "../../data/mnist"
epochs = 200
epochs = 500
)
func runLinear() {
@ -28,22 +28,23 @@ func runLinear() {
for epoch := 0; epoch < epochs; epoch++ {
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
logits := ds.TrainImages.MustMm(ws, false).MustAdd(bs, true)
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(ds.TrainLabels, true)
ws.ZeroGrad()
bs.ZeroGrad()
loss.MustBackward()
ts.NoGrad(func() {
ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0), true))
bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0), true))
})
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
loss.MustDrop()
}
}

View File

@ -16,8 +16,7 @@ const (
LabelNN int64 = 10
MnistDirNN string = "../../data/mnist"
epochsNN = 50
batchSizeNN = 256
epochsNN = 200
LrNN = 1e-3
)
@ -27,12 +26,10 @@ var l nn.Linear
func netInit(vs nn.Path) ts.Module {
n := nn.Seq()
l = nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
n.Add(l)
n.AddFn(nn.ForwardWith(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu()
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu(true)
}))
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
@ -45,9 +42,11 @@ func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer
opt.BackwardStep(loss)
testAccuracy := m.Forward(testX).AccuracyForLogits(testY).Values()[0]
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
testAccuracy := m.Forward(testX).AccuracyForLogits(testY)
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy.Values()[0]*100)
loss.MustDrop()
testAccuracy.MustDrop()
}
func runNN() {
@ -62,9 +61,7 @@ func runNN() {
}
for epoch := 0; epoch < epochsNN; epoch++ {
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
}
}

View File

@ -0,0 +1,74 @@
package main
import (
"fmt"
// "runtime"
ts "github.com/sugarme/gotch/tensor"
)
func createTensors(samples int) []ts.Tensor {
n := int(10e6)
var data []float64
for i := 0; i < n; i++ {
data = append(data, float64(i))
}
var tensors []ts.Tensor
s := ts.FloatScalar(float64(0.23))
// for i := 0; i < samples; i++ {
for i := 0; i < 1; i++ {
t := ts.MustOfSlice(data).MustMul1(s, true)
// t1.MustDrop()
// t.MustDrop()
// t1 = ts.Tensor{}
// t = ts.Tensor{}
// runtime.GC()
// fmt.Printf("t values: %v", t.Values())
// fmt.Printf("t1 values: %v", t1.Values())
tensors = append(tensors, t)
}
return tensors
}
func dropTensors(tensors []ts.Tensor) {
for _, t := range tensors {
t.MustDrop()
}
}
func main() {
var si *SI
si = Get()
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
startRAM := si.TotalRam - si.FreeRam
epochs := 50
// var m runtime.MemStats
for i := 0; i < epochs; i++ {
// runtime.ReadMemStats(&m)
// t0 := float64(m.Sys) / 1024 / 1024
tensors := createTensors(10000)
// runtime.ReadMemStats(&m)
// t1 := float64(m.Sys) / 1024 / 1024
dropTensors(tensors)
// runtime.ReadMemStats(&m)
// t2 := float64(m.Sys) / 1024 / 1024
// fmt.Printf("Epoch: %v \t Start Mem [%.3f MiB] \t Alloc Mem [%.3f MiB] \t Free Mem [%.3f MiB]\n", i, t0, t1, t2)
si = Get()
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", i, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
}
}

View File

@ -0,0 +1,124 @@
// A wrapper around the linux syscall sysinfo(2).
package main
import (
"fmt"
"sync"
"syscall"
"time"
)
// Go-ized http://man7.org/linux/man-pages/man2/sysinfo.2.html
type SI struct {
Uptime time.Duration // time since boot
Loads [3]float64 // 1, 5, and 15 minute load averages, see e.g. UPTIME(1)
Procs uint64 // number of current processes
TotalRam uint64 // total usable main memory size [kB]
FreeRam uint64 // available memory size [kB]
SharedRam uint64 // amount of shared memory [kB]
BufferRam uint64 // memory used by buffers [kB]
TotalSwap uint64 // total swap space size [kB]
FreeSwap uint64 // swap space still available [kB]
TotalHighRam uint64 // total high memory size [kB]
FreeHighRam uint64 // available high memory size [kB]
mu sync.Mutex // ensures atomic writes; protects the following fields
}
var sis = &SI{}
// Get the linux sysinfo data structure.
//
// Useful links in the wild web:
// http://man7.org/linux/man-pages/man2/sysinfo.2.html
// http://man7.org/linux/man-pages/man1/uptime.1.html
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/zsyscall_linux_amd64.go#L1050
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_amd64.go#L528
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_arm.go#L502
func Get() *SI {
/*
// Note: uint64 is uint32 on 32 bit CPUs
type Sysinfo_t struct {
Uptime int64 // Seconds since boot
Loads [3]uint64 // 1, 5, and 15 minute load averages
Totalram uint64 // Total usable main memory size
Freeram uint64 // Available memory size
Sharedram uint64 // Amount of shared memory
Bufferram uint64 // Memory used by buffers
Totalswap uint64 // Total swap space size
Freeswap uint64 // swap space still available
Procs uint16 // Number of current processes
Pad uint16
Pad_cgo_0 [4]byte
Totalhigh uint64 // Total high memory size
Freehigh uint64 // Available high memory size
Unit uint32 // Memory unit size in bytes
X_f [0]byte
Pad_cgo_1 [4]byte // Padding to 64 bytes
}
*/
// ~1kB garbage
si := &syscall.Sysinfo_t{}
// XXX is a raw syscall thread safe?
err := syscall.Sysinfo(si)
if err != nil {
panic("Commander, we have a problem. syscall.Sysinfo:" + err.Error())
}
scale := 65536.0 // magic
defer sis.mu.Unlock()
sis.mu.Lock()
unit := uint64(si.Unit) * 1024 // kB
sis.Uptime = time.Duration(si.Uptime) * time.Second
sis.Loads[0] = float64(si.Loads[0]) / scale
sis.Loads[1] = float64(si.Loads[1]) / scale
sis.Loads[2] = float64(si.Loads[2]) / scale
sis.Procs = uint64(si.Procs)
sis.TotalRam = uint64(si.Totalram) / unit
sis.FreeRam = uint64(si.Freeram) / unit
sis.BufferRam = uint64(si.Bufferram) / unit
sis.TotalSwap = uint64(si.Totalswap) / unit
sis.FreeSwap = uint64(si.Freeswap) / unit
sis.TotalHighRam = uint64(si.Totalhigh) / unit
sis.FreeHighRam = uint64(si.Freehigh) / unit
return sis
}
// Make the "fmt" Stringer interface happy.
func (si SI) String() string {
// XXX: Is the copy of SI done atomic? Not sure.
// Without an outer lock this may print a junk.
return fmt.Sprintf("uptime\t\t%v\nload\t\t%2.2f %2.2f %2.2f\nprocs\t\t%d\n"+
"ram total\t%d kB\nram free\t%d kB\nram buffer\t%d kB\n"+
"swap total\t%d kB\nswap free\t%d kB",
//"high ram total\t%d kB\nhigh ram free\t%d kB\n"
si.Uptime, si.Loads[0], si.Loads[1], si.Loads[2], si.Procs,
si.TotalRam, si.FreeRam, si.BufferRam,
si.TotalSwap, si.FreeSwap,
// archaic si.TotalHighRam, si.FreeHighRam
)
}
/*
Convert to string in a thread safe way.
Output:
uptime 279h6m21s
load 0.12 0.04 0.05
procs 143
ram total 383752 kB
ram free 254980 kB
ram buffer 7640 kB
swap total 887800 kB
swap free 879356 kB
*/
func (si *SI) ToString() string {
defer si.mu.Unlock()
si.mu.Lock()
return si.String()
}

View File

@ -126,29 +126,28 @@ func AtInt64ValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) i
// int at_requires_grad(tensor);
func AtRequiresGrad(ts Ctensor) bool {
retVal := C.at_requires_grad((C.tensor)(ts))
retVal := C.at_requires_grad(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined((C.tensor)(ts))
retVal := C.at_defined(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse((C.tensor)(ts))
retVal := C.at_is_sparse(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// void at_backward(tensor, int, int);
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
ctensor := (C.tensor)(ts)
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_backward(ctensor, ckeepGraph, ccreateGraph)
C.at_backward(ts, ckeepGraph, ccreateGraph)
}
/*
@ -169,39 +168,33 @@ func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninput
}
// void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
ctensor := (C.tensor)(tensor)
func AtCopyData(ts Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
cnumel := *(*C.size_t)(unsafe.Pointer(&numel))
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
C.at_copy_data(ts, vs, cnumel, celement_size_in_bytes)
}
// tensor at_shallow_clone(tensor);
func AtShallowClone(ts Ctensor) Ctensor {
ctensor := (C.tensor)(ts)
return C.at_shallow_clone(ctensor)
return C.at_shallow_clone(ts)
}
// tensor at_get(tensor, int index);
func AtGet(ts Ctensor, index int) Ctensor {
ctensor := (C.tensor)(ts)
cindex := *(*C.int)(unsafe.Pointer(&index))
return C.at_get(ctensor, cindex)
return C.at_get(ts, cindex)
}
// void at_copy_(tensor dst, tensor src);
func AtCopy_(dst Ctensor, src Ctensor) {
cdst := (C.tensor)(dst)
csrc := (C.tensor)(src)
C.at_copy_(cdst, csrc)
C.at_copy_(dst, src)
}
// void at_save(tensor, char *filename);
func AtSave(ts Ctensor, path string) {
ctensor := (C.tensor)(ts)
cstringPtr := C.CString(path)
defer C.free(unsafe.Pointer(cstringPtr))
C.at_save(ctensor, cstringPtr)
C.at_save(ts, cstringPtr)
}
// tensor at_load(char *filename);
@ -300,9 +293,8 @@ func AtLoadCallbackWithDevice(filename string, dataPtr unsafe.Pointer, device in
* }
* */
func AtToString(ts Ctensor, lineSize int64) string {
ctensor := (C.tensor)(ts)
clineSize := *(*C.int)(unsafe.Pointer(&lineSize))
charPtr := C.at_to_string(ctensor, clineSize)
charPtr := C.at_to_string(ts, clineSize)
goString := C.GoString(charPtr)
return goString
@ -310,8 +302,7 @@ func AtToString(ts Ctensor, lineSize int64) string {
// void at_free(tensor);
func AtFree(ts Ctensor) {
ctensor := (C.tensor)(ts)
C.at_free(ctensor)
C.at_free(ts)
}
//int at_grad_set_enabled(int b);

View File

@ -90,6 +90,6 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
// 1 1 1
// 1 1 1 ]
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
return xs.MustMatMul(l.Ws.MustT()).MustAdd(l.Bs)
clone := l.Ws.MustShallowClone().MustT(true)
return xs.MustMm(clone, false).MustAdd(l.Bs, true)
}

View File

@ -78,12 +78,11 @@ func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
}
// forward sequentially
var currTs ts.Tensor = xs
for i := 0; i < len(s.layers); i++ {
currTs = s.layers[i].Forward(currTs)
xs = s.layers[i].Forward(xs)
}
return currTs
return xs
}
// SequentialT is a sequential layer combining new layers with support for a training mode.
@ -136,7 +135,6 @@ func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
// Add appends a layer after all the current layers.
func (s *SequentialT) Add(l ts.ModuleT) {
s.layers = append(s.layers, l)
}

View File

@ -274,7 +274,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
for k, v := range vs.Vars.NamedVariables {
srcTs, _ := srcNamedVariables[k]
srcDevTs, err := srcTs.To(device)
srcDevTs, err := srcTs.To(device, false)
if err != nil {
return err
}

View File

@ -84,8 +84,8 @@ func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) {
func (it Iter2) Shuffle() (retVal Iter2) {
index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
it.xs = it.xs.MustIndexSelect(0, index)
it.ys = it.ys.MustIndexSelect(0, index)
it.xs = it.xs.MustIndexSelect(0, index, true)
it.ys = it.ys.MustIndexSelect(0, index, true)
return it
}

View File

@ -252,14 +252,14 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
switch reflect.TypeOf(spec).Name() {
case "InsertNewAxis":
nextTensor, err = currTensor.Unsqueeze(currIdx)
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
if err != nil {
return retVal, err
}
nextIdx = currIdx + 1
case "Select": // 1 field: `Index`
index := reflect.ValueOf(spec).FieldByName("Index").Interface().(int64)
nextTensor, err = currTensor.Select(currIdx, index) // TODO: double-check is `*index` or `index`
nextTensor, err = currTensor.Select(currIdx, index, true) // TODO: double-check is `*index` or `index`
if err != nil {
return retVal, err
}
@ -269,7 +269,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
// NOTE: for now, just implement (Included(start), Excluded(end))` case
start := reflect.ValueOf(spec).FieldByName("Start").Interface().(int64)
end := reflect.ValueOf(spec).FieldByName("End").Interface().(int64)
nextTensor, err = currTensor.Narrow(currIdx, start, end-start)
nextTensor, err = currTensor.Narrow(currIdx, start, end-start, true)
if err != nil {
return retVal, err
}
@ -280,11 +280,11 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
if err != nil {
return retVal, err
}
indexTensor, err = indexTensor.To(device)
indexTensor, err = indexTensor.To(device, true)
if err != nil {
return retVal, err
}
nextTensor, err = currTensor.IndexSelect(currIdx, indexTensor)
nextTensor, err = currTensor.IndexSelect(currIdx, indexTensor, true)
if err != nil {
return retVal, err
}

View File

@ -72,7 +72,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
break
}
acc := m.ForwardT(item.Data.MustTo(d), false).AccuracyForLogits(item.Label.MustTo(d)).MustView([]int64{-1}).MustFloat64Value([]int64{0})
acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}).MustFloat64Value([]int64{0})
size := float64(item.Data.MustSize()[0])
sumAccuracy += acc * size
sampleCount += size

View File

@ -8,13 +8,13 @@ import (
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) {
return ts.MustLogSoftmax(-1, gotch.Float.CInt()).MustNllLoss(targets)
return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true)
}
// AccuracyForLogits returns the average accuracy for some given logits assuming that
// targets represent ground-truth.
func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
return ts.MustArgmax(-1, false).MustEq1(targets).MustTotype(gotch.Float).MustMean(gotch.Float.CInt())
return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
}
// TODO: continue

View File

@ -12,12 +12,15 @@ import (
lib "github.com/sugarme/gotch/libtch"
)
func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) {
func (ts Tensor) To(device gotch.Device, del bool) (retVal Tensor, err error) {
// TODO: how to get pointer to CUDA memory???
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
@ -28,9 +31,9 @@ func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) {
func (ts Tensor) MustTo(device gotch.Device, del bool) (retVal Tensor) {
var err error
retVal, err = ts.To(device)
retVal, err = ts.To(device, del)
if err != nil {
log.Fatal(err)
}
@ -38,9 +41,11 @@ func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) {
return retVal
}
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
func (ts Tensor) Matmul(other Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgMatmul(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
@ -50,8 +55,8 @@ func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) {
retVal, err := ts.Matmul(other)
func (ts Tensor) MustMatMul(other Tensor, del bool) (retVal Tensor) {
retVal, err := ts.Matmul(other, del)
if err != nil {
log.Fatal(err)
}
@ -61,7 +66,6 @@ func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) {
func (ts Tensor) Grad() (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgGrad(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
@ -82,7 +86,6 @@ func (ts Tensor) MustGrad() (retVal Tensor) {
func (ts Tensor) Detach_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgDetach_(ptr, ts.ctensor)
if err := TorchErr(); err != nil {
@ -92,7 +95,6 @@ func (ts Tensor) Detach_() {
func (ts Tensor) Zero_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgZero_(ptr, ts.ctensor)
if err := TorchErr(); err != nil {
@ -107,7 +109,6 @@ func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSetRequiresGrad(ptr, ts.ctensor, r)
@ -127,9 +128,11 @@ func (ts Tensor) MustSetRequiresGrad(rb bool) (retVal Tensor) {
return retVal
}
func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) {
func (ts Tensor) Mul(other Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgMul(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
@ -139,8 +142,8 @@ func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
retVal, err := ts.Mul(other)
func (ts Tensor) MustMul(other Tensor, del bool) (retVal Tensor) {
retVal, err := ts.Mul(other, del)
if err != nil {
log.Fatal(err)
}
@ -148,9 +151,11 @@ func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
return retVal
}
func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) {
func (ts Tensor) Mul1(other Scalar, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgMul1(ptr, ts.ctensor, other.cscalar)
if err = TorchErr(); err != nil {
@ -160,8 +165,8 @@ func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) {
retVal, err := ts.Mul1(other)
func (ts Tensor) MustMul1(other Scalar, del bool) (retVal Tensor) {
retVal, err := ts.Mul1(other, del)
if err != nil {
log.Fatal(err)
}
@ -171,7 +176,6 @@ func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) {
func (ts Tensor) Mul_(other Tensor) (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMul_(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
@ -188,9 +192,11 @@ func (ts Tensor) MustMul_(other Tensor) {
}
}
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
func (ts Tensor) Add(other Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
@ -200,8 +206,8 @@ func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
retVal, err := ts.Add(other)
func (ts Tensor) MustAdd(other Tensor, del bool) (retVal Tensor) {
retVal, err := ts.Add(other, del)
if err != nil {
log.Fatal(err)
}
@ -211,7 +217,6 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
func (ts Tensor) Add_(other Tensor) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd_(ptr, ts.ctensor, other.ctensor)
if err := TorchErr(); err != nil {
@ -219,9 +224,11 @@ func (ts Tensor) Add_(other Tensor) {
}
}
func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) {
func (ts Tensor) Add1(other Scalar, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgAdd1(ptr, ts.ctensor, other.cscalar)
if err = TorchErr(); err != nil {
@ -231,8 +238,8 @@ func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) {
retVal, err := ts.Add1(other)
func (ts Tensor) MustAdd1(other Scalar, del bool) (retVal Tensor) {
retVal, err := ts.Add1(other, del)
if err != nil {
log.Fatal(err)
@ -243,7 +250,6 @@ func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) {
func (ts Tensor) AddG(other Tensor) (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
@ -263,9 +269,11 @@ func (ts Tensor) MustAddG(other Tensor) {
}
// Totype casts type of tensor to a new tensor with specified DType
func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) {
func (ts Tensor) Totype(dtype gotch.DType, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return retVal, err
@ -283,8 +291,8 @@ func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) {
// Totype casts type of tensor to a new tensor with specified DType. It will
// panic if error
func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) {
retVal, err := ts.Totype(dtype)
func (ts Tensor) MustTotype(dtype gotch.DType, del bool) (retVal Tensor) {
retVal, err := ts.Totype(dtype, del)
if err != nil {
log.Fatal(err)
}
@ -293,10 +301,11 @@ func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) {
}
// Unsqueeze unsqueezes tensor to specified dimension.
func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) {
func (ts Tensor) Unsqueeze(dim int64, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgUnsqueeze(ptr, ts.ctensor, dim)
if err = TorchErr(); err != nil {
return retVal, err
@ -308,10 +317,11 @@ func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) {
}
// Select creates a new tensor from current tensor given dim and index.
func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) {
func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgSelect(ptr, ts.ctensor, dim, index)
if err = TorchErr(); err != nil {
return retVal, err
@ -324,9 +334,11 @@ func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) {
// Narrow creates a new tensor from current tensor given dim and start index
// and length.
func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, err error) {
func (ts Tensor) Narrow(dim int64, start int64, length int64, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgNarrow(ptr, ts.ctensor, dim, start, length)
if err = TorchErr(); err != nil {
@ -340,9 +352,11 @@ func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, er
// IndexSelect creates a new tensor from current tensor given dim and index
// tensor.
func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error) {
func (ts Tensor) IndexSelect(dim int64, index Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgIndexSelect(ptr, ts.ctensor, dim, index.ctensor)
if err = TorchErr(); err != nil {
@ -353,8 +367,8 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
return retVal, nil
}
func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) {
retVal, err := ts.IndexSelect(dim, index)
func (ts Tensor) MustIndexSelect(dim int64, index Tensor, del bool) (retVal Tensor) {
retVal, err := ts.IndexSelect(dim, index, del)
if err != nil {
log.Fatal(err)
}
@ -364,7 +378,6 @@ func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) {
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgZeros(ptr, size, len(size), optionsKind, optionsDevice)
if err = TorchErr(); err != nil {
@ -386,7 +399,6 @@ func MustZeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
func Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgOnes(ptr, size, len(size), optionsKind, optionsDevice)
if err = TorchErr(); err != nil {
@ -410,7 +422,6 @@ func MustOnes(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
func (ts Tensor) Uniform_(from float64, to float64) {
var err error
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgUniform_(ptr, ts.ctensor, from, to)
if err = TorchErr(); err != nil {
@ -418,10 +429,11 @@ func (ts Tensor) Uniform_(from float64, to float64) {
}
}
func (ts Tensor) ZerosLike() (retVal Tensor, err error) {
func (ts Tensor) ZerosLike(del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgZerosLike(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
@ -435,7 +447,6 @@ func (ts Tensor) ZerosLike() (retVal Tensor, err error) {
func (ts Tensor) Fill_(value Scalar) {
var err error
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgFill_(ptr, ts.ctensor, value.cscalar)
if err = TorchErr(); err != nil {
@ -443,10 +454,11 @@ func (ts Tensor) Fill_(value Scalar) {
}
}
func (ts Tensor) RandnLike() (retVal Tensor, err error) {
func (ts Tensor) RandnLike(del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgRandnLike(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
@ -457,10 +469,11 @@ func (ts Tensor) RandnLike() (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
func (ts Tensor) Permute(dims []int64, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgPermute(ptr, ts.ctensor, dims, len(dims))
if err = TorchErr(); err != nil {
@ -472,9 +485,11 @@ func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) {
func (ts Tensor) Squeeze1(dim int64, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgSqueeze1(ptr, ts.ctensor, dim)
@ -487,9 +502,9 @@ func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) {
func (ts Tensor) MustSqueeze1(dim int64, del bool) (retVal Tensor) {
var err error
retVal, err = ts.Squeeze1(dim)
retVal, err = ts.Squeeze1(dim, del)
if err != nil {
log.Fatal(err)
}
@ -527,9 +542,11 @@ func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) {
func (ts Tensor) Mm(mat2 Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgMm(ptr, ts.ctensor, mat2.ctensor)
if err = TorchErr(); err != nil {
@ -541,8 +558,8 @@ func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) {
retVal, err := ts.Mm(mat2)
func (ts Tensor) MustMm(mat2 Tensor, del bool) (retVal Tensor) {
retVal, err := ts.Mm(mat2, del)
if err != nil {
log.Fatal(err)
}
@ -550,9 +567,11 @@ func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) {
return retVal
}
func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) {
func (ts Tensor) LogSoftmax(dim int64, dtype int32, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgLogSoftmax(ptr, ts.ctensor, dim, dtype)
if err = TorchErr(); err != nil {
@ -564,8 +583,8 @@ func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) {
retVal, err := ts.LogSoftmax(dim, dtype)
func (ts Tensor) MustLogSoftmax(dim int64, dtype int32, del bool) (retVal Tensor) {
retVal, err := ts.LogSoftmax(dim, dtype, del)
if err != nil {
log.Fatal(err)
}
@ -573,10 +592,11 @@ func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) {
return retVal
}
func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) {
func (ts Tensor) NllLoss(target Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// NOTE: uncomment this causes panic
// defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
weight := NewTensor()
@ -594,8 +614,8 @@ func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) {
retVal, err := ts.NllLoss(target)
func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) {
retVal, err := ts.NllLoss(target, del)
if err != nil {
log.Fatal(err)
}
@ -603,9 +623,11 @@ func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) {
return retVal
}
func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) {
func (ts Tensor) Argmax(dim int64, keepDim bool, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
var ckeepDim int = 0
if keepDim {
@ -622,8 +644,8 @@ func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) {
retVal, err := ts.Argmax(dim, keepDim)
func (ts Tensor) MustArgmax(dim int64, keepDim bool, del bool) (retVal Tensor) {
retVal, err := ts.Argmax(dim, keepDim, del)
if err != nil {
log.Fatal(err)
}
@ -631,9 +653,11 @@ func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) {
return retVal
}
func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) {
func (ts Tensor) Mean(dtype int32, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgMean(ptr, ts.ctensor, dtype)
if err = TorchErr(); err != nil {
@ -645,8 +669,8 @@ func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustMean(dtype int32) (retVal Tensor) {
retVal, err := ts.Mean(dtype)
func (ts Tensor) MustMean(dtype int32, del bool) (retVal Tensor) {
retVal, err := ts.Mean(dtype, del)
if err != nil {
log.Fatal(err)
}
@ -677,9 +701,11 @@ func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) {
return retVal
}
func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) {
func (ts Tensor) Div1(other Scalar, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgDiv1(ptr, ts.ctensor, other.cscalar)
if err = TorchErr(); err != nil {
@ -691,8 +717,8 @@ func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
retVal, err := ts.Div1(other)
func (ts Tensor) MustDiv1(other Scalar, del bool) (retVal Tensor) {
retVal, err := ts.Div1(other, del)
if err != nil {
log.Fatal(err)
}
@ -702,7 +728,6 @@ func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgRandperm(ptr, n, optionKind.CInt(), optionDevice.CInt())
if err = TorchErr(); err != nil {
@ -725,7 +750,6 @@ func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (r
func (ts Tensor) Clamp_(min Scalar, max Scalar) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
if err := TorchErr(); err != nil {
@ -735,7 +759,6 @@ func (ts Tensor) Clamp_(min Scalar, max Scalar) {
func (ts Tensor) Relu_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgRelu_(ptr, ts.ctensor)
if err := TorchErr(); err != nil {
@ -743,9 +766,11 @@ func (ts Tensor) Relu_() {
}
}
func (ts Tensor) Relu() (retVal Tensor, err error) {
func (ts Tensor) Relu(del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgRelu(ptr, ts.ctensor)
err = TorchErr()
@ -758,8 +783,8 @@ func (ts Tensor) Relu() (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustRelu() (retVal Tensor) {
retVal, err := ts.Relu()
func (ts Tensor) MustRelu(del bool) (retVal Tensor) {
retVal, err := ts.Relu(del)
if err != nil {
log.Fatal(err)
}
@ -767,9 +792,11 @@ func (ts Tensor) MustRelu() (retVal Tensor) {
return retVal
}
func (ts Tensor) T() (retVal Tensor, err error) {
func (ts Tensor) T(del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgT(ptr, ts.ctensor)
err = TorchErr()
@ -782,8 +809,8 @@ func (ts Tensor) T() (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustT() (retVal Tensor) {
retVal, err := ts.T()
func (ts Tensor) MustT(del bool) (retVal Tensor) {
retVal, err := ts.T(del)
if err != nil {
log.Fatal(err)
}
@ -802,10 +829,12 @@ func (ts Tensor) T_() {
}
}
func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error) {
func (ts Tensor) MseLoss(target Tensor, reduction int, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgMseLoss(ptr, ts.ctensor, target.ctensor, reduction)
err = TorchErr()
@ -818,8 +847,8 @@ func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error
return retVal, nil
}
func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) {
retVal, err := ts.MseLoss(target, reduction)
func (ts Tensor) MustMseLoss(target Tensor, reduction int, del bool) (retVal Tensor) {
retVal, err := ts.MseLoss(target, reduction, del)
if err != nil {
log.Fatal(err)
@ -828,9 +857,11 @@ func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) {
return retVal
}
func (ts Tensor) Exp() (retVal Tensor, err error) {
func (ts Tensor) Exp(del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgExp(ptr, ts.ctensor)
err = TorchErr()
@ -843,8 +874,8 @@ func (ts Tensor) Exp() (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustExp() (retVal Tensor) {
retVal, err := ts.Exp()
func (ts Tensor) MustExp(del bool) (retVal Tensor) {
retVal, err := ts.Exp(del)
if err != nil {
log.Fatal(err)
@ -864,9 +895,11 @@ func (ts Tensor) Exp_() {
}
}
func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) {
func (ts Tensor) Pow(exponent Scalar, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgPow(ptr, ts.ctensor, exponent.cscalar)
err = TorchErr()
@ -879,8 +912,8 @@ func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) {
retVal, err := ts.Pow(exponent)
func (ts Tensor) MustPow(exponent Scalar, del bool) (retVal Tensor) {
retVal, err := ts.Pow(exponent, del)
if err != nil {
log.Fatal(err)
@ -889,9 +922,11 @@ func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) {
return retVal
}
func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) {
func (ts Tensor) Sum(dtype int32, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgSum(ptr, ts.ctensor, dtype)
err = TorchErr()
@ -904,8 +939,8 @@ func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustSum(dtype int32) (retVal Tensor) {
retVal, err := ts.Sum(dtype)
func (ts Tensor) MustSum(dtype int32, del bool) (retVal Tensor) {
retVal, err := ts.Sum(dtype, del)
if err != nil {
log.Fatal(err)
@ -914,9 +949,11 @@ func (ts Tensor) MustSum(dtype int32) (retVal Tensor) {
return retVal
}
func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) {
func (ts Tensor) Sub(other Tensor, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgSub(ptr, ts.ctensor, other.ctensor)
err = TorchErr()
@ -929,8 +966,8 @@ func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustSub(other Tensor) (retVal Tensor) {
retVal, err := ts.Sub(other)
func (ts Tensor) MustSub(other Tensor, del bool) (retVal Tensor) {
retVal, err := ts.Sub(other, del)
if err != nil {
log.Fatal(err)
@ -939,9 +976,11 @@ func (ts Tensor) MustSub(other Tensor) (retVal Tensor) {
return retVal
}
func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) {
func (ts Tensor) Sub1(other Scalar, del bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
if del {
defer ts.MustDrop()
}
lib.AtgSub1(ptr, ts.ctensor, other.cscalar)
err = TorchErr()
@ -954,8 +993,8 @@ func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) MustSub1(other Scalar) (retVal Tensor) {
retVal, err := ts.Sub1(other)
func (ts Tensor) MustSub1(other Scalar, del bool) (retVal Tensor) {
retVal, err := ts.Sub1(other, del)
if err != nil {
log.Fatal(err)

View File

@ -2,6 +2,7 @@ package tensor
//#include "stdlib.h"
//#include "stdbool.h"
//#include<stdio.h>
import "C"
import (

View File

@ -16,7 +16,7 @@ import (
// (height, width, channel) -> (channel, height, width)
func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) {
var err error
retVal, err = tensor.Permute([]int64{2, 0, 1})
retVal, err = tensor.Permute([]int64{2, 0, 1}, true)
if err != nil {
log.Fatalf("hwcToCHW error: %v\n", err)
}
@ -25,7 +25,7 @@ func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) {
func chwToHWC(tensor ts.Tensor) (retVal ts.Tensor) {
var err error
retVal, err = tensor.Permute([]int64{1, 2, 0})
retVal, err = tensor.Permute([]int64{1, 2, 0}, true)
if err != nil {
log.Fatalf("hwcToCHW error: %v\n", err)
}
@ -54,7 +54,7 @@ func Load(path string) (retVal ts.Tensor, err error) {
// The tensor input should be of kind UInt8 with values ranging from
// 0 to 255.
func Save(tensor ts.Tensor, path string) (err error) {
t, err := tensor.Totype(gotch.Uint8)
t, err := tensor.Totype(gotch.Uint8, true)
if err != nil {
err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err)
return err
@ -68,9 +68,9 @@ func Save(tensor ts.Tensor, path string) (err error) {
switch {
case len(shape) == 4 && shape[0] == 1:
return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0)).MustTo(gotch.CPU)), path)
return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0), true).MustTo(gotch.CPU, true)), path)
case len(shape) == 3:
return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU)), path)
return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU, true)), path)
default:
err = fmt.Errorf("Unexpected size (%v) for image tensor.\n", len(shape))
return err
@ -125,7 +125,7 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t
var tensorW ts.Tensor
var tensorH ts.Tensor
if resizeW != outW {
tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW)
tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW, true)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
return retVal, err
@ -135,7 +135,7 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t
if resizeH == outH {
retVal = tensorW
} else {
tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH)
tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH, true)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
return retVal, err

View File

@ -82,7 +82,7 @@ func readLabels(filename string) (retVal ts.Tensor) {
log.Fatal(err)
}
retVal = labelsTs.MustTotype(gotch.Int64)
retVal = labelsTs.MustTotype(gotch.Int64, true)
return retVal
}
@ -125,7 +125,7 @@ func readImages(filename string) (retVal ts.Tensor) {
err = fmt.Errorf("create images tensor err.")
log.Fatal(err)
}
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float).MustDiv1(ts.FloatScalar(255.0))
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true)
return retVal
}