feat(wrapper/tensor): iter
This commit is contained in:
parent
e188f4965b
commit
49e0335469
32
example/tensor-iterator/main.go
Normal file
32
example/tensor-iterator/main.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
wrapper "github.com/sugarme/gotch/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
data := [][]int64{
|
||||||
|
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||||
|
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||||
|
}
|
||||||
|
shape := []int64{16}
|
||||||
|
|
||||||
|
ts, err := wrapper.NewTensorFromData(data, shape)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
it, err := ts.Iter(reflect.Float64)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < int(it.Len); i++ {
|
||||||
|
v := it.Next()
|
||||||
|
fmt.Println(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
61
wrapper/iter.go
Normal file
61
wrapper/iter.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package wrapper
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Iterator interface {
|
||||||
|
Next() interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Iterable struct {
|
||||||
|
Index int64
|
||||||
|
Len int64
|
||||||
|
Content Tensor
|
||||||
|
ItemKind reflect.Kind
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next implements Iterator interface
|
||||||
|
func (it *Iterable) Next() (retVal interface{}) {
|
||||||
|
var err error
|
||||||
|
switch it.ItemKind {
|
||||||
|
case reflect.Int64:
|
||||||
|
retVal, err = it.Content.Int64Value([]int64{it.Index})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
it.Index += 1
|
||||||
|
case reflect.Float64:
|
||||||
|
retVal, err = it.Content.Float64Value([]int64{it.Index})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
it.Index += 1
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iter creates an iterable object with specified item type.
|
||||||
|
func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) {
|
||||||
|
num, err := ts.Size1() // size for 1D tensor
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
content, err := ts.ShallowClone()
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Iterable{
|
||||||
|
Index: 0,
|
||||||
|
Len: num,
|
||||||
|
Content: content,
|
||||||
|
ItemKind: kind,
|
||||||
|
}, nil
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user