fixed check null at tensor ops return slice of tensors and clean-up
This commit is contained in:
parent
f9cb2f5cc6
commit
c1ee7689ad
2
go.mod
2
go.mod
|
@ -1,6 +1,6 @@
|
|||
module github.com/sugarme/gotch
|
||||
|
||||
go 1.14
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
||||
|
|
|
@ -4,10 +4,25 @@ package libtch
|
|||
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
/*
|
||||
bool is_null(int* pointer) {
|
||||
if (NULL == pointer) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func IsNull(ctensor Ctensor) bool {
|
||||
// return C.is_null(ctensor)
|
||||
ret := C.is_null((*C.int)(unsafe.Pointer(ctensor)))
|
||||
|
||||
return (bool)(ret)
|
||||
}
|
||||
|
||||
// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`:
|
||||
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||
|
|
|
@ -192,6 +192,8 @@ func (vs *VarStore) Load(filepath string) error {
|
|||
x.Tensor.MustDrop()
|
||||
}
|
||||
|
||||
ts.CleanUp(2000) // 2 seconds
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -227,6 +229,9 @@ func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error {
|
|||
v.Tensor.Copy_(currTs)
|
||||
})
|
||||
}
|
||||
|
||||
ts.CleanUp(2000)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -286,6 +291,8 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) {
|
|||
x.Tensor.MustDrop()
|
||||
}
|
||||
|
||||
ts.CleanUp(2000)
|
||||
|
||||
return missingVariables, nil
|
||||
}
|
||||
|
||||
|
@ -336,6 +343,8 @@ func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string,
|
|||
})
|
||||
}
|
||||
|
||||
ts.CleanUp(2000)
|
||||
|
||||
return missingVariables, nil
|
||||
}
|
||||
|
||||
|
@ -406,6 +415,8 @@ func (vs *VarStore) Copy(src *VarStore) error {
|
|||
srcDevTs.MustDrop()
|
||||
}
|
||||
|
||||
ts.CleanUp(2000)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -715,6 +726,8 @@ func (p *Path) toFloat(dtype gotch.DType) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
ts.CleanUp(2000) // 2 seconds
|
||||
}
|
||||
|
||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||
|
@ -762,7 +775,7 @@ func (p *Path) ToDevice(device gotch.Device) {
|
|||
}
|
||||
}
|
||||
|
||||
ts.CleanUp(2000)
|
||||
ts.CleanUp(2000) // 2 seconds
|
||||
}
|
||||
|
||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||
|
|
|
@ -43,3 +43,20 @@ func ExampleTensorSplitWithSizes(t *testing.T) {
|
|||
// 8 9
|
||||
// [ CPUFloatType{4,2} ]
|
||||
}
|
||||
|
||||
// Test Unbind op specific for BFloat16/Half
|
||||
func TestTensorUnbind(t *testing.T) {
|
||||
// device := gotch.CudaIfAvailable()
|
||||
device := gotch.CPU
|
||||
|
||||
dtype := gotch.BFloat16
|
||||
// dtype := gotch.Half // <- NOTE. Libtorch API Error: "arange_cpu" not implemented for 'Half'
|
||||
|
||||
x := ts.MustArange(ts.IntScalar(60), dtype, device).MustView([]int64{3, 4, 5}, true)
|
||||
|
||||
out := x.MustUnbind(0, true)
|
||||
|
||||
if len(out) != 3 {
|
||||
t.Errorf("Want 3, got %v\n", len(out))
|
||||
}
|
||||
}
|
||||
|
|
41
ts/patch.go
41
ts/patch.go
|
@ -215,8 +215,8 @@ func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -257,8 +257,8 @@ func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -294,9 +294,8 @@ func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) {
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -336,8 +335,8 @@ func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) {
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -367,8 +366,8 @@ func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) {
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -413,9 +412,8 @@ func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) {
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -460,9 +458,8 @@ func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tenso
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -488,7 +485,6 @@ func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (r
|
|||
|
||||
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
|
||||
|
||||
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
|
@ -496,9 +492,10 @@ func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
|
|||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
@ -532,8 +529,8 @@ func Where(condition Tensor) (retVal []*Tensor, err error) {
|
|||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if nextPtr == nil {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user