fixed ununique tensor name caused double free tensor error

This commit is contained in:
sugarme 2023-08-31 22:42:07 +10:00
parent 1cffab577c
commit f0b87eb001

View File

@ -50,9 +50,10 @@ type bigStruct struct {
//
// For heap allocation see. https://stackoverflow.com/questions/10866195
type Tensor struct {
d *bigStruct
name string
ctensor lib.Ctensor
d *bigStruct
name string
ctensor lib.Ctensor
calledFrom string
}
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
@ -63,20 +64,26 @@ func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
x := new(Tensor)
x.ctensor = ctensor
x.name = name
x.d = new(bigStruct)
atomic.AddInt64(&TensorCount, 1)
nbytes := x.nbytes()
atomic.AddInt64(&AllocatedMem, nbytes)
lock.Lock()
if _, ok := ExistingTensors[name]; ok {
name = fmt.Sprintf("%s_%09d", name, TensorCount)
}
ExistingTensors[name] = struct{}{}
lock.Unlock()
x.name = name
if gotch.Debug {
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n", x.name, nbytes)
}
x.calledFrom = "newTensor()"
runtime.SetFinalizer(x, freeCTensor)
return x
@ -124,20 +131,33 @@ func (ts *Tensor) Ctensor() unsafe.Pointer {
// free releases C allocated memory.
func freeCTensor(ts *Tensor) error {
lock.Lock()
defer lock.Unlock()
// Just return if it has been deleted previously!
if unsafe.Pointer(ts.ctensor) == nil {
if gotch.Debug {
log.Printf("INFO: ctensor is nil. Nothing to delete here...\n")
}
if ts == nil || ts.ctensor == nil {
return nil
}
nbytes := ts.nbytes()
atomic.AddInt64(&AllocatedMem, -nbytes)
delete(ExistingTensors, ts.name)
lock.Lock()
defer lock.Unlock()
if _, ok := ExistingTensors[ts.name]; !ok {
log.Printf("WARNING: Probably double free tensor %q. Called from %q. Just skipping...\n", ts.name, ts.calledFrom)
return nil
}
if gotch.Debug {
shape, err := ts.Size()
if err != nil {
err = fmt.Errorf("ERROR: failed to release tensor %q: %w\n", ts.name, err)
}
log.Printf(err.Error())
numel := uint(FlattenDim(shape))
dtype := ts.DType()
nbytes := int64(numel * dtype.Size())
atomic.AddInt64(&AllocatedMem, -nbytes)
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
}
lib.AtFree(ts.ctensor)
if err := TorchErr(); err != nil {
@ -145,9 +165,7 @@ func freeCTensor(ts *Tensor) error {
return err
}
if gotch.Debug {
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
}
delete(ExistingTensors, ts.name)
// IMPORTANT. make it nil so won't double free.
ts.ctensor = nil
@ -198,6 +216,11 @@ func (ts *Tensor) Dim() uint64 {
// to that slice.
func (ts *Tensor) Size() ([]int64, error) {
dim := lib.AtDim(ts.ctensor)
if dim < 0 || dim > 100 {
err := fmt.Errorf("Invalid dim: %v\n", dim)
return nil, err
}
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
@ -1163,13 +1186,22 @@ func (ts *Tensor) MustToString(lw int64) string {
// Drop drops (frees) the tensor
func (ts *Tensor) Drop() error {
if ts.ctensor == nil {
return nil
}
// Clear SetFinalizer on ts so no double free tensor.
// Ref. https://pkg.go.dev/runtime#SetFinalizer
runtime.SetFinalizer(ts, nil)
ts.calledFrom = "ts.Drop()"
return freeCTensor(ts)
}
// MustDrop drops the tensor. It will be panic if error
func (ts *Tensor) MustDrop() {
if err := ts.Drop(); err != nil {
log.Fatal(err)
panic(err)
}
}