tensor/Drop: updated comment

This commit is contained in:
sugarme 2020-10-24 09:36:06 +11:00
parent c9092b1104
commit 1914aac74e
2 changed files with 22 additions and 6 deletions

View File

@ -1,7 +1,7 @@
package main
import (
// "fmt"
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
@ -13,4 +13,10 @@ func main() {
tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
tensor.Print()
fmt.Printf("tensor is nil: %v\n", tensor.IsNil())
tensor.MustDrop()
fmt.Printf("tensor is nil: %v\n", tensor.IsNil())
}

View File

@ -874,6 +874,7 @@ func (ts Tensor) MustToString(lw int64) (retVal string) {
// Drop drops (frees) the tensor
func (ts Tensor) Drop() (err error) {
// has not freed yet.
if !ts.MustDefined() {
return nil
}
@ -883,11 +884,13 @@ func (ts Tensor) Drop() (err error) {
return err
}
// NOTE. assign to a new undefined tensor, then check `ts.MustDefined`
// before deleting at C land. Hence `Drop` method can be called
// multiple times without worrying about double C memory delete panic.
// Other pattern is `defer ts.MustDrop()` whenever a tensor is created.
ts = NewTensor()
// NOTE. there is no reliable way to tell if a pointer has been freed.
// Ref. https://stackoverflow.com/questions/8300853
// This is a hacky way: as soon as tensor is free up, turn it into "undefined" tensor.
// So that next time, call `ts.MustDefined()` to check whether tensor has been freed.
// This is useful as when can call `ts.MustDrop()` multiple times without worrying
// about "free(): double free detected..." error.
ts = NewTensor() // ts now is "undefined"
return nil
}
@ -1165,3 +1168,10 @@ func (ts Tensor) Swish() (retVal Tensor) {
func (ts Tensor) AvgPool2DDefault(ksize int64, del bool) (retVal Tensor) {
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
}
func (ts Tensor) IsNil() bool {
C.free(unsafe.Pointer(ts.ctensor))
return unsafe.Pointer(ts.ctensor) == nil
}