diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b2ae46..ffb2f9b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - reworked `ts.Format()` - Added conv2d benchmark - Fixed #88 memory leak at `example/char-rnn` -- Added missing tensor `Stride()` +- Added missing tensor `Stride()` and `MustDataPtr()` ## [Nofix] - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. diff --git a/ts/tensor.go b/ts/tensor.go index 7aadd7a..208f1d1 100644 --- a/ts/tensor.go +++ b/ts/tensor.go @@ -674,6 +674,15 @@ func (ts *Tensor) DataPtr() (unsafe.Pointer, error) { return datPtr, nil } +func (ts *Tensor) MustDataPtr() unsafe.Pointer { + p, err := ts.DataPtr() + if err != nil { + panic(err) + } + + return p +} + // Defined returns true is the tensor is defined. func (ts *Tensor) Defined() (bool, error) { state := lib.AtDefined(ts.ctensor)