feat(wrapper/index): completed
This commit is contained in:
parent
2eaf926bc4
commit
e188f4965b
49
example/tensor-index/main.go
Normal file
49
example/tensor-index/main.go
Normal file
|
@ -0,0 +1,49 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
data := [][]int64{
|
||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||
}
|
||||
shape := []int64{2, 8}
|
||||
// shape := []int64{2, 2, 4}
|
||||
|
||||
ts, err := wrapper.NewTensorFromData(data, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ts.Print()
|
||||
|
||||
// Select
|
||||
s := wrapper.NewSelect(7)
|
||||
// selectedTs := ts.Idx(s)
|
||||
// selectedTs.Print()
|
||||
|
||||
// Narrow (start inclusive, end exclusive)
|
||||
n := wrapper.NewNarrow(0, 1)
|
||||
// narrowedTs := ts.Idx(n)
|
||||
// narrowedTs.Print()
|
||||
|
||||
// InsertNewAxis
|
||||
// i := wrapper.NewInsertNewAxis()
|
||||
// newAxisTs := ts.Idx(i)
|
||||
// newAxisTs.Print()
|
||||
|
||||
// IndexSelect
|
||||
// idxTensor := wrapper.MustOfSlice([]int64{0, 1})
|
||||
// is := wrapper.NewIndexSelect(idxTensor)
|
||||
// isTs := ts.Idx(is)
|
||||
// isTs.Print()
|
||||
|
||||
// Combined
|
||||
var tsIndexes []wrapper.TensorIndexer = []wrapper.TensorIndexer{n, s}
|
||||
combinedTs := ts.Idx(tsIndexes)
|
||||
|
||||
combinedTs.Print()
|
||||
|
||||
}
|
|
@ -67,3 +67,30 @@ func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
|||
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
||||
C.atg_totype(ptr, self, cscalar_type)
|
||||
}
|
||||
|
||||
// void atg_unsqueeze(tensor *, tensor self, int64_t dim);
|
||||
func AtgUnsqueeze(ptr *Ctensor, self Ctensor, dim int64) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
C.atg_unsqueeze(ptr, self, cdim)
|
||||
}
|
||||
|
||||
// void atg_select(tensor *, tensor self, int64_t dim, int64_t index);
|
||||
func AtgSelect(ptr *Ctensor, self Ctensor, dim int64, index int64) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
cindex := *(*C.int64_t)(unsafe.Pointer(&index))
|
||||
C.atg_select(ptr, self, cdim, cindex)
|
||||
}
|
||||
|
||||
// void atg_narrow(tensor *, tensor self, int64_t dim, int64_t start, int64_t length);
|
||||
func AtgNarrow(ptr *Ctensor, self Ctensor, dim int64, start int64, length int64) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
cstart := *(*C.int64_t)(unsafe.Pointer(&start))
|
||||
clength := *(*C.int64_t)(unsafe.Pointer(&length))
|
||||
C.atg_narrow(ptr, self, cdim, cstart, clength)
|
||||
}
|
||||
|
||||
// void atg_index_select(tensor *, tensor self, int64_t dim, tensor index);
|
||||
func AtgIndexSelect(ptr *Ctensor, self Ctensor, dim int64, index Ctensor) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
C.atg_index_select(ptr, self, cdim, index)
|
||||
}
|
||||
|
|
177
wrapper/index.go
177
wrapper/index.go
|
@ -70,55 +70,45 @@ type NewAxis struct{}
|
|||
|
||||
// TensorIndexer is an interface which defines method `From`
|
||||
// for any type to fulfill to become an tensor indexer
|
||||
type TensorIndexer interface {
|
||||
From(interface{}) TensorIndexer
|
||||
type TensorIndexer interface{}
|
||||
type Select struct{ Index int64 }
|
||||
type Narrow struct {
|
||||
Start int64
|
||||
End int64
|
||||
}
|
||||
type IndexSelect struct{ Index Tensor }
|
||||
type InsertNewAxis struct{}
|
||||
|
||||
// NewSelect creates an tensor indexer with given index.
|
||||
// `index` must be in range of tensor dimension. E.g. tensor shape [2,8]
|
||||
// will have size = 2, hence `index` should be in range from [0,2)
|
||||
func NewSelect(index int64) Select {
|
||||
return Select{index}
|
||||
}
|
||||
|
||||
// Below is a list of all types that implement `TensorIndexer`
|
||||
// So that they can act as tensor indexer.
|
||||
// type Select struct{}
|
||||
// type Narrow struct {
|
||||
// bound int64
|
||||
// }
|
||||
// type IndexSelect struct {
|
||||
// tensor wrapper.Tensor
|
||||
// }
|
||||
// type InsertNewAxis struct{}
|
||||
//
|
||||
// // Implementing `TensorIndexer`
|
||||
// func (sel Select) From(index interface{}) TensorIndexer {
|
||||
// return sel.Select(index.(int64))
|
||||
// }
|
||||
// func (sel Select) new(index int64) Select {
|
||||
// return Select{index: index}
|
||||
// }
|
||||
|
||||
type SelectFn func(int64) TensorIndexer
|
||||
type NarrowFn func(from int64, to int64) TensorIndexer
|
||||
type IndexSelectFn func(ts Tensor) TensorIndexer
|
||||
type InsertNewAxisFn func() TensorIndexer
|
||||
|
||||
func (sel SelectFn) From(index int64) TensorIndexer {
|
||||
return sel(index)
|
||||
func NewNarrow(start, end int64) Narrow {
|
||||
return Narrow{Start: start, End: end}
|
||||
}
|
||||
|
||||
// TODO: implement `TensorIndexer` for the rest
|
||||
func NewIndexSelect(ts Tensor) IndexSelect {
|
||||
return IndexSelect{Index: ts}
|
||||
}
|
||||
|
||||
// NOTE: all the below variables will have `TensorIndexer` trait.
|
||||
// In other words, they are `TensorIndexer` type.
|
||||
func NewInsertNewAxis() InsertNewAxis {
|
||||
return InsertNewAxis{}
|
||||
}
|
||||
|
||||
// type SelectFn func(int64)
|
||||
// type NarrowFn func(from int64, to int64)
|
||||
// type IndexSelectFn func(ts Tensor)
|
||||
// type InsertNewAxisFn func()
|
||||
//
|
||||
// Alternatively, we can create a enum-like of TensorIndexer using map.
|
||||
// E.g. TensorIndexers = map[string]interface{}
|
||||
// TensorIndexers["Select"] = SelectFn
|
||||
// TensorIndexers["Narrow"] = NarrowFn
|
||||
// TensorIndexers["IndexSelect"] = IndexSelectFn
|
||||
// TensorIndexers["InsertNewAxis"] = InsertNewAxisFn
|
||||
var (
|
||||
Select SelectFn
|
||||
Narrow NarrowFn
|
||||
IndexSelect IndexSelectFn
|
||||
InsertNewAxis InsertNewAxisFn
|
||||
)
|
||||
// var (
|
||||
// // Select SelectFn
|
||||
// // Narrow NarrowFn
|
||||
// // IndexSelect IndexSelectFn
|
||||
// // InsertNewAxis InsertNewAxisFn
|
||||
// )
|
||||
|
||||
type IndexOp interface {
|
||||
Idx(index interface{}) Tensor
|
||||
|
@ -138,10 +128,10 @@ func (ts *Tensor) Idx(index interface{}) (retVal Tensor) {
|
|||
|
||||
var indexes []TensorIndexer
|
||||
|
||||
switch indexVal.Kind().String() { // TODO: double check whether it 'Interface' or 'TensorIndexer'???
|
||||
case "TensorIndexer": // T: A
|
||||
switch indexVal.Kind().String() {
|
||||
case "struct": // T: A
|
||||
indexes = append(indexes, index.(TensorIndexer))
|
||||
case "Slice": // T: []TensorIndexer
|
||||
case "slice": // T: []TensorIndexer
|
||||
switch len(index.([]TensorIndexer)) {
|
||||
case 1: // T: [A]
|
||||
idxA := index.([]TensorIndexer)[0]
|
||||
|
@ -218,37 +208,31 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
|
||||
// Make sure tensor conforms the format
|
||||
for _, spec := range indexSpec {
|
||||
// If `spec` is `IndexSelectFn` function and either
|
||||
if reflect.ValueOf(spec).String() == "IndexSelectFn" {
|
||||
// If `spec` is `IndexSelect` type and
|
||||
if reflect.TypeOf(spec).Name() == "IndexSelect" {
|
||||
if reflect.ValueOf(spec).Kind() == reflect.Struct {
|
||||
inputTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(Tensor)
|
||||
|
||||
// 1. its input tensor has dimension > 1, throw error.
|
||||
f, err := NewFunc(spec)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Indexer Func Error: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
// list of `spec` function input parameters.
|
||||
inArgs := f.Info().InArgs
|
||||
tsVal := inArgs[0] // reflect.Value
|
||||
inputTensor := reflect.ValueOf(tsVal).Interface().(Tensor)
|
||||
inputTensorShape, err := inputTensor.Size()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Indexer Func Error: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
if len(inputTensorShape) != 1 {
|
||||
err = fmt.Errorf("Multi-dimenstional tensor is not supported for indexing.")
|
||||
return retVal, err
|
||||
}
|
||||
// 1. Either its input tensor has dimension > 1, throw error.
|
||||
inputTensorShape, err := inputTensor.Size()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Indexer Func Error: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
if len(inputTensorShape) != 1 {
|
||||
err = fmt.Errorf("Multi-dimenstional tensor is not supported for indexing.")
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// 2. its input tensor has an unsupported dtype
|
||||
if inputTensor.DType() != gotch.Int64 ||
|
||||
inputTensor.DType() != gotch.Int16 ||
|
||||
inputTensor.DType() != gotch.Int8 ||
|
||||
inputTensor.DType() != gotch.Int {
|
||||
// 2. Or its input tensor has an unsupported dtype
|
||||
if inputTensor.DType() != gotch.Int64 &&
|
||||
inputTensor.DType() != gotch.Int16 &&
|
||||
inputTensor.DType() != gotch.Int8 &&
|
||||
inputTensor.DType() != gotch.Int {
|
||||
|
||||
err = fmt.Errorf("The dtype of tensor used as indices must be one of: 'int64', 'int16', 'int8', 'int'. \n")
|
||||
return retVal, err
|
||||
err = fmt.Errorf("The dtype of tensor used (%v) as indices must be one of: 'int64', 'int16', 'int8', 'int'. \n", inputTensor.DType())
|
||||
return retVal, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -257,53 +241,46 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
var (
|
||||
currTensor Tensor = ts.MustShallowClone()
|
||||
currIdx int64 = 0
|
||||
nextTensor Tensor
|
||||
nextIdx int64
|
||||
)
|
||||
|
||||
// `spec` is a function type implements `TensorIndexer`
|
||||
for _, spec := range indexSpec {
|
||||
var (
|
||||
nextTensor Tensor
|
||||
nextIdx int64
|
||||
)
|
||||
|
||||
// get info of `spec` function
|
||||
f, err := NewFunc(spec)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Indexer Func Error: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
// list of `spec` function input parameters.
|
||||
inArgs := f.Info().InArgs
|
||||
fmt.Printf("spec type: %v\n", reflect.TypeOf(spec).Name())
|
||||
|
||||
// Now, specific indexOp depending on `TensorIndexer` func
|
||||
switch reflect.ValueOf(spec).Kind().String() {
|
||||
switch reflect.TypeOf(spec).Name() {
|
||||
case "InsertNewAxis":
|
||||
nextTensor, err = currTensor.Unsqueeze(currIdx)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx + 1
|
||||
case "SelectFn": // 1 param of `(index int64)`
|
||||
indexVal := inArgs[0]
|
||||
index := reflect.ValueOf(indexVal).Interface().(int64)
|
||||
case "Select": // 1 field: `Index`
|
||||
index := reflect.ValueOf(spec).FieldByName("Index").Interface().(int64)
|
||||
nextTensor, err = currTensor.Select(currIdx, index) // TODO: double-check is `*index` or `index`
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx // not advanced because select() squeezes dimension
|
||||
case "NarrowFn": // 2 params: `(start, end int64)`
|
||||
case "Narrow": // 2 fields: `(Start, End int64)`
|
||||
// TODO: implement for `Unbounded`, `Included`, `Excluded` ranges
|
||||
// NOTE: for now, just implement (Included(start), Excluded(end))` case
|
||||
start := reflect.ValueOf(inArgs[0]).Interface().(int64)
|
||||
end := reflect.ValueOf(inArgs[1]).Interface().(int64)
|
||||
start := reflect.ValueOf(spec).FieldByName("Start").Interface().(int64)
|
||||
end := reflect.ValueOf(spec).FieldByName("End").Interface().(int64)
|
||||
nextTensor, err = currTensor.Narrow(currIdx, start, end-start)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx + 1
|
||||
case "IndexSelectFn": // 1 param `(indexTensor Tensor)`
|
||||
indexTensor := reflect.ValueOf(inArgs[0]).Interface().(Tensor)
|
||||
indexTensor, err = indexTensor.ToDevice(currTensor.Device())
|
||||
case "IndexSelect": // 1 field `(Index Tensor)`
|
||||
indexTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(Tensor)
|
||||
device, err := currTensor.Device()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
indexTensor, err = indexTensor.To(device)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -312,7 +289,9 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx + 1
|
||||
}
|
||||
} // end of switch
|
||||
|
||||
currTensor.Print()
|
||||
|
||||
currTensor = nextTensor
|
||||
currIdx = nextIdx
|
||||
|
@ -320,7 +299,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
|
||||
retVal = currTensor
|
||||
|
||||
return
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) mustIndexer(indexSpec []TensorIndexer) (retVal Tensor) {
|
||||
|
|
|
@ -231,3 +231,65 @@ func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Unsqueeze unsqueezes tensor to specified dimension.
|
||||
func (ts Tensor) Unsqueeze(dim int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgUnsqueeze(ptr, ts.ctensor, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Select creates a new tensor from current tensor given dim and index.
|
||||
func (ts Tensor) Select(dim int64, index int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSelect(ptr, ts.ctensor, dim, index)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Narrow creates a new tensor from current tensor given dim and start index
|
||||
// and length.
|
||||
func (ts Tensor) Narrow(dim int64, start int64, length int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgNarrow(ptr, ts.ctensor, dim, start, length)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// IndexSelect creates a new tensor from current tensor given dim and index
|
||||
// tensor.
|
||||
func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgIndexSelect(ptr, ts.ctensor, dim, index.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user