gotch/ts/iter.go
Goncalves Henriques, Andre (UG - Computer Science) 9257404edd Move the name of the module
2024-04-21 15:15:00 +01:00

69 lines
1.2 KiB
Go

package ts
import (
"fmt"
"log"
"git.andr3h3nriqu3s.com/andr3/gotch"
)
type Iterator interface {
Next() (item interface{}, ok bool)
}
type Iterable struct {
Index int64
Len int64
Content *Tensor
ItemKind gotch.DType
}
// Next implements Iterator interface
func (it *Iterable) Next() (item interface{}, ok bool) {
if it.Index == it.Len {
return nil, false
}
var err error
switch it.ItemKind.GoKind().String() {
case "int64":
item, err = it.Content.Int64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
}
it.Index += 1
case "float64":
item, err = it.Content.Float64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
}
it.Index += 1
default:
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
log.Fatal(err)
}
return item, true
}
// Iter creates an iterable object with specified item type.
func (ts *Tensor) Iter(dtype gotch.DType) (*Iterable, error) {
num, err := ts.Size1() // size for 1D tensor
if err != nil {
return nil, err
}
tmp, err := ts.ShallowClone()
if err != nil {
return nil, err
}
content := tmp.MustTotype(dtype, true)
return &Iterable{
Index: 0,
Len: num,
Content: content,
ItemKind: dtype,
}, nil
}