gotch/tensor/iter.go

62 lines
1.1 KiB
Go
Raw Normal View History

package tensor
2020-06-13 00:52:34 +01:00
import (
"fmt"
"log"
"reflect"
)
type Iterator interface {
Next() interface{}
}
type Iterable struct {
Index int64
Len int64
Content Tensor
ItemKind reflect.Kind
}
// Next implements Iterator interface
func (it *Iterable) Next() (retVal interface{}) {
var err error
switch it.ItemKind {
case reflect.Int64:
retVal, err = it.Content.Int64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
}
it.Index += 1
case reflect.Float64:
retVal, err = it.Content.Float64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
}
it.Index += 1
default:
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
log.Fatal(err)
}
return retVal
}
// Iter creates an iterable object with specified item type.
func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) {
num, err := ts.Size1() // size for 1D tensor
if err != nil {
return retVal, err
}
content, err := ts.ShallowClone()
if err != nil {
return retVal, err
}
return Iterable{
Index: 0,
Len: num,
Content: content,
ItemKind: kind,
}, nil
}