fixed check null at tensor ops return slice of tensors and clean-up

This commit is contained in:
sugarme 2023-07-07 16:08:15 +10:00
parent f9cb2f5cc6
commit c1ee7689ad
5 changed files with 66 additions and 24 deletions

2
go.mod
View File

@ -1,6 +1,6 @@
module github.com/sugarme/gotch
go 1.14
go 1.19
require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0

View File

@ -4,10 +4,25 @@ package libtch
//#include "stdbool.h"
//#include "torch_api.h"
/*
bool is_null(int* pointer) {
if (NULL == pointer) {
return true;
}
return false;
}
*/
import "C"
import "unsafe"
func IsNull(ctensor Ctensor) bool {
// return C.is_null(ctensor)
ret := C.is_null((*C.int)(unsafe.Pointer(ctensor)))
return (bool)(ret)
}
// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`:
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);

View File

@ -192,6 +192,8 @@ func (vs *VarStore) Load(filepath string) error {
x.Tensor.MustDrop()
}
ts.CleanUp(2000) // 2 seconds
return nil
}
@ -227,6 +229,9 @@ func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error {
v.Tensor.Copy_(currTs)
})
}
ts.CleanUp(2000)
return nil
}
@ -286,6 +291,8 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) {
x.Tensor.MustDrop()
}
ts.CleanUp(2000)
return missingVariables, nil
}
@ -336,6 +343,8 @@ func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string,
})
}
ts.CleanUp(2000)
return missingVariables, nil
}
@ -406,6 +415,8 @@ func (vs *VarStore) Copy(src *VarStore) error {
srcDevTs.MustDrop()
}
ts.CleanUp(2000)
return nil
}
@ -715,6 +726,8 @@ func (p *Path) toFloat(dtype gotch.DType) {
}
}
}
ts.CleanUp(2000) // 2 seconds
}
// ToFloat casts all variables in current path and subpaths to `Float` precision.
@ -762,7 +775,7 @@ func (p *Path) ToDevice(device gotch.Device) {
}
}
ts.CleanUp(2000)
ts.CleanUp(2000) // 2 seconds
}
// ZerosNoTrain creates a new variable initialized with zeros.

View File

@ -43,3 +43,20 @@ func ExampleTensorSplitWithSizes(t *testing.T) {
// 8 9
// [ CPUFloatType{4,2} ]
}
// Test Unbind op specific for BFloat16/Half
func TestTensorUnbind(t *testing.T) {
// device := gotch.CudaIfAvailable()
device := gotch.CPU
dtype := gotch.BFloat16
// dtype := gotch.Half // <- NOTE. Libtorch API Error: "arange_cpu" not implemented for 'Half'
x := ts.MustArange(ts.IntScalar(60), dtype, device).MustView([]int64{3, 4, 5}, true)
out := x.MustUnbind(0, true)
if len(out) != 3 {
t.Errorf("Want 3, got %v\n", len(out))
}
}

View File

@ -215,8 +215,8 @@ func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -257,8 +257,8 @@ func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -294,9 +294,8 @@ func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -336,8 +335,8 @@ func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -367,8 +366,8 @@ func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -413,9 +412,8 @@ func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -460,9 +458,8 @@ func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tenso
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -488,7 +485,6 @@ func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (r
// tensor *atg_unbind(tensor self, int64_t dim);
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
if err = TorchErr(); err != nil {
return retVal, err
@ -496,9 +492,10 @@ func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
@ -532,8 +529,8 @@ func Where(condition Tensor) (retVal []*Tensor, err error) {
currentPtr := ctensorsPtr
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if nextPtr == nil {
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}