2020-06-13 02:19:01 +01:00
|
|
|
package tensor
|
2020-06-13 00:52:34 +01:00
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
|
|
|
"log"
|
2020-07-11 07:42:27 +01:00
|
|
|
|
|
|
|
"github.com/sugarme/gotch"
|
2020-06-13 00:52:34 +01:00
|
|
|
)
|
|
|
|
|
|
|
|
type Iterator interface {
|
2020-07-11 07:42:27 +01:00
|
|
|
Next() (item interface{}, ok bool)
|
2020-06-13 00:52:34 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
type Iterable struct {
|
|
|
|
Index int64
|
|
|
|
Len int64
|
|
|
|
Content Tensor
|
2020-07-11 07:42:27 +01:00
|
|
|
ItemKind gotch.DType
|
2020-06-13 00:52:34 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Next implements Iterator interface
|
2020-07-11 07:42:27 +01:00
|
|
|
func (it *Iterable) Next() (retVal interface{}, ok bool) {
|
|
|
|
|
|
|
|
if it.Index == it.Len {
|
|
|
|
return retVal, false
|
|
|
|
}
|
|
|
|
|
2020-06-13 00:52:34 +01:00
|
|
|
var err error
|
2020-07-11 07:42:27 +01:00
|
|
|
switch it.ItemKind.Kind().String() {
|
|
|
|
case "int64":
|
2020-06-13 00:52:34 +01:00
|
|
|
retVal, err = it.Content.Int64Value([]int64{it.Index})
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
it.Index += 1
|
2020-07-11 07:42:27 +01:00
|
|
|
case "float64":
|
2020-06-13 00:52:34 +01:00
|
|
|
retVal, err = it.Content.Float64Value([]int64{it.Index})
|
|
|
|
if err != nil {
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
it.Index += 1
|
|
|
|
default:
|
|
|
|
err := fmt.Errorf("Iterator error: unsupported item kind (%v).\n", it.ItemKind)
|
|
|
|
log.Fatal(err)
|
|
|
|
}
|
|
|
|
|
2020-07-11 07:42:27 +01:00
|
|
|
return retVal, true
|
2020-06-13 00:52:34 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
// Iter creates an iterable object with specified item type.
|
2020-07-11 07:42:27 +01:00
|
|
|
func (ts Tensor) Iter(dtype gotch.DType) (retVal Iterable, err error) {
|
2020-06-13 00:52:34 +01:00
|
|
|
num, err := ts.Size1() // size for 1D tensor
|
|
|
|
if err != nil {
|
|
|
|
return retVal, err
|
|
|
|
}
|
2020-07-11 07:42:27 +01:00
|
|
|
tmp, err := ts.ShallowClone()
|
2020-06-13 00:52:34 +01:00
|
|
|
if err != nil {
|
|
|
|
return retVal, err
|
|
|
|
}
|
2020-07-11 07:42:27 +01:00
|
|
|
content := tmp.MustTotype(dtype, true)
|
2020-06-13 00:52:34 +01:00
|
|
|
|
|
|
|
return Iterable{
|
|
|
|
Index: 0,
|
|
|
|
Len: num,
|
|
|
|
Content: content,
|
2020-07-11 07:42:27 +01:00
|
|
|
ItemKind: dtype,
|
2020-06-13 00:52:34 +01:00
|
|
|
}, nil
|
|
|
|
}
|