gotch/ts/iter.go

69 lines
1.2 KiB
Go
Raw Permalink Normal View History

2022-03-12 07:20:20 +00:00
package ts
2020-06-13 00:52:34 +01:00
import (
"fmt"
"log"
2020-07-11 07:42:27 +01:00
2024-04-21 15:15:00 +01:00
"git.andr3h3nriqu3s.com/andr3/gotch"
2020-06-13 00:52:34 +01:00
)
type Iterator interface {
2020-07-11 07:42:27 +01:00
Next() (item interface{}, ok bool)
2020-06-13 00:52:34 +01:00
}
type Iterable struct {
Index int64
Len int64
Content *Tensor
2020-07-11 07:42:27 +01:00
ItemKind gotch.DType
2020-06-13 00:52:34 +01:00
}
// Next implements Iterator interface
func (it *Iterable) Next() (item interface{}, ok bool) {
2020-07-11 07:42:27 +01:00
if it.Index == it.Len {
return nil, false
2020-07-11 07:42:27 +01:00
}
2020-06-13 00:52:34 +01:00
var err error
2023-07-06 15:01:23 +01:00
switch it.ItemKind.GoKind().String() {
2020-07-11 07:42:27 +01:00
case "int64":
item, err = it.Content.Int64Value([]int64{it.Index})
2020-06-13 00:52:34 +01:00
if err != nil {
log.Fatal(err)
}
it.Index += 1
2020-07-11 07:42:27 +01:00
case "float64":
item, err = it.Content.Float64Value([]int64{it.Index})
2020-06-13 00:52:34 +01:00
if err != nil {
log.Fatal(err)
}
it.Index += 1
default:
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
log.Fatal(err)
}
return item, true
2020-06-13 00:52:34 +01:00
}
// Iter creates an iterable object with specified item type.
func (ts *Tensor) Iter(dtype gotch.DType) (*Iterable, error) {
2020-06-13 00:52:34 +01:00
num, err := ts.Size1() // size for 1D tensor
if err != nil {
return nil, err
2020-06-13 00:52:34 +01:00
}
2020-07-11 07:42:27 +01:00
tmp, err := ts.ShallowClone()
2020-06-13 00:52:34 +01:00
if err != nil {
return nil, err
2020-06-13 00:52:34 +01:00
}
2020-07-11 07:42:27 +01:00
content := tmp.MustTotype(dtype, true)
2020-06-13 00:52:34 +01:00
return &Iterable{
2020-06-13 00:52:34 +01:00
Index: 0,
Len: num,
Content: content,
2020-07-11 07:42:27 +01:00
ItemKind: dtype,
2020-06-13 00:52:34 +01:00
}, nil
}