diff --git a/example/tensor-iterator/main.go b/example/tensor-iterator/main.go new file mode 100644 index 0000000..4c8c94c --- /dev/null +++ b/example/tensor-iterator/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "reflect" + + wrapper "github.com/sugarme/gotch/wrapper" +) + +func main() { + data := [][]int64{ + {1, 1, 1, 2, 2, 2, 3, 3}, + {1, 1, 1, 2, 2, 2, 4, 4}, + } + shape := []int64{16} + + ts, err := wrapper.NewTensorFromData(data, shape) + if err != nil { + panic(err) + } + + it, err := ts.Iter(reflect.Float64) + if err != nil { + panic(err) + } + + for i := 0; i < int(it.Len); i++ { + v := it.Next() + fmt.Println(v) + } + +} diff --git a/wrapper/iter.go b/wrapper/iter.go new file mode 100644 index 0000000..9842aac --- /dev/null +++ b/wrapper/iter.go @@ -0,0 +1,61 @@ +package wrapper + +import ( + "fmt" + "log" + "reflect" +) + +type Iterator interface { + Next() interface{} +} + +type Iterable struct { + Index int64 + Len int64 + Content Tensor + ItemKind reflect.Kind +} + +// Next implements Iterator interface +func (it *Iterable) Next() (retVal interface{}) { + var err error + switch it.ItemKind { + case reflect.Int64: + retVal, err = it.Content.Int64Value([]int64{it.Index}) + if err != nil { + log.Fatal(err) + } + it.Index += 1 + case reflect.Float64: + retVal, err = it.Content.Float64Value([]int64{it.Index}) + if err != nil { + log.Fatal(err) + } + it.Index += 1 + default: + err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind) + log.Fatal(err) + } + + return retVal +} + +// Iter creates an iterable object with specified item type. +func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) { + num, err := ts.Size1() // size for 1D tensor + if err != nil { + return retVal, err + } + content, err := ts.ShallowClone() + if err != nil { + return retVal, err + } + + return Iterable{ + Index: 0, + Len: num, + Content: content, + ItemKind: kind, + }, nil +}