fix(tensor/index): fixed incorrect counting of number of NewAxis
This commit is contained in:
parent
22333d4544
commit
94c1af9671
|
@ -1,27 +1,25 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// "github.com/sugarme/gotch"
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
data := [][]int64{
|
||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||
}
|
||||
// shape := []int64{2, 8}
|
||||
shape := []int64{2, 2, 4}
|
||||
|
||||
t, err := ts.NewTensorFromData(data, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
|
||||
var idxs []ts.TensorIndexer = []ts.TensorIndexer{
|
||||
// ts.NewNarrow(0, tensor.MustSize()[0]),
|
||||
// ts.NewNarrow(0, tensor.MustSize()[1]),
|
||||
ts.NewInsertNewAxis(),
|
||||
}
|
||||
|
||||
t.Print()
|
||||
result := tensor.Idx(idxs)
|
||||
|
||||
idx := ts.NewNarrow(0, 3)
|
||||
fmt.Printf("Original Ts shape: %v\n", tensor.MustSize())
|
||||
fmt.Printf("Result Ts shape: %v\n", result.MustSize())
|
||||
|
||||
selTs := t.Idx(idx)
|
||||
selTs.Print()
|
||||
}
|
||||
|
|
|
@ -642,3 +642,39 @@ func AtgReflectionPad2d(ptr *Ctensor, self Ctensor, paddingData []int64, padding
|
|||
|
||||
C.atg_reflection_pad2d(ptr, self, cpaddingDataPtr, cpaddingLen)
|
||||
}
|
||||
|
||||
// void atg_arange(tensor *, scalar end, int options_kind, int options_device);
|
||||
func AtgArange(ptr *Ctensor, end Cscalar, optionsKind int32, optionsDevice int32) {
|
||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||
|
||||
C.atg_arange(ptr, end, coptionsKind, coptionsDevice)
|
||||
}
|
||||
|
||||
// void atg_arange1(tensor *, scalar start, scalar end, int options_kind, int options_device);
|
||||
func AtgArange1(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32) {
|
||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||
|
||||
C.atg_arange1(ptr, start, end, coptionsKind, coptionsDevice)
|
||||
}
|
||||
|
||||
// void atg_arange2(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device);
|
||||
func AtgArange2(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32) {
|
||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||
|
||||
C.atg_arange2(ptr, start, end, step, coptionsKind, coptionsDevice)
|
||||
}
|
||||
|
||||
// void atg_arange_out(tensor *, tensor out, scalar end);
|
||||
func AtgArangeOut(ptr *Ctensor, out Ctensor, end Cscalar) {
|
||||
|
||||
C.atg_arange_out(ptr, out, end)
|
||||
}
|
||||
|
||||
// void atg_arange_out1(tensor *, tensor out, scalar start, scalar end);
|
||||
func AtgArangeOut1(ptr *Ctensor, out Ctensor, start Cscalar, end Cscalar) {
|
||||
|
||||
C.atg_arange_out1(ptr, out, start, end)
|
||||
}
|
||||
|
|
|
@ -193,18 +193,19 @@ func (ts *Tensor) Idx(index interface{}) (retVal Tensor) {
|
|||
func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
||||
|
||||
// Make sure number of non-newaxis is not exceed number of dimensions
|
||||
var nonNewAxis []TensorIndexer
|
||||
var numNewAxis int = 0
|
||||
for _, ti := range indexSpec {
|
||||
if reflect.ValueOf(ti).String() != "InsertNewAxis" {
|
||||
nonNewAxis = append(nonNewAxis, ti)
|
||||
if reflect.TypeOf(ti).Name() == "InsertNewAxis" {
|
||||
numNewAxis += 1
|
||||
}
|
||||
}
|
||||
|
||||
tsShape, err := ts.Size()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
tsLen := len(tsShape)
|
||||
if len(nonNewAxis) > tsLen {
|
||||
if len(indexSpec) > tsLen+numNewAxis {
|
||||
err = fmt.Errorf("Too many indices for tensor of dimension %v\n", tsLen)
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -253,6 +254,7 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
|
||||
switch reflect.TypeOf(spec).Name() {
|
||||
case "InsertNewAxis":
|
||||
fmt.Println(currIdx)
|
||||
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
|
|
|
@ -1966,3 +1966,123 @@ func (ts Tensor) MustReflectionPad2d(paddingData []int64) (retVal Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Arange(end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgArange(ptr, end.cscalar, kind.CInt(), device.CInt())
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
func MustArange(end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor) {
|
||||
|
||||
retVal, err := Arange(end, kind, device)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Arange1(start Scalar, end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgArange1(ptr, start.cscalar, end.cscalar, kind.CInt(), device.CInt())
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
func MustArange1(start, end Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor) {
|
||||
|
||||
retVal, err := Arange1(start, end, kind, device)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Arange2(start Scalar, end Scalar, step Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgArange2(ptr, start.cscalar, end.cscalar, step.cscalar, kind.CInt(), device.CInt())
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustArange2(start Scalar, end Scalar, step Scalar, kind gotch.DType, device gotch.Device) (retVal Tensor) {
|
||||
|
||||
retVal, err := Arange2(start, end, step, kind, device)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func ArangeOut(out Tensor, end Scalar) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgArangeOut(ptr, out.ctensor, end.cscalar)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustArangeOut(out Tensor, end Scalar) (retVal Tensor) {
|
||||
retVal, err := ArangeOut(out, end)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func ArangeOut1(out Tensor, start, end Scalar) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgArangeOut1(ptr, out.ctensor, start.cscalar, end.cscalar)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustArangeOut1(out Tensor, start, end Scalar) (retVal Tensor) {
|
||||
retVal, err := ArangeOut1(out, start, end)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user