fixed check null at tensor ops return slice of tensors and clean-up
This commit is contained in:
parent
f9cb2f5cc6
commit
c1ee7689ad
2
go.mod
2
go.mod
|
@ -1,6 +1,6 @@
|
||||||
module github.com/sugarme/gotch
|
module github.com/sugarme/gotch
|
||||||
|
|
||||||
go 1.14
|
go 1.19
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
||||||
|
|
|
@ -4,10 +4,25 @@ package libtch
|
||||||
|
|
||||||
//#include "stdbool.h"
|
//#include "stdbool.h"
|
||||||
//#include "torch_api.h"
|
//#include "torch_api.h"
|
||||||
|
/*
|
||||||
|
bool is_null(int* pointer) {
|
||||||
|
if (NULL == pointer) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*/
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import "unsafe"
|
import "unsafe"
|
||||||
|
|
||||||
|
func IsNull(ctensor Ctensor) bool {
|
||||||
|
// return C.is_null(ctensor)
|
||||||
|
ret := C.is_null((*C.int)(unsafe.Pointer(ctensor)))
|
||||||
|
|
||||||
|
return (bool)(ret)
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`:
|
// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`:
|
||||||
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||||
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||||
|
|
|
@ -192,6 +192,8 @@ func (vs *VarStore) Load(filepath string) error {
|
||||||
x.Tensor.MustDrop()
|
x.Tensor.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.CleanUp(2000) // 2 seconds
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,6 +229,9 @@ func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error {
|
||||||
v.Tensor.Copy_(currTs)
|
v.Tensor.Copy_(currTs)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.CleanUp(2000)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,6 +291,8 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) {
|
||||||
x.Tensor.MustDrop()
|
x.Tensor.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.CleanUp(2000)
|
||||||
|
|
||||||
return missingVariables, nil
|
return missingVariables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,6 +343,8 @@ func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.CleanUp(2000)
|
||||||
|
|
||||||
return missingVariables, nil
|
return missingVariables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -406,6 +415,8 @@ func (vs *VarStore) Copy(src *VarStore) error {
|
||||||
srcDevTs.MustDrop()
|
srcDevTs.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.CleanUp(2000)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -715,6 +726,8 @@ func (p *Path) toFloat(dtype gotch.DType) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ts.CleanUp(2000) // 2 seconds
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||||
|
@ -762,7 +775,7 @@ func (p *Path) ToDevice(device gotch.Device) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.CleanUp(2000)
|
ts.CleanUp(2000) // 2 seconds
|
||||||
}
|
}
|
||||||
|
|
||||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||||
|
|
|
@ -43,3 +43,20 @@ func ExampleTensorSplitWithSizes(t *testing.T) {
|
||||||
// 8 9
|
// 8 9
|
||||||
// [ CPUFloatType{4,2} ]
|
// [ CPUFloatType{4,2} ]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test Unbind op specific for BFloat16/Half
|
||||||
|
func TestTensorUnbind(t *testing.T) {
|
||||||
|
// device := gotch.CudaIfAvailable()
|
||||||
|
device := gotch.CPU
|
||||||
|
|
||||||
|
dtype := gotch.BFloat16
|
||||||
|
// dtype := gotch.Half // <- NOTE. Libtorch API Error: "arange_cpu" not implemented for 'Half'
|
||||||
|
|
||||||
|
x := ts.MustArange(ts.IntScalar(60), dtype, device).MustView([]int64{3, 4, 5}, true)
|
||||||
|
|
||||||
|
out := x.MustUnbind(0, true)
|
||||||
|
|
||||||
|
if len(out) != 3 {
|
||||||
|
t.Errorf("Want 3, got %v\n", len(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
41
ts/patch.go
41
ts/patch.go
|
@ -215,8 +215,8 @@ func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
if nextPtr == nil {
|
if *nextPtr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,8 +257,8 @@ func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
if nextPtr == nil {
|
if *nextPtr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -294,9 +294,8 @@ func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) {
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
// calculate the next pointer value
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
if *nextPtr == nil {
|
||||||
if nextPtr == nil {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -336,8 +335,8 @@ func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) {
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
if nextPtr == nil {
|
if *nextPtr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,8 +366,8 @@ func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) {
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
if nextPtr == nil {
|
if *nextPtr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -413,9 +412,8 @@ func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) {
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
// calculate the next pointer value
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
if *nextPtr == nil {
|
||||||
if nextPtr == nil {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -460,9 +458,8 @@ func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tenso
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
// calculate the next pointer value
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
if *nextPtr == nil {
|
||||||
if nextPtr == nil {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -488,7 +485,6 @@ func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (r
|
||||||
|
|
||||||
// tensor *atg_unbind(tensor self, int64_t dim);
|
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||||
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
|
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
|
||||||
|
|
||||||
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
|
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
|
@ -496,9 +492,10 @@ func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
|
||||||
|
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
if nextPtr == nil {
|
if *nextPtr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -532,8 +529,8 @@ func Where(condition Tensor) (retVal []*Tensor, err error) {
|
||||||
currentPtr := ctensorsPtr
|
currentPtr := ctensorsPtr
|
||||||
retVal = append(retVal, newTensor(*currentPtr))
|
retVal = append(retVal, newTensor(*currentPtr))
|
||||||
for {
|
for {
|
||||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||||
if nextPtr == nil {
|
if *nextPtr == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user