added tensor.FromCtensor()
This commit is contained in:
parent
73d6c0ae86
commit
2d5031009b
|
@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- fixed make `tensor.ValueGo()` returning `[]int` instead of `[]int32`
|
||||
- added more building block modules: Dropout, MaxPool2D, Parameter, Identity
|
||||
- added nn.BatchNorm.Forward() with default training=true
|
||||
- added exposing `tensor.Ctensor()`
|
||||
- added API `tensor.FromCtensor()`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -21,6 +21,10 @@ type Tensor struct {
|
|||
ctensor lib.Ctensor
|
||||
}
|
||||
|
||||
func (ts *Tensor) Ctensor() unsafe.Pointer {
|
||||
return unsafe.Pointer(ts.ctensor)
|
||||
}
|
||||
|
||||
// None is an undefined tensor.
|
||||
// It can be used in optional tensor parameter where 'None' value used.
|
||||
// `ts.MustDefined()` function is used for checking 'null'
|
||||
|
@ -32,6 +36,11 @@ func NewTensor() *Tensor {
|
|||
return &Tensor{ctensor}
|
||||
}
|
||||
|
||||
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
|
||||
cts := (lib.Ctensor)(ctensor)
|
||||
return &Tensor{cts}
|
||||
}
|
||||
|
||||
func (ts *Tensor) Dim() uint64 {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
|
|
Loading…
Reference in New Issue
Block a user