reverse changed #10
This commit is contained in:
parent
1914aac74e
commit
d93cf1b996
|
@ -873,25 +873,11 @@ func (ts Tensor) MustToString(lw int64) (retVal string) {
|
|||
|
||||
// Drop drops (frees) the tensor
|
||||
func (ts Tensor) Drop() (err error) {
|
||||
|
||||
// has not freed yet.
|
||||
if !ts.MustDefined() {
|
||||
return nil
|
||||
}
|
||||
|
||||
lib.AtFree(ts.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// NOTE. there is no reliable way to tell if a pointer has been freed.
|
||||
// Ref. https://stackoverflow.com/questions/8300853
|
||||
// This is a hacky way: as soon as tensor is free up, turn it into "undefined" tensor.
|
||||
// So that next time, call `ts.MustDefined()` to check whether tensor has been freed.
|
||||
// This is useful as when can call `ts.MustDrop()` multiple times without worrying
|
||||
// about "free(): double free detected..." error.
|
||||
ts = NewTensor() // ts now is "undefined"
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -1168,10 +1154,3 @@ func (ts Tensor) Swish() (retVal Tensor) {
|
|||
func (ts Tensor) AvgPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
||||
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
|
||||
}
|
||||
|
||||
func (ts Tensor) IsNil() bool {
|
||||
|
||||
C.free(unsafe.Pointer(ts.ctensor))
|
||||
|
||||
return unsafe.Pointer(ts.ctensor) == nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user