fix(tensor/index): fixed incorrect counting of number of NewAxis

This commit is contained in:
sugarme 2020-07-10 11:10:26 +10:00
parent 22333d4544
commit 94c1af9671
4 changed files with 174 additions and 18 deletions

View File

@ -1,27 +1,25 @@
package main
import (
// "github.com/sugarme/gotch"
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func main() {
data := [][]int64{
{1, 1, 1, 2, 2, 2, 3, 3},
{1, 1, 1, 2, 2, 2, 4, 4},
}
// shape := []int64{2, 8}
shape := []int64{2, 2, 4}
t, err := ts.NewTensorFromData(data, shape)
if err != nil {
panic(err)
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
var idxs []ts.TensorIndexer = []ts.TensorIndexer{
// ts.NewNarrow(0, tensor.MustSize()[0]),
// ts.NewNarrow(0, tensor.MustSize()[1]),
ts.NewInsertNewAxis(),
}
t.Print()
result := tensor.Idx(idxs)
idx := ts.NewNarrow(0, 3)
fmt.Printf("Original Ts shape: %v\n", tensor.MustSize())
fmt.Printf("Result Ts shape: %v\n", result.MustSize())
selTs := t.Idx(idx)
selTs.Print()
}

View File

@ -642,3 +642,39 @@ func AtgReflectionPad2d(ptr *Ctensor, self Ctensor, paddingData []int64, padding
C.atg_reflection_pad2d(ptr, self, cpaddingDataPtr, cpaddingLen)
}
// void atg_arange(tensor *, scalar end, int options_kind, int options_device);
func AtgArange(ptr *Ctensor, end Cscalar, optionsKind int32, optionsDevice int32) {
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_arange(ptr, end, coptionsKind, coptionsDevice)
}
// void atg_arange1(tensor *, scalar start, scalar end, int options_kind, int options_device);
func AtgArange1(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32) {
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_arange1(ptr, start, end, coptionsKind, coptionsDevice)
}
// void atg_arange2(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device);
func AtgArange2(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32) {
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_arange2(ptr, start, end, step, coptionsKind, coptionsDevice)
}
// void atg_arange_out(tensor *, tensor out, scalar end);
func AtgArangeOut(ptr *Ctensor, out Ctensor, end Cscalar) {
C.atg_arange_out(ptr, out, end)
}
// void atg_arange_out1(tensor *, tensor out, scalar start, scalar end);
func AtgArangeOut1(ptr *Ctensor, out Ctensor, start Cscalar, end Cscalar) {
C.atg_arange_out1(ptr, out, start, end)
}

View File

@ -193,18 +193,19 @@ func (ts *Tensor) Idx(index interface{}) (retVal Tensor) {
func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
// Make sure number of non-newaxis is not exceed number of dimensions
var nonNewAxis []TensorIndexer
var numNewAxis int = 0
for _, ti := range indexSpec {
if reflect.ValueOf(ti).String() != "InsertNewAxis" {
nonNewAxis = append(nonNewAxis, ti)
if reflect.TypeOf(ti).Name() == "InsertNewAxis" {
numNewAxis += 1
}
}
tsShape, err := ts.Size()
if err != nil {
return retVal, err
}
tsLen := len(tsShape)
if len(nonNewAxis) > tsLen {
if len(indexSpec) > tsLen+numNewAxis {
err = fmt.Errorf("Too many indices for tensor of dimension %v\n", tsLen)
return retVal, err
}
@ -253,6 +254,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
switch reflect.TypeOf(spec).Name() {
case "InsertNewAxis":
fmt.Println(currIdx)
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
if err != nil {
return retVal, err

View File

@ -1966,3 +1966,123 @@ func (ts Tensor) MustReflectionPad2d(paddingData []int64) (retVal Tensor) {
return retVal
}
func Arange(end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgArange(ptr, end.cscalar, kind.CInt(), device.CInt())
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustArange(end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor) {
retVal, err := Arange(end, kind, device)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Arange1(start Scalar, end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgArange1(ptr, start.cscalar, end.cscalar, kind.CInt(), device.CInt())
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustArange1(start, end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor) {
retVal, err := Arange1(start, end, kind, device)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Arange2(start Scalar, end Scalar, step Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgArange2(ptr, start.cscalar, end.cscalar, step.cscalar, kind.CInt(), device.CInt())
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustArange2(start Scalar, end Scalar, step Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor) {
retVal, err := Arange2(start, end, step, kind, device)
if err != nil {
log.Fatal(err)
}
return retVal
}
func ArangeOut(out Tensor, end Scalar) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgArangeOut(ptr, out.ctensor, end.cscalar)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustArangeOut(out Tensor, end Scalar) (retVal Tensor) {
retVal, err := ArangeOut(out, end)
if err != nil {
log.Fatal(err)
}
return retVal
}
func ArangeOut1(out Tensor, start, end Scalar) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgArangeOut1(ptr, out.ctensor, start.cscalar, end.cscalar)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustArangeOut1(out Tensor, start, end Scalar) (retVal Tensor) {
retVal, err := ArangeOut1(out, start, end)
if err != nil {
log.Fatal(err)
}
return retVal
}