diff --git a/example/basic/main.go b/example/basic/main.go index f58c056..57951e4 100644 --- a/example/basic/main.go +++ b/example/basic/main.go @@ -1,7 +1,7 @@ package main import ( - // "fmt" + "fmt" "github.com/sugarme/gotch" ts "github.com/sugarme/gotch/tensor" @@ -13,4 +13,10 @@ func main() { tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true) tensor.Print() + + fmt.Printf("tensor is nil: %v\n", tensor.IsNil()) + + tensor.MustDrop() + + fmt.Printf("tensor is nil: %v\n", tensor.IsNil()) } diff --git a/tensor/tensor.go b/tensor/tensor.go index 4a752b8..af38988 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -874,6 +874,7 @@ func (ts Tensor) MustToString(lw int64) (retVal string) { // Drop drops (frees) the tensor func (ts Tensor) Drop() (err error) { + // has not freed yet. if !ts.MustDefined() { return nil } @@ -883,11 +884,13 @@ func (ts Tensor) Drop() (err error) { return err } - // NOTE. assign to a new undefined tensor, then check `ts.MustDefined` - // before deleting at C land. Hence `Drop` method can be called - // multiple times without worrying about double C memory delete panic. - // Other pattern is `defer ts.MustDrop()` whenever a tensor is created. - ts = NewTensor() + // NOTE. there is no reliable way to tell if a pointer has been freed. + // Ref. https://stackoverflow.com/questions/8300853 + // This is a hacky way: as soon as tensor is free up, turn it into "undefined" tensor. + // So that next time, call `ts.MustDefined()` to check whether tensor has been freed. + // This is useful as when can call `ts.MustDrop()` multiple times without worrying + // about "free(): double free detected..." error. + ts = NewTensor() // ts now is "undefined" return nil } @@ -1165,3 +1168,10 @@ func (ts Tensor) Swish() (retVal Tensor) { func (ts Tensor) AvgPool2DDefault(ksize int64, del bool) (retVal Tensor) { return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del) } + +func (ts Tensor) IsNil() bool { + + C.free(unsafe.Pointer(ts.ctensor)) + + return unsafe.Pointer(ts.ctensor) == nil +}