feat(wrapper/tensor): iter
This commit is contained in:
parent
e188f4965b
commit
49e0335469
32
example/tensor-iterator/main.go
Normal file
32
example/tensor-iterator/main.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
data := [][]int64{
|
||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||
}
|
||||
shape := []int64{16}
|
||||
|
||||
ts, err := wrapper.NewTensorFromData(data, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
it, err := ts.Iter(reflect.Float64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i := 0; i < int(it.Len); i++ {
|
||||
v := it.Next()
|
||||
fmt.Println(v)
|
||||
}
|
||||
|
||||
}
|
61
wrapper/iter.go
Normal file
61
wrapper/iter.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package wrapper
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Iterator interface {
|
||||
Next() interface{}
|
||||
}
|
||||
|
||||
type Iterable struct {
|
||||
Index int64
|
||||
Len int64
|
||||
Content Tensor
|
||||
ItemKind reflect.Kind
|
||||
}
|
||||
|
||||
// Next implements Iterator interface
|
||||
func (it *Iterable) Next() (retVal interface{}) {
|
||||
var err error
|
||||
switch it.ItemKind {
|
||||
case reflect.Int64:
|
||||
retVal, err = it.Content.Int64Value([]int64{it.Index})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
it.Index += 1
|
||||
case reflect.Float64:
|
||||
retVal, err = it.Content.Float64Value([]int64{it.Index})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
it.Index += 1
|
||||
default:
|
||||
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Iter creates an iterable object with specified item type.
|
||||
func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) {
|
||||
num, err := ts.Size1() // size for 1D tensor
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
content, err := ts.ShallowClone()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Iterable{
|
||||
Index: 0,
|
||||
Len: num,
|
||||
Content: content,
|
||||
ItemKind: kind,
|
||||
}, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user