fixed ununique tensor name caused double free tensor error
This commit is contained in:
parent
1cffab577c
commit
f0b87eb001
70
ts/tensor.go
70
ts/tensor.go
|
@ -50,9 +50,10 @@ type bigStruct struct {
|
|||
//
|
||||
// For heap allocation see. https://stackoverflow.com/questions/10866195
|
||||
type Tensor struct {
|
||||
d *bigStruct
|
||||
name string
|
||||
ctensor lib.Ctensor
|
||||
d *bigStruct
|
||||
name string
|
||||
ctensor lib.Ctensor
|
||||
calledFrom string
|
||||
}
|
||||
|
||||
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
|
||||
|
@ -63,20 +64,26 @@ func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
|
|||
|
||||
x := new(Tensor)
|
||||
x.ctensor = ctensor
|
||||
x.name = name
|
||||
x.d = new(bigStruct)
|
||||
|
||||
atomic.AddInt64(&TensorCount, 1)
|
||||
nbytes := x.nbytes()
|
||||
atomic.AddInt64(&AllocatedMem, nbytes)
|
||||
lock.Lock()
|
||||
if _, ok := ExistingTensors[name]; ok {
|
||||
name = fmt.Sprintf("%s_%09d", name, TensorCount)
|
||||
}
|
||||
ExistingTensors[name] = struct{}{}
|
||||
lock.Unlock()
|
||||
|
||||
x.name = name
|
||||
|
||||
if gotch.Debug {
|
||||
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n", x.name, nbytes)
|
||||
}
|
||||
|
||||
x.calledFrom = "newTensor()"
|
||||
|
||||
runtime.SetFinalizer(x, freeCTensor)
|
||||
|
||||
return x
|
||||
|
@ -124,20 +131,33 @@ func (ts *Tensor) Ctensor() unsafe.Pointer {
|
|||
|
||||
// free releases C allocated memory.
|
||||
func freeCTensor(ts *Tensor) error {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
// Just return if it has been deleted previously!
|
||||
if unsafe.Pointer(ts.ctensor) == nil {
|
||||
if gotch.Debug {
|
||||
log.Printf("INFO: ctensor is nil. Nothing to delete here...\n")
|
||||
}
|
||||
if ts == nil || ts.ctensor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
nbytes := ts.nbytes()
|
||||
atomic.AddInt64(&AllocatedMem, -nbytes)
|
||||
delete(ExistingTensors, ts.name)
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if _, ok := ExistingTensors[ts.name]; !ok {
|
||||
log.Printf("WARNING: Probably double free tensor %q. Called from %q. Just skipping...\n", ts.name, ts.calledFrom)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if gotch.Debug {
|
||||
shape, err := ts.Size()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ERROR: failed to release tensor %q: %w\n", ts.name, err)
|
||||
}
|
||||
log.Printf(err.Error())
|
||||
|
||||
numel := uint(FlattenDim(shape))
|
||||
dtype := ts.DType()
|
||||
nbytes := int64(numel * dtype.Size())
|
||||
atomic.AddInt64(&AllocatedMem, -nbytes)
|
||||
|
||||
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
|
||||
}
|
||||
|
||||
lib.AtFree(ts.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
|
@ -145,9 +165,7 @@ func freeCTensor(ts *Tensor) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if gotch.Debug {
|
||||
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
|
||||
}
|
||||
delete(ExistingTensors, ts.name)
|
||||
|
||||
// IMPORTANT. make it nil so won't double free.
|
||||
ts.ctensor = nil
|
||||
|
@ -198,6 +216,11 @@ func (ts *Tensor) Dim() uint64 {
|
|||
// to that slice.
|
||||
func (ts *Tensor) Size() ([]int64, error) {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
if dim < 0 || dim > 100 {
|
||||
err := fmt.Errorf("Invalid dim: %v\n", dim)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sz := make([]int64, dim)
|
||||
szPtr, err := DataAsPtr(sz)
|
||||
if err != nil {
|
||||
|
@ -1163,13 +1186,22 @@ func (ts *Tensor) MustToString(lw int64) string {
|
|||
|
||||
// Drop drops (frees) the tensor
|
||||
func (ts *Tensor) Drop() error {
|
||||
if ts.ctensor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clear SetFinalizer on ts so no double free tensor.
|
||||
// Ref. https://pkg.go.dev/runtime#SetFinalizer
|
||||
runtime.SetFinalizer(ts, nil)
|
||||
|
||||
ts.calledFrom = "ts.Drop()"
|
||||
return freeCTensor(ts)
|
||||
}
|
||||
|
||||
// MustDrop drops the tensor. It will be panic if error
|
||||
func (ts *Tensor) MustDrop() {
|
||||
if err := ts.Drop(); err != nil {
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user