feat(wrapper/tensor): iter

This commit is contained in:
sugarme 2020-06-13 09:52:34 +10:00
parent e188f4965b
commit 49e0335469
2 changed files with 93 additions and 0 deletions

View File

@ -0,0 +1,32 @@
package main
import (
"fmt"
"reflect"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
data := [][]int64{
{1, 1, 1, 2, 2, 2, 3, 3},
{1, 1, 1, 2, 2, 2, 4, 4},
}
shape := []int64{16}
ts, err := wrapper.NewTensorFromData(data, shape)
if err != nil {
panic(err)
}
it, err := ts.Iter(reflect.Float64)
if err != nil {
panic(err)
}
for i := 0; i < int(it.Len); i++ {
v := it.Next()
fmt.Println(v)
}
}

61
wrapper/iter.go Normal file
View File

@ -0,0 +1,61 @@
package wrapper
import (
"fmt"
"log"
"reflect"
)
type Iterator interface {
Next() interface{}
}
type Iterable struct {
Index int64
Len int64
Content Tensor
ItemKind reflect.Kind
}
// Next implements Iterator interface
func (it *Iterable) Next() (retVal interface{}) {
var err error
switch it.ItemKind {
case reflect.Int64:
retVal, err = it.Content.Int64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
}
it.Index += 1
case reflect.Float64:
retVal, err = it.Content.Float64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
}
it.Index += 1
default:
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
log.Fatal(err)
}
return retVal
}
// Iter creates an iterable object with specified item type.
func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) {
num, err := ts.Size1() // size for 1D tensor
if err != nil {
return retVal, err
}
content, err := ts.ShallowClone()
if err != nil {
return retVal, err
}
return Iterable{
Index: 0,
Len: num,
Content: content,
ItemKind: kind,
}, nil
}