feat(wrapper/index): completed

This commit is contained in:
sugarme 2020-06-13 00:21:59 +10:00
parent 2eaf926bc4
commit e188f4965b
4 changed files with 216 additions and 99 deletions

View File

@ -0,0 +1,49 @@
package main
import (
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
data := [][]int64{
{1, 1, 1, 2, 2, 2, 3, 3},
{1, 1, 1, 2, 2, 2, 4, 4},
}
shape := []int64{2, 8}
// shape := []int64{2, 2, 4}
ts, err := wrapper.NewTensorFromData(data, shape)
if err != nil {
panic(err)
}
ts.Print()
// Select
s := wrapper.NewSelect(7)
// selectedTs := ts.Idx(s)
// selectedTs.Print()
// Narrow (start inclusive, end exclusive)
n := wrapper.NewNarrow(0, 1)
// narrowedTs := ts.Idx(n)
// narrowedTs.Print()
// InsertNewAxis
// i := wrapper.NewInsertNewAxis()
// newAxisTs := ts.Idx(i)
// newAxisTs.Print()
// IndexSelect
// idxTensor := wrapper.MustOfSlice([]int64{0, 1})
// is := wrapper.NewIndexSelect(idxTensor)
// isTs := ts.Idx(is)
// isTs.Print()
// Combined
var tsIndexes []wrapper.TensorIndexer = []wrapper.TensorIndexer{n, s}
combinedTs := ts.Idx(tsIndexes)
combinedTs.Print()
}

View File

@ -67,3 +67,30 @@ func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
C.atg_totype(ptr, self, cscalar_type)
}
// void atg_unsqueeze(tensor *, tensor self, int64_t dim);
func AtgUnsqueeze(ptr *Ctensor, self Ctensor, dim int64) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
C.atg_unsqueeze(ptr, self, cdim)
}
// void atg_select(tensor *, tensor self, int64_t dim, int64_t index);
func AtgSelect(ptr *Ctensor, self Ctensor, dim int64, index int64) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
cindex := *(*C.int64_t)(unsafe.Pointer(&index))
C.atg_select(ptr, self, cdim, cindex)
}
// void atg_narrow(tensor *, tensor self, int64_t dim, int64_t start, int64_t length);
func AtgNarrow(ptr *Ctensor, self Ctensor, dim int64, start int64, length int64) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
cstart := *(*C.int64_t)(unsafe.Pointer(&start))
clength := *(*C.int64_t)(unsafe.Pointer(&length))
C.atg_narrow(ptr, self, cdim, cstart, clength)
}
// void atg_index_select(tensor *, tensor self, int64_t dim, tensor index);
func AtgIndexSelect(ptr *Ctensor, self Ctensor, dim int64, index Ctensor) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
C.atg_index_select(ptr, self, cdim, index)
}

View File

@ -70,55 +70,45 @@ type NewAxis struct{}
// TensorIndexer is an interface which defines method `From`
// for any type to fulfill to become an tensor indexer
type TensorIndexer interface {
From(interface{}) TensorIndexer
type TensorIndexer interface{}
type Select struct{ Index int64 }
type Narrow struct {
Start int64
End int64
}
type IndexSelect struct{ Index Tensor }
type InsertNewAxis struct{}
// NewSelect creates an tensor indexer with given index.
// `index` must be in range of tensor dimension. E.g. tensor shape [2,8]
// will have size = 2, hence `index` should be in range from [0,2)
func NewSelect(index int64) Select {
return Select{index}
}
// Below is a list of all types that implement `TensorIndexer`
// So that they can act as tensor indexer.
// type Select struct{}
// type Narrow struct {
// bound int64
// }
// type IndexSelect struct {
// tensor wrapper.Tensor
// }
// type InsertNewAxis struct{}
//
// // Implementing `TensorIndexer`
// func (sel Select) From(index interface{}) TensorIndexer {
// return sel.Select(index.(int64))
// }
// func (sel Select) new(index int64) Select {
// return Select{index: index}
// }
type SelectFn func(int64) TensorIndexer
type NarrowFn func(from int64, to int64) TensorIndexer
type IndexSelectFn func(ts Tensor) TensorIndexer
type InsertNewAxisFn func() TensorIndexer
func (sel SelectFn) From(index int64) TensorIndexer {
return sel(index)
func NewNarrow(start, end int64) Narrow {
return Narrow{Start: start, End: end}
}
// TODO: implement `TensorIndexer` for the rest
func NewIndexSelect(ts Tensor) IndexSelect {
return IndexSelect{Index: ts}
}
// NOTE: all the below variables will have `TensorIndexer` trait.
// In other words, they are `TensorIndexer` type.
func NewInsertNewAxis() InsertNewAxis {
return InsertNewAxis{}
}
// type SelectFn func(int64)
// type NarrowFn func(from int64, to int64)
// type IndexSelectFn func(ts Tensor)
// type InsertNewAxisFn func()
//
// Alternatively, we can create a enum-like of TensorIndexer using map.
// E.g. TensorIndexers = map[string]interface{}
// TensorIndexers["Select"] = SelectFn
// TensorIndexers["Narrow"] = NarrowFn
// TensorIndexers["IndexSelect"] = IndexSelectFn
// TensorIndexers["InsertNewAxis"] = InsertNewAxisFn
var (
Select SelectFn
Narrow NarrowFn
IndexSelect IndexSelectFn
InsertNewAxis InsertNewAxisFn
)
// var (
// // Select SelectFn
// // Narrow NarrowFn
// // IndexSelect IndexSelectFn
// // InsertNewAxis InsertNewAxisFn
// )
type IndexOp interface {
Idx(index interface{}) Tensor
@ -138,10 +128,10 @@ func (ts *Tensor) Idx(index interface{}) (retVal Tensor) {
var indexes []TensorIndexer
switch indexVal.Kind().String() { // TODO: double check whether it 'Interface' or 'TensorIndexer'???
case "TensorIndexer": // T: A
switch indexVal.Kind().String() {
case "struct": // T: A
indexes = append(indexes, index.(TensorIndexer))
case "Slice": // T: []TensorIndexer
case "slice": // T: []TensorIndexer
switch len(index.([]TensorIndexer)) {
case 1: // T: [A]
idxA := index.([]TensorIndexer)[0]
@ -218,37 +208,31 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
// Make sure tensor conforms the format
for _, spec := range indexSpec {
// If `spec` is `IndexSelectFn` function and either
if reflect.ValueOf(spec).String() == "IndexSelectFn" {
// If `spec` is `IndexSelect` type and
if reflect.TypeOf(spec).Name() == "IndexSelect" {
if reflect.ValueOf(spec).Kind() == reflect.Struct {
inputTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(Tensor)
// 1. its input tensor has dimension > 1, throw error.
f, err := NewFunc(spec)
if err != nil {
err = fmt.Errorf("Indexer Func Error: %v\n", err)
return retVal, err
}
// list of `spec` function input parameters.
inArgs := f.Info().InArgs
tsVal := inArgs[0] // reflect.Value
inputTensor := reflect.ValueOf(tsVal).Interface().(Tensor)
inputTensorShape, err := inputTensor.Size()
if err != nil {
err = fmt.Errorf("Indexer Func Error: %v\n", err)
return retVal, err
}
if len(inputTensorShape) != 1 {
err = fmt.Errorf("Multi-dimenstional tensor is not supported for indexing.")
return retVal, err
}
// 1. Either its input tensor has dimension > 1, throw error.
inputTensorShape, err := inputTensor.Size()
if err != nil {
err = fmt.Errorf("Indexer Func Error: %v\n", err)
return retVal, err
}
if len(inputTensorShape) != 1 {
err = fmt.Errorf("Multi-dimenstional tensor is not supported for indexing.")
return retVal, err
}
// 2. its input tensor has an unsupported dtype
if inputTensor.DType() != gotch.Int64 ||
inputTensor.DType() != gotch.Int16 ||
inputTensor.DType() != gotch.Int8 ||
inputTensor.DType() != gotch.Int {
// 2. Or its input tensor has an unsupported dtype
if inputTensor.DType() != gotch.Int64 &&
inputTensor.DType() != gotch.Int16 &&
inputTensor.DType() != gotch.Int8 &&
inputTensor.DType() != gotch.Int {
err = fmt.Errorf("The dtype of tensor used as indices must be one of: 'int64', 'int16', 'int8', 'int'. \n")
return retVal, err
err = fmt.Errorf("The dtype of tensor used (%v) as indices must be one of: 'int64', 'int16', 'int8', 'int'. \n", inputTensor.DType())
return retVal, err
}
}
}
}
@ -257,53 +241,46 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
var (
currTensor Tensor = ts.MustShallowClone()
currIdx int64 = 0
nextTensor Tensor
nextIdx int64
)
// `spec` is a function type implements `TensorIndexer`
for _, spec := range indexSpec {
var (
nextTensor Tensor
nextIdx int64
)
// get info of `spec` function
f, err := NewFunc(spec)
if err != nil {
err = fmt.Errorf("Indexer Func Error: %v\n", err)
return retVal, err
}
// list of `spec` function input parameters.
inArgs := f.Info().InArgs
fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
// Now, specific indexOp depending on `TensorIndexer` func
switch reflect.ValueOf(spec).Kind().String() {
switch reflect.TypeOf(spec).Name() {
case "InsertNewAxis":
nextTensor, err = currTensor.Unsqueeze(currIdx)
if err != nil {
return retVal, err
}
nextIdx = currIdx + 1
case "SelectFn": // 1 param of `(index int64)`
indexVal := inArgs[0]
index := reflect.ValueOf(indexVal).Interface().(int64)
case "Select": // 1 field: `Index`
index := reflect.ValueOf(spec).FieldByName("Index").Interface().(int64)
nextTensor, err = currTensor.Select(currIdx, index) // TODO: double-check is `*index` or `index`
if err != nil {
return retVal, err
}
nextIdx = currIdx // not advanced because select() squeezes dimension
case "NarrowFn": // 2 params: `(start, end int64)`
case "Narrow": // 2 fields: `(Start, End int64)`
// TODO: implement for `Unbounded`, `Included`, `Excluded` ranges
// NOTE: for now, just implement (Included(start), Excluded(end))` case
start := reflect.ValueOf(inArgs[0]).Interface().(int64)
end := reflect.ValueOf(inArgs[1]).Interface().(int64)
start := reflect.ValueOf(spec).FieldByName("Start").Interface().(int64)
end := reflect.ValueOf(spec).FieldByName("End").Interface().(int64)
nextTensor, err = currTensor.Narrow(currIdx, start, end-start)
if err != nil {
return retVal, err
}
nextIdx = currIdx + 1
case "IndexSelectFn": // 1 param `(indexTensor Tensor)`
indexTensor := reflect.ValueOf(inArgs[0]).Interface().(Tensor)
indexTensor, err = indexTensor.ToDevice(currTensor.Device())
case "IndexSelect": // 1 field `(Index Tensor)`
indexTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(Tensor)
device, err := currTensor.Device()
if err != nil {
return retVal, err
}
indexTensor, err = indexTensor.To(device)
if err != nil {
return retVal, err
}
@ -312,7 +289,9 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
return retVal, err
}
nextIdx = currIdx + 1
}
} // end of switch
currTensor.Print()
currTensor = nextTensor
currIdx = nextIdx
@ -320,7 +299,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
retVal = currTensor
return
return retVal, nil
}
func (ts Tensor) mustIndexer(indexSpec []TensorIndexer) (retVal Tensor) {

View File

@ -231,3 +231,65 @@ func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
return retVal
}
// Unsqueeze unsqueezes tensor to specified dimension.
func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgUnsqueeze(ptr, ts.ctensor, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
// Select creates a new tensor from current tensor given dim and index.
func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSelect(ptr, ts.ctensor, dim, index)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
// Narrow creates a new tensor from current tensor given dim and start index
// and length.
func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgNarrow(ptr, ts.ctensor, dim, start, length)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
// IndexSelect creates a new tensor from current tensor given dim and index
// tensor.
func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgIndexSelect(ptr, ts.ctensor, dim, index.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}