feat(tensor): ts.MustDrop() now can call multiple times

This commit is contained in:
sugarme 2020-10-24 00:29:23 +11:00
parent 1f6c972007
commit 95db45896e
2 changed files with 12 additions and 6 deletions

View File

@ -1,7 +1,7 @@
package main
import (
"fmt"
// "fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
@ -13,9 +13,4 @@ func main() {
tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
tensor.Print()
nilTs := ts.NewTensor()
fmt.Printf("nilTs val: %v", nilTs)
}

View File

@ -873,11 +873,22 @@ func (ts Tensor) MustToString(lw int64) (retVal string) {
// Drop drops (frees) the tensor
func (ts Tensor) Drop() (err error) {
if !ts.MustDefined() {
return nil
}
lib.AtFree(ts.ctensor)
if err = TorchErr(); err != nil {
return err
}
// NOTE. assign to a new undefined tensor, then check `ts.MustDefined`
// before deleting at C land. Hence `Drop` method can be called
// multiple times without worrying about double C memory delete panic.
// Other pattern is `defer ts.MustDrop()` whenever a tensor is created.
ts = NewTensor()
return nil
}