WIP(tensor/index): added tensor/index.go
This commit is contained in:
parent
a3b965b9c5
commit
6f0ad33578
|
@ -55,4 +55,16 @@ func main() {
|
||||||
|
|
||||||
fmt.Printf("Tensor String: \n%v\n", tsString)
|
fmt.Printf("Tensor String: \n%v\n", tsString)
|
||||||
|
|
||||||
|
imagePath := "mnist-sample.png"
|
||||||
|
|
||||||
|
imageTs, err := wrapper.LoadHwc(imagePath)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = imageTs.Save("mnist-tensor-saved.png")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
BIN
example/tensor-io/mnist-sample.png
Normal file
BIN
example/tensor-io/mnist-sample.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.7 KiB |
BIN
example/tensor-io/mnist-tensor-saved.png
Normal file
BIN
example/tensor-io/mnist-tensor-saved.png
Normal file
Binary file not shown.
274
tensor/index.go
Normal file
274
tensor/index.go
Normal file
|
@ -0,0 +1,274 @@
|
||||||
|
package tensor
|
||||||
|
|
||||||
|
// Indexing operations for tensor
|
||||||
|
// It defines a `i` indexing operation. This can be used in various scenarios.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
// Using an integer index returns slice obtained by selecting elements with
|
||||||
|
// specified index. Negative values can be used for the index, and `..` can
|
||||||
|
// be used to get all the indexes from a given dimension.
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// ts := wrapper.OfSlice(int[1,2,3,4,5,6])
|
||||||
|
// ts.View((2,3))
|
||||||
|
// t := ts.i(1)
|
||||||
|
// t = ts.i(.., -2)
|
||||||
|
//```
|
||||||
|
//
|
||||||
|
// Indexes like `..1`, `1..` or `1..2` can be used to narrow a dimension.
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// ts := wrapper.OfSlice(int[1,2,3,4,5,6])
|
||||||
|
// ts.View((2,3))
|
||||||
|
// t := ts.i((..,1..))
|
||||||
|
// t.Size() // [2,2]
|
||||||
|
// t = t.Contiguous()
|
||||||
|
// tsSlice := t.View(-1) // [2,3,5,6]
|
||||||
|
// t := ts.i((..1, ..))
|
||||||
|
// t.Size() // [1,3]
|
||||||
|
// t = t.Contiguous()
|
||||||
|
// t.View(-1) // [1,2,3]
|
||||||
|
// t = ts.i((.., 1..2))
|
||||||
|
// t.Size() // [2,1]
|
||||||
|
// t = t.Contiguous()
|
||||||
|
// t = t.View(-1) // [2,5]
|
||||||
|
// t = ts.i((.., 1..=2))
|
||||||
|
// t.Size() // [2,2]
|
||||||
|
// t = t.Contiguous()
|
||||||
|
// t.View(-1) // [2,3,5,6]
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// `NewAxis` index can be used to insert a dimension.
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// ts := wrapper.OfSlice(int[1,2,3,4,5,6])
|
||||||
|
// ts.View((2,3))
|
||||||
|
// t := ts.i((NewAxis,))
|
||||||
|
// t.Size() // [1,2,3]
|
||||||
|
// t = ts.i((..,..,NewAxis))
|
||||||
|
// t.Size() // [2,3,1]
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// Unlike NumPy, the `i` operation does not support advanced indexing.
|
||||||
|
// The result can be different from NumPy with same set of arguments.
|
||||||
|
// For example, `tensor.i(..1, []int{0,3}, []int{2,1,3})` does narrowing
|
||||||
|
// on first dimension, and index selection on second and third dimensions.
|
||||||
|
// The analogous NumPy indexing `array[:1, [0, 3], [2, 1, 3]]` throws
|
||||||
|
// shape mismatch error due to advanced indexing rule. Another distinction
|
||||||
|
// is that `i` guarantees the input and result tensor shares the same
|
||||||
|
// underlying storage, while NumPy may copy the tensor in certain scenarios.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Tensor type `super` wrapper.Tensor
|
||||||
|
type Tensor struct {
|
||||||
|
wrapper.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
type NewAxis struct{}
|
||||||
|
|
||||||
|
// TensorIndexer is an interface which defines method `From`
|
||||||
|
// for any type to fulfill to become an tensor indexer
|
||||||
|
type TensorIndexer interface {
|
||||||
|
From(interface{}) TensorIndexer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Below is a list of all types that implement `TensorIndexer`
|
||||||
|
// So that they can act as tensor indexer.
|
||||||
|
// type Select struct{}
|
||||||
|
// type Narrow struct {
|
||||||
|
// bound int64
|
||||||
|
// }
|
||||||
|
// type IndexSelect struct {
|
||||||
|
// tensor wrapper.Tensor
|
||||||
|
// }
|
||||||
|
// type InsertNewAxis struct{}
|
||||||
|
//
|
||||||
|
// // Implementing `TensorIndexer`
|
||||||
|
// func (sel Select) From(index interface{}) TensorIndexer {
|
||||||
|
// return sel.Select(index.(int64))
|
||||||
|
// }
|
||||||
|
// func (sel Select) new(index int64) Select {
|
||||||
|
// return Select{index: index}
|
||||||
|
// }
|
||||||
|
|
||||||
|
type SelectFn func(int64) TensorIndexer
|
||||||
|
type NarrowFn func(from int64, to int64) TensorIndexer
|
||||||
|
type IndexSelectFn func(ts Tensor) TensorIndexer
|
||||||
|
type InsertNewAxisFn func() TensorIndexer
|
||||||
|
|
||||||
|
func (sel SelectFn) From(index int64) TensorIndexer {
|
||||||
|
return sel(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: implement `TensorIndexer` for the rest
|
||||||
|
|
||||||
|
// NOTE: all the below variables will have `TensorIndexer` trait.
|
||||||
|
// In other words, they are `TensorIndexer` type.
|
||||||
|
//
|
||||||
|
// Alternatively, we can create a enum-like of TensorIndexer using map.
|
||||||
|
// E.g. TensorIndexers = map[string]interface{}
|
||||||
|
// TensorIndexers["Select"] = SelectFn
|
||||||
|
// TensorIndexers["Narrow"] = NarrowFn
|
||||||
|
// TensorIndexers["IndexSelect"] = IndexSelectFn
|
||||||
|
// TensorIndexers["InsertNewAxis"] = InsertNewAxisFn
|
||||||
|
var (
|
||||||
|
Select SelectFn
|
||||||
|
Narrow NarrowFn
|
||||||
|
IndexSelect IndexSelectFn
|
||||||
|
InsertNewAxis InsertNewAxisFn
|
||||||
|
)
|
||||||
|
|
||||||
|
type IndexOp interface {
|
||||||
|
Idx(index interface{}) Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// implement IndexOp for Tensor:
|
||||||
|
// =============================
|
||||||
|
|
||||||
|
// Idx implements `IndexOp` interface for Tensor
|
||||||
|
//
|
||||||
|
// NOTE:
|
||||||
|
// - `index`: expects type `TensorIndexer` or `[]TensorIndexer`
|
||||||
|
func (ts *Tensor) Idx(index interface{}) (retVal Tensor) {
|
||||||
|
|
||||||
|
// indexTyp := reflect.TypeOf(index)
|
||||||
|
indexVal := reflect.ValueOf(index)
|
||||||
|
|
||||||
|
var indexes []TensorIndexer
|
||||||
|
|
||||||
|
switch indexVal.Kind().String() { // TODO: double check whether it 'Interface' or 'TensorIndexer'???
|
||||||
|
case "TensorIndexer": // T: A
|
||||||
|
indexes = append(indexes, index.(TensorIndexer))
|
||||||
|
case "Slice": // T: []TensorIndexer
|
||||||
|
switch len(index.([]TensorIndexer)) {
|
||||||
|
case 1: // T: [A]
|
||||||
|
idxA := index.([]TensorIndexer)[0]
|
||||||
|
indexes = append(indexes, idxA)
|
||||||
|
case 2: // T: [A, B]
|
||||||
|
idxA := index.([]TensorIndexer)[0]
|
||||||
|
idxB := index.([]TensorIndexer)[1]
|
||||||
|
indexes = append(indexes, idxA, idxB)
|
||||||
|
case 3: // T: [A, B, C]
|
||||||
|
idxA := index.([]TensorIndexer)[0]
|
||||||
|
idxB := index.([]TensorIndexer)[1]
|
||||||
|
idxC := index.([]TensorIndexer)[2]
|
||||||
|
indexes = append(indexes, idxA, idxB, idxC)
|
||||||
|
case 4: // T: [A, B, C, D]
|
||||||
|
idxA := index.([]TensorIndexer)[0]
|
||||||
|
idxB := index.([]TensorIndexer)[1]
|
||||||
|
idxC := index.([]TensorIndexer)[2]
|
||||||
|
idxD := index.([]TensorIndexer)[3]
|
||||||
|
indexes = append(indexes, idxA, idxB, idxC, idxD)
|
||||||
|
case 5: // T: [A, B, C, D, E]
|
||||||
|
idxA := index.([]TensorIndexer)[0]
|
||||||
|
idxB := index.([]TensorIndexer)[1]
|
||||||
|
idxC := index.([]TensorIndexer)[2]
|
||||||
|
idxD := index.([]TensorIndexer)[3]
|
||||||
|
idxE := index.([]TensorIndexer)[4]
|
||||||
|
indexes = append(indexes, idxA, idxB, idxC, idxD, idxE)
|
||||||
|
case 6: // T: [A, B, C, D, E, F]
|
||||||
|
idxA := index.([]TensorIndexer)[0]
|
||||||
|
idxB := index.([]TensorIndexer)[1]
|
||||||
|
idxC := index.([]TensorIndexer)[2]
|
||||||
|
idxD := index.([]TensorIndexer)[3]
|
||||||
|
idxE := index.([]TensorIndexer)[4]
|
||||||
|
idxF := index.([]TensorIndexer)[5]
|
||||||
|
indexes = append(indexes, idxA, idxB, idxC, idxD, idxE, idxF)
|
||||||
|
case 7: // T: [A, B, C, D, E, F, G]
|
||||||
|
idxA := index.([]TensorIndexer)[0]
|
||||||
|
idxB := index.([]TensorIndexer)[1]
|
||||||
|
idxC := index.([]TensorIndexer)[2]
|
||||||
|
idxD := index.([]TensorIndexer)[3]
|
||||||
|
idxE := index.([]TensorIndexer)[4]
|
||||||
|
idxF := index.([]TensorIndexer)[5]
|
||||||
|
idxG := index.([]TensorIndexer)[6]
|
||||||
|
indexes = append(indexes, idxA, idxB, idxC, idxD, idxE, idxF, idxG)
|
||||||
|
default:
|
||||||
|
log.Fatalf("Invalid input 'index' slice length (%v) - max is 7\n", len(index.([]TensorIndexer)))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Fatalf("Invalid 'index' type (%v) - Expected type 'TensorIndexer' or '[]TensorIndexer'\n.", indexVal.Kind().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return ts.mustIndexer(indexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensor Methods:
|
||||||
|
// ===============
|
||||||
|
func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
||||||
|
|
||||||
|
// Make sure number of non-newaxis is not exceed number of dimensions
|
||||||
|
var nonNewAxis []TensorIndexer
|
||||||
|
for _, ti := range indexSpec {
|
||||||
|
if reflect.ValueOf(ti).String() != "InsertNewAxis" {
|
||||||
|
nonNewAxis = append(nonNewAxis, ti)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tsShape, err := ts.Size()
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
tsLen := len(tsShape)
|
||||||
|
if len(nonNewAxis) > tsLen {
|
||||||
|
err = fmt.Errorf("Too many indices for tensor of dimension %v\n", tsLen)
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure tensor conforms the format
|
||||||
|
for _, spec := range indexSpec {
|
||||||
|
// If `spec` is `IndexSelectFn` function and either
|
||||||
|
if reflect.ValueOf(spec).String() == "IndexSelectFn" {
|
||||||
|
|
||||||
|
// 1. its input tensor has dimension > 1, throw error.
|
||||||
|
f, err := wrapper.NewFunc(spec)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Indexer Func Error: %v\n", err)
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
// list of `spec` function input parameters.
|
||||||
|
inArgs := f.Info().InArgs
|
||||||
|
tsVal := inArgs[0] // reflect.Value
|
||||||
|
inputTensor := reflect.ValueOf(tsVal).Interface().(Tensor)
|
||||||
|
inputTensorShape, err := inputTensor.Size()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Indexer Func Error: %v\n", err)
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
if len(inputTensorShape) != 1 {
|
||||||
|
err = fmt.Errorf("Multi-dimenstional tensor is not supported for indexing.")
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. its input tensor has an unsupported dtype
|
||||||
|
if inputTensor.DType() != gotch.Int64 ||
|
||||||
|
inputTensor.DType() != gotch.Int16 ||
|
||||||
|
inputTensor.DType() != gotch.Int8 ||
|
||||||
|
inputTensor.DType() != gotch.Int {
|
||||||
|
|
||||||
|
err = fmt.Errorf("The dtype of tensor used as indices must be one of: 'int64', 'int16', 'int8', 'int'. \n")
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, apply indexing from left to right.
|
||||||
|
// TODO: implement it
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *Tensor) mustIndexer(indexSpec []TensorIndexer) (retVal Tensor) {
|
||||||
|
retVal, err := ts.indexer(indexSpec)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user