tensor/Drop: updated comment
This commit is contained in:
parent
c9092b1104
commit
1914aac74e
|
@ -1,7 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
|
@ -13,4 +13,10 @@ func main() {
|
|||
tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
|
||||
|
||||
tensor.Print()
|
||||
|
||||
fmt.Printf("tensor is nil: %v\n", tensor.IsNil())
|
||||
|
||||
tensor.MustDrop()
|
||||
|
||||
fmt.Printf("tensor is nil: %v\n", tensor.IsNil())
|
||||
}
|
||||
|
|
|
@ -874,6 +874,7 @@ func (ts Tensor) MustToString(lw int64) (retVal string) {
|
|||
// Drop drops (frees) the tensor
|
||||
func (ts Tensor) Drop() (err error) {
|
||||
|
||||
// has not freed yet.
|
||||
if !ts.MustDefined() {
|
||||
return nil
|
||||
}
|
||||
|
@ -883,11 +884,13 @@ func (ts Tensor) Drop() (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
// NOTE. assign to a new undefined tensor, then check `ts.MustDefined`
|
||||
// before deleting at C land. Hence `Drop` method can be called
|
||||
// multiple times without worrying about double C memory delete panic.
|
||||
// Other pattern is `defer ts.MustDrop()` whenever a tensor is created.
|
||||
ts = NewTensor()
|
||||
// NOTE. there is no reliable way to tell if a pointer has been freed.
|
||||
// Ref. https://stackoverflow.com/questions/8300853
|
||||
// This is a hacky way: as soon as tensor is free up, turn it into "undefined" tensor.
|
||||
// So that next time, call `ts.MustDefined()` to check whether tensor has been freed.
|
||||
// This is useful as when can call `ts.MustDrop()` multiple times without worrying
|
||||
// about "free(): double free detected..." error.
|
||||
ts = NewTensor() // ts now is "undefined"
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -1165,3 +1168,10 @@ func (ts Tensor) Swish() (retVal Tensor) {
|
|||
func (ts Tensor) AvgPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
||||
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
|
||||
}
|
||||
|
||||
func (ts Tensor) IsNil() bool {
|
||||
|
||||
C.free(unsafe.Pointer(ts.ctensor))
|
||||
|
||||
return unsafe.Pointer(ts.ctensor) == nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user