feat(tensor): add option to delete tensor after applying operator to free up C memory
This commit is contained in:
parent
4ffe5feb7a
commit
f36d2482a1
|
@ -13,7 +13,7 @@ const (
|
|||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
|
||||
epochs = 200
|
||||
epochs = 500
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
|
@ -28,22 +28,23 @@ func runLinear() {
|
|||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
|
||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
logits := ds.TrainImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(ds.TrainLabels, true)
|
||||
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
loss.MustBackward()
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
ws.Add_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0), true))
|
||||
bs.Add_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0), true))
|
||||
})
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||
|
||||
loss.MustDrop()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,8 +16,7 @@ const (
|
|||
LabelNN int64 = 10
|
||||
MnistDirNN string = "../../data/mnist"
|
||||
|
||||
epochsNN = 50
|
||||
batchSizeNN = 256
|
||||
epochsNN = 200
|
||||
|
||||
LrNN = 1e-3
|
||||
)
|
||||
|
@ -27,12 +26,10 @@ var l nn.Linear
|
|||
func netInit(vs nn.Path) ts.Module {
|
||||
n := nn.Seq()
|
||||
|
||||
l = nn.NewLinear(vs.Sub("layer1"), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())
|
||||
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
||||
|
||||
n.Add(l)
|
||||
|
||||
n.AddFn(nn.ForwardWith(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustRelu()
|
||||
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustRelu(true)
|
||||
}))
|
||||
|
||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
||||
|
@ -45,9 +42,11 @@ func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer
|
|||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
testAccuracy := m.Forward(testX).AccuracyForLogits(testY).Values()[0]
|
||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||
testAccuracy := m.Forward(testX).AccuracyForLogits(testY)
|
||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy.Values()[0]*100)
|
||||
|
||||
loss.MustDrop()
|
||||
testAccuracy.MustDrop()
|
||||
}
|
||||
|
||||
func runNN() {
|
||||
|
@ -62,9 +61,7 @@ func runNN() {
|
|||
}
|
||||
|
||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||
|
||||
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
74
example/tensor-memory/main.go
Normal file
74
example/tensor-memory/main.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
// "runtime"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func createTensors(samples int) []ts.Tensor {
|
||||
n := int(10e6)
|
||||
var data []float64
|
||||
for i := 0; i < n; i++ {
|
||||
data = append(data, float64(i))
|
||||
}
|
||||
|
||||
var tensors []ts.Tensor
|
||||
s := ts.FloatScalar(float64(0.23))
|
||||
|
||||
// for i := 0; i < samples; i++ {
|
||||
for i := 0; i < 1; i++ {
|
||||
t := ts.MustOfSlice(data).MustMul1(s, true)
|
||||
|
||||
// t1.MustDrop()
|
||||
// t.MustDrop()
|
||||
// t1 = ts.Tensor{}
|
||||
// t = ts.Tensor{}
|
||||
// runtime.GC()
|
||||
|
||||
// fmt.Printf("t values: %v", t.Values())
|
||||
// fmt.Printf("t1 values: %v", t1.Values())
|
||||
tensors = append(tensors, t)
|
||||
}
|
||||
|
||||
return tensors
|
||||
}
|
||||
|
||||
func dropTensors(tensors []ts.Tensor) {
|
||||
for _, t := range tensors {
|
||||
t.MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
var si *SI
|
||||
si = Get()
|
||||
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
|
||||
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
|
||||
|
||||
startRAM := si.TotalRam - si.FreeRam
|
||||
|
||||
epochs := 50
|
||||
// var m runtime.MemStats
|
||||
|
||||
for i := 0; i < epochs; i++ {
|
||||
// runtime.ReadMemStats(&m)
|
||||
// t0 := float64(m.Sys) / 1024 / 1024
|
||||
|
||||
tensors := createTensors(10000)
|
||||
|
||||
// runtime.ReadMemStats(&m)
|
||||
// t1 := float64(m.Sys) / 1024 / 1024
|
||||
|
||||
dropTensors(tensors)
|
||||
|
||||
// runtime.ReadMemStats(&m)
|
||||
// t2 := float64(m.Sys) / 1024 / 1024
|
||||
|
||||
// fmt.Printf("Epoch: %v \t Start Mem [%.3f MiB] \t Alloc Mem [%.3f MiB] \t Free Mem [%.3f MiB]\n", i, t0, t1, t2)
|
||||
si = Get()
|
||||
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", i, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
|
||||
}
|
||||
}
|
124
example/tensor-memory/sysinfo.go
Normal file
124
example/tensor-memory/sysinfo.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
// A wrapper around the linux syscall sysinfo(2).
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Go-ized http://man7.org/linux/man-pages/man2/sysinfo.2.html
|
||||
type SI struct {
|
||||
Uptime time.Duration // time since boot
|
||||
Loads [3]float64 // 1, 5, and 15 minute load averages, see e.g. UPTIME(1)
|
||||
Procs uint64 // number of current processes
|
||||
TotalRam uint64 // total usable main memory size [kB]
|
||||
FreeRam uint64 // available memory size [kB]
|
||||
SharedRam uint64 // amount of shared memory [kB]
|
||||
BufferRam uint64 // memory used by buffers [kB]
|
||||
TotalSwap uint64 // total swap space size [kB]
|
||||
FreeSwap uint64 // swap space still available [kB]
|
||||
TotalHighRam uint64 // total high memory size [kB]
|
||||
FreeHighRam uint64 // available high memory size [kB]
|
||||
mu sync.Mutex // ensures atomic writes; protects the following fields
|
||||
}
|
||||
|
||||
var sis = &SI{}
|
||||
|
||||
// Get the linux sysinfo data structure.
|
||||
//
|
||||
// Useful links in the wild web:
|
||||
// http://man7.org/linux/man-pages/man2/sysinfo.2.html
|
||||
// http://man7.org/linux/man-pages/man1/uptime.1.html
|
||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/zsyscall_linux_amd64.go#L1050
|
||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_amd64.go#L528
|
||||
// https://github.com/capnm/golang/blob/go1.1.1/src/pkg/syscall/ztypes_linux_arm.go#L502
|
||||
func Get() *SI {
|
||||
|
||||
/*
|
||||
// Note: uint64 is uint32 on 32 bit CPUs
|
||||
type Sysinfo_t struct {
|
||||
Uptime int64 // Seconds since boot
|
||||
Loads [3]uint64 // 1, 5, and 15 minute load averages
|
||||
Totalram uint64 // Total usable main memory size
|
||||
Freeram uint64 // Available memory size
|
||||
Sharedram uint64 // Amount of shared memory
|
||||
Bufferram uint64 // Memory used by buffers
|
||||
Totalswap uint64 // Total swap space size
|
||||
Freeswap uint64 // swap space still available
|
||||
Procs uint16 // Number of current processes
|
||||
Pad uint16
|
||||
Pad_cgo_0 [4]byte
|
||||
Totalhigh uint64 // Total high memory size
|
||||
Freehigh uint64 // Available high memory size
|
||||
Unit uint32 // Memory unit size in bytes
|
||||
X_f [0]byte
|
||||
Pad_cgo_1 [4]byte // Padding to 64 bytes
|
||||
}
|
||||
*/
|
||||
|
||||
// ~1kB garbage
|
||||
si := &syscall.Sysinfo_t{}
|
||||
|
||||
// XXX is a raw syscall thread safe?
|
||||
err := syscall.Sysinfo(si)
|
||||
if err != nil {
|
||||
panic("Commander, we have a problem. syscall.Sysinfo:" + err.Error())
|
||||
}
|
||||
scale := 65536.0 // magic
|
||||
|
||||
defer sis.mu.Unlock()
|
||||
sis.mu.Lock()
|
||||
|
||||
unit := uint64(si.Unit) * 1024 // kB
|
||||
|
||||
sis.Uptime = time.Duration(si.Uptime) * time.Second
|
||||
sis.Loads[0] = float64(si.Loads[0]) / scale
|
||||
sis.Loads[1] = float64(si.Loads[1]) / scale
|
||||
sis.Loads[2] = float64(si.Loads[2]) / scale
|
||||
sis.Procs = uint64(si.Procs)
|
||||
|
||||
sis.TotalRam = uint64(si.Totalram) / unit
|
||||
sis.FreeRam = uint64(si.Freeram) / unit
|
||||
sis.BufferRam = uint64(si.Bufferram) / unit
|
||||
sis.TotalSwap = uint64(si.Totalswap) / unit
|
||||
sis.FreeSwap = uint64(si.Freeswap) / unit
|
||||
sis.TotalHighRam = uint64(si.Totalhigh) / unit
|
||||
sis.FreeHighRam = uint64(si.Freehigh) / unit
|
||||
|
||||
return sis
|
||||
}
|
||||
|
||||
// Make the "fmt" Stringer interface happy.
|
||||
func (si SI) String() string {
|
||||
// XXX: Is the copy of SI done atomic? Not sure.
|
||||
// Without an outer lock this may print a junk.
|
||||
return fmt.Sprintf("uptime\t\t%v\nload\t\t%2.2f %2.2f %2.2f\nprocs\t\t%d\n"+
|
||||
"ram total\t%d kB\nram free\t%d kB\nram buffer\t%d kB\n"+
|
||||
"swap total\t%d kB\nswap free\t%d kB",
|
||||
//"high ram total\t%d kB\nhigh ram free\t%d kB\n"
|
||||
si.Uptime, si.Loads[0], si.Loads[1], si.Loads[2], si.Procs,
|
||||
si.TotalRam, si.FreeRam, si.BufferRam,
|
||||
si.TotalSwap, si.FreeSwap,
|
||||
// archaic si.TotalHighRam, si.FreeHighRam
|
||||
)
|
||||
}
|
||||
|
||||
/*
|
||||
Convert to string in a thread safe way.
|
||||
Output:
|
||||
uptime 279h6m21s
|
||||
load 0.12 0.04 0.05
|
||||
procs 143
|
||||
ram total 383752 kB
|
||||
ram free 254980 kB
|
||||
ram buffer 7640 kB
|
||||
swap total 887800 kB
|
||||
swap free 879356 kB
|
||||
*/
|
||||
func (si *SI) ToString() string {
|
||||
defer si.mu.Unlock()
|
||||
si.mu.Lock()
|
||||
return si.String()
|
||||
}
|
|
@ -126,29 +126,28 @@ func AtInt64ValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) i
|
|||
|
||||
// int at_requires_grad(tensor);
|
||||
func AtRequiresGrad(ts Ctensor) bool {
|
||||
retVal := C.at_requires_grad((C.tensor)(ts))
|
||||
retVal := C.at_requires_grad(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_defined(tensor);
|
||||
func AtDefined(ts Ctensor) bool {
|
||||
retVal := C.at_defined((C.tensor)(ts))
|
||||
retVal := C.at_defined(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_sparse(tensor);
|
||||
func AtIsSparse(ts Ctensor) bool {
|
||||
retVal := C.at_is_sparse((C.tensor)(ts))
|
||||
retVal := C.at_is_sparse(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// void at_backward(tensor, int, int);
|
||||
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
|
||||
ctensor := (C.tensor)(ts)
|
||||
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
|
||||
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
|
||||
|
||||
C.at_backward(ctensor, ckeepGraph, ccreateGraph)
|
||||
C.at_backward(ts, ckeepGraph, ccreateGraph)
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -169,39 +168,33 @@ func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninput
|
|||
}
|
||||
|
||||
// void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
|
||||
func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
|
||||
ctensor := (C.tensor)(tensor)
|
||||
func AtCopyData(ts Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
|
||||
cnumel := *(*C.size_t)(unsafe.Pointer(&numel))
|
||||
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
|
||||
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
|
||||
C.at_copy_data(ts, vs, cnumel, celement_size_in_bytes)
|
||||
}
|
||||
|
||||
// tensor at_shallow_clone(tensor);
|
||||
func AtShallowClone(ts Ctensor) Ctensor {
|
||||
ctensor := (C.tensor)(ts)
|
||||
return C.at_shallow_clone(ctensor)
|
||||
return C.at_shallow_clone(ts)
|
||||
}
|
||||
|
||||
// tensor at_get(tensor, int index);
|
||||
func AtGet(ts Ctensor, index int) Ctensor {
|
||||
ctensor := (C.tensor)(ts)
|
||||
cindex := *(*C.int)(unsafe.Pointer(&index))
|
||||
return C.at_get(ctensor, cindex)
|
||||
return C.at_get(ts, cindex)
|
||||
}
|
||||
|
||||
// void at_copy_(tensor dst, tensor src);
|
||||
func AtCopy_(dst Ctensor, src Ctensor) {
|
||||
cdst := (C.tensor)(dst)
|
||||
csrc := (C.tensor)(src)
|
||||
C.at_copy_(cdst, csrc)
|
||||
C.at_copy_(dst, src)
|
||||
}
|
||||
|
||||
// void at_save(tensor, char *filename);
|
||||
func AtSave(ts Ctensor, path string) {
|
||||
ctensor := (C.tensor)(ts)
|
||||
cstringPtr := C.CString(path)
|
||||
defer C.free(unsafe.Pointer(cstringPtr))
|
||||
C.at_save(ctensor, cstringPtr)
|
||||
C.at_save(ts, cstringPtr)
|
||||
}
|
||||
|
||||
// tensor at_load(char *filename);
|
||||
|
@ -300,9 +293,8 @@ func AtLoadCallbackWithDevice(filename string, dataPtr unsafe.Pointer, device in
|
|||
* }
|
||||
* */
|
||||
func AtToString(ts Ctensor, lineSize int64) string {
|
||||
ctensor := (C.tensor)(ts)
|
||||
clineSize := *(*C.int)(unsafe.Pointer(&lineSize))
|
||||
charPtr := C.at_to_string(ctensor, clineSize)
|
||||
charPtr := C.at_to_string(ts, clineSize)
|
||||
goString := C.GoString(charPtr)
|
||||
|
||||
return goString
|
||||
|
@ -310,8 +302,7 @@ func AtToString(ts Ctensor, lineSize int64) string {
|
|||
|
||||
// void at_free(tensor);
|
||||
func AtFree(ts Ctensor) {
|
||||
ctensor := (C.tensor)(ts)
|
||||
C.at_free(ctensor)
|
||||
C.at_free(ts)
|
||||
}
|
||||
|
||||
//int at_grad_set_enabled(int b);
|
||||
|
|
|
@ -90,6 +90,6 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
|
|||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
|
||||
return xs.MustMatMul(l.Ws.MustT()).MustAdd(l.Bs)
|
||||
clone := l.Ws.MustShallowClone().MustT(true)
|
||||
return xs.MustMm(clone, false).MustAdd(l.Bs, true)
|
||||
}
|
||||
|
|
|
@ -78,12 +78,11 @@ func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
|||
}
|
||||
|
||||
// forward sequentially
|
||||
var currTs ts.Tensor = xs
|
||||
for i := 0; i < len(s.layers); i++ {
|
||||
currTs = s.layers[i].Forward(currTs)
|
||||
xs = s.layers[i].Forward(xs)
|
||||
}
|
||||
|
||||
return currTs
|
||||
return xs
|
||||
}
|
||||
|
||||
// SequentialT is a sequential layer combining new layers with support for a training mode.
|
||||
|
@ -136,7 +135,6 @@ func (s SequentialT) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
|||
|
||||
// Add appends a layer after all the current layers.
|
||||
func (s *SequentialT) Add(l ts.ModuleT) {
|
||||
|
||||
s.layers = append(s.layers, l)
|
||||
}
|
||||
|
||||
|
|
|
@ -274,7 +274,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
|
|||
|
||||
for k, v := range vs.Vars.NamedVariables {
|
||||
srcTs, _ := srcNamedVariables[k]
|
||||
srcDevTs, err := srcTs.To(device)
|
||||
srcDevTs, err := srcTs.To(device, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -84,8 +84,8 @@ func MustNewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2) {
|
|||
func (it Iter2) Shuffle() (retVal Iter2) {
|
||||
index := MustRandperm(it.totalSize, gotch.Int64, gotch.CPU)
|
||||
|
||||
it.xs = it.xs.MustIndexSelect(0, index)
|
||||
it.ys = it.ys.MustIndexSelect(0, index)
|
||||
it.xs = it.xs.MustIndexSelect(0, index, true)
|
||||
it.ys = it.ys.MustIndexSelect(0, index, true)
|
||||
return it
|
||||
}
|
||||
|
||||
|
|
|
@ -252,14 +252,14 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
|
||||
switch reflect.TypeOf(spec).Name() {
|
||||
case "InsertNewAxis":
|
||||
nextTensor, err = currTensor.Unsqueeze(currIdx)
|
||||
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx + 1
|
||||
case "Select": // 1 field: `Index`
|
||||
index := reflect.ValueOf(spec).FieldByName("Index").Interface().(int64)
|
||||
nextTensor, err = currTensor.Select(currIdx, index) // TODO: double-check is `*index` or `index`
|
||||
nextTensor, err = currTensor.Select(currIdx, index, true) // TODO: double-check is `*index` or `index`
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -269,7 +269,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
// NOTE: for now, just implement (Included(start), Excluded(end))` case
|
||||
start := reflect.ValueOf(spec).FieldByName("Start").Interface().(int64)
|
||||
end := reflect.ValueOf(spec).FieldByName("End").Interface().(int64)
|
||||
nextTensor, err = currTensor.Narrow(currIdx, start, end-start)
|
||||
nextTensor, err = currTensor.Narrow(currIdx, start, end-start, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -280,11 +280,11 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
indexTensor, err = indexTensor.To(device)
|
||||
indexTensor, err = indexTensor.To(device, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextTensor, err = currTensor.IndexSelect(currIdx, indexTensor)
|
||||
nextTensor, err = currTensor.IndexSelect(currIdx, indexTensor, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
|
|
@ -72,7 +72,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
|
|||
break
|
||||
}
|
||||
|
||||
acc := m.ForwardT(item.Data.MustTo(d), false).AccuracyForLogits(item.Label.MustTo(d)).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
acc := m.ForwardT(item.Data.MustTo(d, true), false).AccuracyForLogits(item.Label.MustTo(d, true)).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
size := float64(item.Data.MustSize()[0])
|
||||
sumAccuracy += acc * size
|
||||
sampleCount += size
|
||||
|
|
|
@ -8,13 +8,13 @@ import (
|
|||
|
||||
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
|
||||
func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) {
|
||||
return ts.MustLogSoftmax(-1, gotch.Float.CInt()).MustNllLoss(targets)
|
||||
return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true)
|
||||
}
|
||||
|
||||
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
||||
// targets represent ground-truth.
|
||||
func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
|
||||
return ts.MustArgmax(-1, false).MustEq1(targets).MustTotype(gotch.Float).MustMean(gotch.Float.CInt())
|
||||
return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
|
||||
}
|
||||
|
||||
// TODO: continue
|
||||
|
|
|
@ -12,12 +12,15 @@ import (
|
|||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) {
|
||||
func (ts Tensor) To(device gotch.Device, del bool) (retVal Tensor, err error) {
|
||||
|
||||
// TODO: how to get pointer to CUDA memory???
|
||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
|
||||
|
||||
|
@ -28,9 +31,9 @@ func (ts Tensor) To(device gotch.Device) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) {
|
||||
func (ts Tensor) MustTo(device gotch.Device, del bool) (retVal Tensor) {
|
||||
var err error
|
||||
retVal, err = ts.To(device)
|
||||
retVal, err = ts.To(device, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -38,9 +41,11 @@ func (ts Tensor) MustTo(device gotch.Device) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Matmul(other Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgMatmul(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -50,8 +55,8 @@ func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Matmul(other)
|
||||
func (ts Tensor) MustMatMul(other Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Matmul(other, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -61,7 +66,6 @@ func (ts Tensor) MustMatMul(other Tensor) (retVal Tensor) {
|
|||
|
||||
func (ts Tensor) Grad() (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgGrad(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -82,7 +86,6 @@ func (ts Tensor) MustGrad() (retVal Tensor) {
|
|||
|
||||
func (ts Tensor) Detach_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgDetach_(ptr, ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
|
@ -92,7 +95,6 @@ func (ts Tensor) Detach_() {
|
|||
|
||||
func (ts Tensor) Zero_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgZero_(ptr, ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
|
@ -107,7 +109,6 @@ func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
|
|||
}
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSetRequiresGrad(ptr, ts.ctensor, r)
|
||||
|
||||
|
@ -127,9 +128,11 @@ func (ts Tensor) MustSetRequiresGrad(rb bool) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Mul(other Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgMul(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -139,8 +142,8 @@ func (ts Tensor) Mul(other Tensor) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Mul(other)
|
||||
func (ts Tensor) MustMul(other Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Mul(other, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -148,9 +151,11 @@ func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Mul1(other Scalar, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgMul1(ptr, ts.ctensor, other.cscalar)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -160,8 +165,8 @@ func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Mul1(other)
|
||||
func (ts Tensor) MustMul1(other Scalar, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Mul1(other, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -171,7 +176,6 @@ func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) {
|
|||
|
||||
func (ts Tensor) Mul_(other Tensor) (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgMul_(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -188,9 +192,11 @@ func (ts Tensor) MustMul_(other Tensor) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Add(other Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -200,8 +206,8 @@ func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Add(other)
|
||||
func (ts Tensor) MustAdd(other Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Add(other, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -211,7 +217,6 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
|||
|
||||
func (ts Tensor) Add_(other Tensor) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd_(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
|
@ -219,9 +224,11 @@ func (ts Tensor) Add_(other Tensor) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Add1(other Scalar, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgAdd1(ptr, ts.ctensor, other.cscalar)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -231,8 +238,8 @@ func (ts Tensor) Add1(other Scalar) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Add1(other)
|
||||
func (ts Tensor) MustAdd1(other Scalar, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Add1(other, del)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -243,7 +250,6 @@ func (ts Tensor) MustAdd1(other Scalar) (retVal Tensor) {
|
|||
|
||||
func (ts Tensor) AddG(other Tensor) (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -263,9 +269,11 @@ func (ts Tensor) MustAddG(other Tensor) {
|
|||
}
|
||||
|
||||
// Totype casts type of tensor to a new tensor with specified DType
|
||||
func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Totype(dtype gotch.DType, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
|
@ -283,8 +291,8 @@ func (ts Tensor) Totype(dtype gotch.DType) (retVal Tensor, err error) {
|
|||
|
||||
// Totype casts type of tensor to a new tensor with specified DType. It will
|
||||
// panic if error
|
||||
func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) {
|
||||
retVal, err := ts.Totype(dtype)
|
||||
func (ts Tensor) MustTotype(dtype gotch.DType, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Totype(dtype, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -293,10 +301,11 @@ func (ts Tensor) MustTotype(dtype gotch.DType) (retVal Tensor) {
|
|||
}
|
||||
|
||||
// Unsqueeze unsqueezes tensor to specified dimension.
|
||||
func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Unsqueeze(dim int64, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgUnsqueeze(ptr, ts.ctensor, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
|
@ -308,10 +317,11 @@ func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) {
|
|||
}
|
||||
|
||||
// Select creates a new tensor from current tensor given dim and index.
|
||||
func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Select(dim int64, index int64, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgSelect(ptr, ts.ctensor, dim, index)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
|
@ -324,9 +334,11 @@ func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) {
|
|||
|
||||
// Narrow creates a new tensor from current tensor given dim and start index
|
||||
// and length.
|
||||
func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Narrow(dim int64, start int64, length int64, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgNarrow(ptr, ts.ctensor, dim, start, length)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -340,9 +352,11 @@ func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, er
|
|||
|
||||
// IndexSelect creates a new tensor from current tensor given dim and index
|
||||
// tensor.
|
||||
func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) IndexSelect(dim int64, index Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgIndexSelect(ptr, ts.ctensor, dim, index.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -353,8 +367,8 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.IndexSelect(dim, index)
|
||||
func (ts Tensor) MustIndexSelect(dim int64, index Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.IndexSelect(dim, index, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -364,7 +378,6 @@ func (ts Tensor) MustIndexSelect(dim int64, index Tensor) (retVal Tensor) {
|
|||
|
||||
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgZeros(ptr, size, len(size), optionsKind, optionsDevice)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -386,7 +399,6 @@ func MustZeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
|
|||
|
||||
func Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgOnes(ptr, size, len(size), optionsKind, optionsDevice)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -410,7 +422,6 @@ func MustOnes(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
|
|||
func (ts Tensor) Uniform_(from float64, to float64) {
|
||||
var err error
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgUniform_(ptr, ts.ctensor, from, to)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -418,10 +429,11 @@ func (ts Tensor) Uniform_(from float64, to float64) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) ZerosLike() (retVal Tensor, err error) {
|
||||
func (ts Tensor) ZerosLike(del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgZerosLike(ptr, ts.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
|
@ -435,7 +447,6 @@ func (ts Tensor) ZerosLike() (retVal Tensor, err error) {
|
|||
func (ts Tensor) Fill_(value Scalar) {
|
||||
var err error
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgFill_(ptr, ts.ctensor, value.cscalar)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -443,10 +454,11 @@ func (ts Tensor) Fill_(value Scalar) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) RandnLike() (retVal Tensor, err error) {
|
||||
func (ts Tensor) RandnLike(del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgRandnLike(ptr, ts.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
|
@ -457,10 +469,11 @@ func (ts Tensor) RandnLike() (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Permute(dims []int64, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
lib.AtgPermute(ptr, ts.ctensor, dims, len(dims))
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -472,9 +485,11 @@ func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Squeeze1(dim int64, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgSqueeze1(ptr, ts.ctensor, dim)
|
||||
|
||||
|
@ -487,9 +502,9 @@ func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) {
|
||||
func (ts Tensor) MustSqueeze1(dim int64, del bool) (retVal Tensor) {
|
||||
var err error
|
||||
retVal, err = ts.Squeeze1(dim)
|
||||
retVal, err = ts.Squeeze1(dim, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -527,9 +542,11 @@ func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Mm(mat2 Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgMm(ptr, ts.ctensor, mat2.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -541,8 +558,8 @@ func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Mm(mat2)
|
||||
func (ts Tensor) MustMm(mat2 Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Mm(mat2, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -550,9 +567,11 @@ func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) {
|
||||
func (ts Tensor) LogSoftmax(dim int64, dtype int32, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgLogSoftmax(ptr, ts.ctensor, dim, dtype)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -564,8 +583,8 @@ func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) {
|
||||
retVal, err := ts.LogSoftmax(dim, dtype)
|
||||
func (ts Tensor) MustLogSoftmax(dim int64, dtype int32, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.LogSoftmax(dim, dtype, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -573,10 +592,11 @@ func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) NllLoss(target Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// NOTE: uncomment this causes panic
|
||||
// defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
weight := NewTensor()
|
||||
|
||||
|
@ -594,8 +614,8 @@ func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.NllLoss(target)
|
||||
func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.NllLoss(target, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -603,9 +623,11 @@ func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Argmax(dim int64, keepDim bool, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
var ckeepDim int = 0
|
||||
if keepDim {
|
||||
|
@ -622,8 +644,8 @@ func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) {
|
||||
retVal, err := ts.Argmax(dim, keepDim)
|
||||
func (ts Tensor) MustArgmax(dim int64, keepDim bool, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Argmax(dim, keepDim, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -631,9 +653,11 @@ func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Mean(dtype int32, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgMean(ptr, ts.ctensor, dtype)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -645,8 +669,8 @@ func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMean(dtype int32) (retVal Tensor) {
|
||||
retVal, err := ts.Mean(dtype)
|
||||
func (ts Tensor) MustMean(dtype int32, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Mean(dtype, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -677,9 +701,11 @@ func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Div1(other Scalar, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgDiv1(ptr, ts.ctensor, other.cscalar)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -691,8 +717,8 @@ func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Div1(other)
|
||||
func (ts Tensor) MustDiv1(other Scalar, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Div1(other, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -702,7 +728,6 @@ func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
|
|||
|
||||
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgRandperm(ptr, n, optionKind.CInt(), optionDevice.CInt())
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -725,7 +750,6 @@ func MustRandperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (r
|
|||
|
||||
func (ts Tensor) Clamp_(min Scalar, max Scalar) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgClamp_(ptr, ts.ctensor, min.cscalar, max.cscalar)
|
||||
if err := TorchErr(); err != nil {
|
||||
|
@ -735,7 +759,6 @@ func (ts Tensor) Clamp_(min Scalar, max Scalar) {
|
|||
|
||||
func (ts Tensor) Relu_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgRelu_(ptr, ts.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
|
@ -743,9 +766,11 @@ func (ts Tensor) Relu_() {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Relu() (retVal Tensor, err error) {
|
||||
func (ts Tensor) Relu(del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgRelu(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
|
@ -758,8 +783,8 @@ func (ts Tensor) Relu() (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustRelu() (retVal Tensor) {
|
||||
retVal, err := ts.Relu()
|
||||
func (ts Tensor) MustRelu(del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Relu(del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -767,9 +792,11 @@ func (ts Tensor) MustRelu() (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) T() (retVal Tensor, err error) {
|
||||
func (ts Tensor) T(del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgT(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
|
@ -782,8 +809,8 @@ func (ts Tensor) T() (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustT() (retVal Tensor) {
|
||||
retVal, err := ts.T()
|
||||
func (ts Tensor) MustT(del bool) (retVal Tensor) {
|
||||
retVal, err := ts.T(del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -802,10 +829,12 @@ func (ts Tensor) T_() {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error) {
|
||||
func (ts Tensor) MseLoss(target Tensor, reduction int, del bool) (retVal Tensor, err error) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgMseLoss(ptr, ts.ctensor, target.ctensor, reduction)
|
||||
err = TorchErr()
|
||||
|
@ -818,8 +847,8 @@ func (ts Tensor) MseLoss(target Tensor, reduction int) (retVal Tensor, err error
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) {
|
||||
retVal, err := ts.MseLoss(target, reduction)
|
||||
func (ts Tensor) MustMseLoss(target Tensor, reduction int, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.MseLoss(target, reduction, del)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -828,9 +857,11 @@ func (ts Tensor) MustMseLoss(target Tensor, reduction int) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Exp() (retVal Tensor, err error) {
|
||||
func (ts Tensor) Exp(del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgExp(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
|
@ -843,8 +874,8 @@ func (ts Tensor) Exp() (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustExp() (retVal Tensor) {
|
||||
retVal, err := ts.Exp()
|
||||
func (ts Tensor) MustExp(del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Exp(del)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -864,9 +895,11 @@ func (ts Tensor) Exp_() {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Pow(exponent Scalar, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgPow(ptr, ts.ctensor, exponent.cscalar)
|
||||
err = TorchErr()
|
||||
|
@ -879,8 +912,8 @@ func (ts Tensor) Pow(exponent Scalar) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Pow(exponent)
|
||||
func (ts Tensor) MustPow(exponent Scalar, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Pow(exponent, del)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -889,9 +922,11 @@ func (ts Tensor) MustPow(exponent Scalar) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Sum(dtype int32, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgSum(ptr, ts.ctensor, dtype)
|
||||
err = TorchErr()
|
||||
|
@ -904,8 +939,8 @@ func (ts Tensor) Sum(dtype int32) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSum(dtype int32) (retVal Tensor) {
|
||||
retVal, err := ts.Sum(dtype)
|
||||
func (ts Tensor) MustSum(dtype int32, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Sum(dtype, del)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -914,9 +949,11 @@ func (ts Tensor) MustSum(dtype int32) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Sub(other Tensor, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgSub(ptr, ts.ctensor, other.ctensor)
|
||||
err = TorchErr()
|
||||
|
@ -929,8 +966,8 @@ func (ts Tensor) Sub(other Tensor) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSub(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Sub(other)
|
||||
func (ts Tensor) MustSub(other Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Sub(other, del)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -939,9 +976,11 @@ func (ts Tensor) MustSub(other Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Sub1(other Scalar, del bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgSub1(ptr, ts.ctensor, other.cscalar)
|
||||
err = TorchErr()
|
||||
|
@ -954,8 +993,8 @@ func (ts Tensor) Sub1(other Scalar) (retVal Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSub1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Sub1(other)
|
||||
func (ts Tensor) MustSub1(other Scalar, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Sub1(other, del)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
|
|
@ -2,6 +2,7 @@ package tensor
|
|||
|
||||
//#include "stdlib.h"
|
||||
//#include "stdbool.h"
|
||||
//#include<stdio.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
|
|
|
@ -16,7 +16,7 @@ import (
|
|||
// (height, width, channel) -> (channel, height, width)
|
||||
func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||
var err error
|
||||
retVal, err = tensor.Permute([]int64{2, 0, 1})
|
||||
retVal, err = tensor.Permute([]int64{2, 0, 1}, true)
|
||||
if err != nil {
|
||||
log.Fatalf("hwcToCHW error: %v\n", err)
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) {
|
|||
|
||||
func chwToHWC(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||
var err error
|
||||
retVal, err = tensor.Permute([]int64{1, 2, 0})
|
||||
retVal, err = tensor.Permute([]int64{1, 2, 0}, true)
|
||||
if err != nil {
|
||||
log.Fatalf("hwcToCHW error: %v\n", err)
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ func Load(path string) (retVal ts.Tensor, err error) {
|
|||
// The tensor input should be of kind UInt8 with values ranging from
|
||||
// 0 to 255.
|
||||
func Save(tensor ts.Tensor, path string) (err error) {
|
||||
t, err := tensor.Totype(gotch.Uint8)
|
||||
t, err := tensor.Totype(gotch.Uint8, true)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err)
|
||||
return err
|
||||
|
@ -68,9 +68,9 @@ func Save(tensor ts.Tensor, path string) (err error) {
|
|||
|
||||
switch {
|
||||
case len(shape) == 4 && shape[0] == 1:
|
||||
return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0)).MustTo(gotch.CPU)), path)
|
||||
return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0), true).MustTo(gotch.CPU, true)), path)
|
||||
case len(shape) == 3:
|
||||
return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU)), path)
|
||||
return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU, true)), path)
|
||||
default:
|
||||
err = fmt.Errorf("Unexpected size (%v) for image tensor.\n", len(shape))
|
||||
return err
|
||||
|
@ -125,7 +125,7 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t
|
|||
var tensorW ts.Tensor
|
||||
var tensorH ts.Tensor
|
||||
if resizeW != outW {
|
||||
tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW)
|
||||
tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW, true)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
|
||||
return retVal, err
|
||||
|
@ -135,7 +135,7 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t
|
|||
if resizeH == outH {
|
||||
retVal = tensorW
|
||||
} else {
|
||||
tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH)
|
||||
tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH, true)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
|
||||
return retVal, err
|
||||
|
|
|
@ -82,7 +82,7 @@ func readLabels(filename string) (retVal ts.Tensor) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
retVal = labelsTs.MustTotype(gotch.Int64)
|
||||
retVal = labelsTs.MustTotype(gotch.Int64, true)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
@ -125,7 +125,7 @@ func readImages(filename string) (retVal ts.Tensor) {
|
|||
err = fmt.Errorf("create images tensor err.")
|
||||
log.Fatal(err)
|
||||
}
|
||||
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float).MustDiv1(ts.FloatScalar(255.0))
|
||||
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user