added tensor.FromCtensor()

This commit is contained in:
sugarme 2022-02-16 11:39:27 +11:00
parent 73d6c0ae86
commit 2d5031009b
2 changed files with 11 additions and 0 deletions

View File

@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- fixed make `tensor.ValueGo()` returning `[]int` instead of `[]int32`
- added more building block modules: Dropout, MaxPool2D, Parameter, Identity
- added nn.BatchNorm.Forward() with default training=true
- added exposing `tensor.Ctensor()`
- added API `tensor.FromCtensor()`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -21,6 +21,10 @@ type Tensor struct {
ctensor lib.Ctensor
}
func (ts *Tensor) Ctensor() unsafe.Pointer {
return unsafe.Pointer(ts.ctensor)
}
// None is an undefined tensor.
// It can be used in optional tensor parameter where 'None' value used.
// `ts.MustDefined()` function is used for checking 'null'
@ -32,6 +36,11 @@ func NewTensor() *Tensor {
return &Tensor{ctensor}
}
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
cts := (lib.Ctensor)(ctensor)
return &Tensor{cts}
}
func (ts *Tensor) Dim() uint64 {
dim := lib.AtDim(ts.ctensor)
if err := TorchErr(); err != nil {