From c1ee7689ad6887a1cdfe34d78d39d88884f95e9c Mon Sep 17 00:00:00 2001 From: sugarme Date: Fri, 7 Jul 2023 16:08:15 +1000 Subject: [PATCH] fixed check null at tensor ops return slice of tensors and clean-up --- go.mod | 2 +- libtch/patch.go | 15 +++++++++++++++ nn/varstore.go | 15 ++++++++++++++- ts/patch-example_test.go | 17 +++++++++++++++++ ts/patch.go | 41 +++++++++++++++++++--------------------- 5 files changed, 66 insertions(+), 24 deletions(-) diff --git a/go.mod b/go.mod index cd7a7b9..5ec0e19 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/sugarme/gotch -go 1.14 +go 1.19 require ( github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 diff --git a/libtch/patch.go b/libtch/patch.go index 83b2f54..1bcb873 100644 --- a/libtch/patch.go +++ b/libtch/patch.go @@ -4,10 +4,25 @@ package libtch //#include "stdbool.h" //#include "torch_api.h" +/* +bool is_null(int* pointer) { + if (NULL == pointer) { + return true; + } + return false; +} +*/ import "C" import "unsafe" +func IsNull(ctensor Ctensor) bool { + // return C.is_null(ctensor) + ret := C.is_null((*C.int)(unsafe.Pointer(ctensor))) + + return (bool)(ret) +} + // NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`: // tensor *atg_align_tensors(tensor *tensors_data, int tensors_len); // tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len); diff --git a/nn/varstore.go b/nn/varstore.go index e15d20a..ee9a2b3 100644 --- a/nn/varstore.go +++ b/nn/varstore.go @@ -192,6 +192,8 @@ func (vs *VarStore) Load(filepath string) error { x.Tensor.MustDrop() } + ts.CleanUp(2000) // 2 seconds + return nil } @@ -227,6 +229,9 @@ func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error { v.Tensor.Copy_(currTs) }) } + + ts.CleanUp(2000) + return nil } @@ -286,6 +291,8 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) { x.Tensor.MustDrop() } + ts.CleanUp(2000) + return missingVariables, nil } @@ -336,6 +343,8 @@ func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string, }) } + ts.CleanUp(2000) + return missingVariables, nil } @@ -406,6 +415,8 @@ func (vs *VarStore) Copy(src *VarStore) error { srcDevTs.MustDrop() } + ts.CleanUp(2000) + return nil } @@ -715,6 +726,8 @@ func (p *Path) toFloat(dtype gotch.DType) { } } } + + ts.CleanUp(2000) // 2 seconds } // ToFloat casts all variables in current path and subpaths to `Float` precision. @@ -762,7 +775,7 @@ func (p *Path) ToDevice(device gotch.Device) { } } - ts.CleanUp(2000) + ts.CleanUp(2000) // 2 seconds } // ZerosNoTrain creates a new variable initialized with zeros. diff --git a/ts/patch-example_test.go b/ts/patch-example_test.go index 361106a..7a956bf 100644 --- a/ts/patch-example_test.go +++ b/ts/patch-example_test.go @@ -43,3 +43,20 @@ func ExampleTensorSplitWithSizes(t *testing.T) { // 8 9 // [ CPUFloatType{4,2} ] } + +// Test Unbind op specific for BFloat16/Half +func TestTensorUnbind(t *testing.T) { + // device := gotch.CudaIfAvailable() + device := gotch.CPU + + dtype := gotch.BFloat16 + // dtype := gotch.Half // <- NOTE. Libtorch API Error: "arange_cpu" not implemented for 'Half' + + x := ts.MustArange(ts.IntScalar(60), dtype, device).MustView([]int64{3, 4, 5}, true) + + out := x.MustUnbind(0, true) + + if len(out) != 3 { + t.Errorf("Want 3, got %v\n", len(out)) + } +} diff --git a/ts/patch.go b/ts/patch.go index e039f98..54e0f2b 100644 --- a/ts/patch.go +++ b/ts/patch.go @@ -215,8 +215,8 @@ func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -257,8 +257,8 @@ func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -294,9 +294,8 @@ func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - // calculate the next pointer value - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -336,8 +335,8 @@ func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -367,8 +366,8 @@ func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -413,9 +412,8 @@ func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - // calculate the next pointer value - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -460,9 +458,8 @@ func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tenso currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - // calculate the next pointer value - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -488,7 +485,6 @@ func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (r // tensor *atg_unbind(tensor self, int64_t dim); func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) { - ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim) if err = TorchErr(); err != nil { return retVal, err @@ -496,9 +492,10 @@ func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) + for { - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break } @@ -532,8 +529,8 @@ func Where(condition Tensor) (retVal []*Tensor, err error) { currentPtr := ctensorsPtr retVal = append(retVal, newTensor(*currentPtr)) for { - nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) - if nextPtr == nil { + nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { break }