69 lines
1.2 KiB
Go
69 lines
1.2 KiB
Go
package ts
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
)
|
|
|
|
type Iterator interface {
|
|
Next() (item interface{}, ok bool)
|
|
}
|
|
|
|
type Iterable struct {
|
|
Index int64
|
|
Len int64
|
|
Content *Tensor
|
|
ItemKind gotch.DType
|
|
}
|
|
|
|
// Next implements Iterator interface
|
|
func (it *Iterable) Next() (item interface{}, ok bool) {
|
|
|
|
if it.Index == it.Len {
|
|
return nil, false
|
|
}
|
|
|
|
var err error
|
|
switch it.ItemKind.GoKind().String() {
|
|
case "int64":
|
|
item, err = it.Content.Int64Value([]int64{it.Index})
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
it.Index += 1
|
|
case "float64":
|
|
item, err = it.Content.Float64Value([]int64{it.Index})
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
it.Index += 1
|
|
default:
|
|
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return item, true
|
|
}
|
|
|
|
// Iter creates an iterable object with specified item type.
|
|
func (ts *Tensor) Iter(dtype gotch.DType) (*Iterable, error) {
|
|
num, err := ts.Size1() // size for 1D tensor
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tmp, err := ts.ShallowClone()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
content := tmp.MustTotype(dtype, true)
|
|
|
|
return &Iterable{
|
|
Index: 0,
|
|
Len: num,
|
|
Content: content,
|
|
ItemKind: dtype,
|
|
}, nil
|
|
}
|