fixed wrong tensor method 'Meshgrid'

This commit is contained in:
sugarme 2022-01-17 17:18:54 +11:00
parent 653caf4be5
commit 45ff5b2e5b
3 changed files with 5 additions and 7 deletions

View File

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
- added `nn.Path.Paths()` method
- added `nn.VarStore.Summary()` method
- fixed incorrect tensor method `ts.Meshgrid` -> `Meshgrid`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -202,6 +202,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
case "Tuple":
val := reflect.Indirect(reflect.ValueOf(iv.value))
fmt.Printf("val: %v\n", val)
switch {
// 1. Tuple is (Tensor, Tensor)
case val.Type() == reflect.TypeOf([]Tensor{}):

View File

@ -321,7 +321,7 @@ func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor)
}
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
func (ts *Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
func Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
@ -348,12 +348,8 @@ func (ts *Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
return retVal, nil
}
func (ts *Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.Meshgrid(tensors)
func MustMeshgrid(tensors []Tensor) (retVal []Tensor) {
retVal, err := Meshgrid(tensors)
if err != nil {
log.Fatal(err)
}