fixed wrong tensor method 'Meshgrid'

This commit is contained in:
sugarme 2022-01-17 17:18:54 +11:00
parent 653caf4be5
commit 45ff5b2e5b
3 changed files with 5 additions and 7 deletions

View File

@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
- added `nn.Path.Paths()` method - added `nn.Path.Paths()` method
- added `nn.VarStore.Summary()` method - added `nn.VarStore.Summary()` method
- fixed incorrect tensor method `ts.Meshgrid` -> `Meshgrid`
## [Nofix] ## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -202,6 +202,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
case "Tuple": case "Tuple":
val := reflect.Indirect(reflect.ValueOf(iv.value)) val := reflect.Indirect(reflect.ValueOf(iv.value))
fmt.Printf("val: %v\n", val)
switch { switch {
// 1. Tuple is (Tensor, Tensor) // 1. Tuple is (Tensor, Tensor)
case val.Type() == reflect.TypeOf([]Tensor{}): case val.Type() == reflect.TypeOf([]Tensor{}):

View File

@ -321,7 +321,7 @@ func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor)
} }
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len); // tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
func (ts *Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) { func Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
var ctensors []lib.Ctensor var ctensors []lib.Ctensor
for _, t := range tensors { for _, t := range tensors {
@ -348,12 +348,8 @@ func (ts *Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts *Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) { func MustMeshgrid(tensors []Tensor) (retVal []Tensor) {
if del { retVal, err := Meshgrid(tensors)
defer ts.MustDrop()
}
retVal, err := ts.Meshgrid(tensors)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }