fixed wrong tensor method 'Meshgrid'
This commit is contained in:
parent
653caf4be5
commit
45ff5b2e5b
|
@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
## [Unreleased]
|
||||
- added `nn.Path.Paths()` method
|
||||
- added `nn.VarStore.Summary()` method
|
||||
- fixed incorrect tensor method `ts.Meshgrid` -> `Meshgrid`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -202,6 +202,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
|
|||
|
||||
case "Tuple":
|
||||
val := reflect.Indirect(reflect.ValueOf(iv.value))
|
||||
fmt.Printf("val: %v\n", val)
|
||||
switch {
|
||||
// 1. Tuple is (Tensor, Tensor)
|
||||
case val.Type() == reflect.TypeOf([]Tensor{}):
|
||||
|
|
|
@ -321,7 +321,7 @@ func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor)
|
|||
}
|
||||
|
||||
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
||||
func (ts *Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
func Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
|
@ -348,12 +348,8 @@ func (ts *Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
retVal, err := ts.Meshgrid(tensors)
|
||||
func MustMeshgrid(tensors []Tensor) (retVal []Tensor) {
|
||||
retVal, err := Meshgrid(tensors)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user