reverse changed #10

This commit is contained in:
sugarme 2020-10-27 17:30:58 +11:00
parent 1914aac74e
commit d93cf1b996

View File

@ -873,25 +873,11 @@ func (ts Tensor) MustToString(lw int64) (retVal string) {
// Drop drops (frees) the tensor
func (ts Tensor) Drop() (err error) {
// has not freed yet.
if !ts.MustDefined() {
return nil
}
lib.AtFree(ts.ctensor)
if err = TorchErr(); err != nil {
return err
}
// NOTE. there is no reliable way to tell if a pointer has been freed.
// Ref. https://stackoverflow.com/questions/8300853
// This is a hacky way: as soon as tensor is free up, turn it into "undefined" tensor.
// So that next time, call `ts.MustDefined()` to check whether tensor has been freed.
// This is useful as when can call `ts.MustDrop()` multiple times without worrying
// about "free(): double free detected..." error.
ts = NewTensor() // ts now is "undefined"
return nil
}
@ -1168,10 +1154,3 @@ func (ts Tensor) Swish() (retVal Tensor) {
func (ts Tensor) AvgPool2DDefault(ksize int64, del bool) (retVal Tensor) {
return ts.MustAvgPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, false, true, 1, del)
}
func (ts Tensor) IsNil() bool {
C.free(unsafe.Pointer(ts.ctensor))
return unsafe.Pointer(ts.ctensor) == nil
}