diff --git a/libtch/c-generated-sample.go b/libtch/c-generated-sample.go index 8ee925c..c4a5be5 100644 --- a/libtch/c-generated-sample.go +++ b/libtch/c-generated-sample.go @@ -270,6 +270,16 @@ func AtgDiv(ptr *Ctensor, self Ctensor, other Ctensor) { C.atg_div(ptr, self, other) } +// void atg_div_(tensor *, tensor self, tensor other); +func AtgDiv_(ptr *Ctensor, self Ctensor, other Ctensor) { + C.atg_div_(ptr, self, other) +} + +// void atg_div_1(tensor *, tensor self, scalar other); +func AtgDiv1_(ptr *Ctensor, self Ctensor, other Cscalar) { + C.atg_div_1(ptr, self, other) +} + // void atg_randperm(tensor *, int64_t n, int options_kind, int options_device); func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) { cn := *(*C.int64_t)(unsafe.Pointer(&n)) @@ -358,6 +368,11 @@ func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) { C.atg_sub_(ptr, self, other) } +// void atg_sub_1(tensor *, tensor self, scalar other); +func AtgSub1_(ptr *Ctensor, self Ctensor, other Cscalar) { + C.atg_sub_1(ptr, self, other) +} + // void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups); func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) { cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0])) diff --git a/tensor/iter.go b/tensor/iter.go index 6d97d26..ef7475e 100644 --- a/tensor/iter.go +++ b/tensor/iter.go @@ -3,31 +3,37 @@ package tensor import ( "fmt" "log" - "reflect" + + "github.com/sugarme/gotch" ) type Iterator interface { - Next() interface{} + Next() (item interface{}, ok bool) } type Iterable struct { Index int64 Len int64 Content Tensor - ItemKind reflect.Kind + ItemKind gotch.DType } // Next implements Iterator interface -func (it *Iterable) Next() (retVal interface{}) { +func (it *Iterable) Next() (retVal interface{}, ok bool) { + + if it.Index == it.Len { + return retVal, false + } + var err error - switch it.ItemKind { - case reflect.Int64: + switch it.ItemKind.Kind().String() { + case "int64": retVal, err = it.Content.Int64Value([]int64{it.Index}) if err != nil { log.Fatal(err) } it.Index += 1 - case reflect.Float64: + case "float64": retVal, err = it.Content.Float64Value([]int64{it.Index}) if err != nil { log.Fatal(err) @@ -38,24 +44,25 @@ func (it *Iterable) Next() (retVal interface{}) { log.Fatal(err) } - return retVal + return retVal, true } // Iter creates an iterable object with specified item type. -func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) { +func (ts Tensor) Iter(dtype gotch.DType) (retVal Iterable, err error) { num, err := ts.Size1() // size for 1D tensor if err != nil { return retVal, err } - content, err := ts.ShallowClone() + tmp, err := ts.ShallowClone() if err != nil { return retVal, err } + content := tmp.MustTotype(dtype, true) return Iterable{ Index: 0, Len: num, Content: content, - ItemKind: kind, + ItemKind: dtype, }, nil } diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index f55151a..47ed869 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -883,6 +883,22 @@ func (ts Tensor) MustDiv(other Tensor, del bool) (retVal Tensor) { return retVal } +func (ts Tensor) Div_(other Tensor) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + lib.AtgDiv_(ptr, ts.ctensor, other.ctensor) + if err := TorchErr(); err != nil { + log.Fatal(err) + } +} + +func (ts Tensor) Div1_(other Scalar) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + lib.AtgDiv1_(ptr, ts.ctensor, other.cscalar) + if err := TorchErr(); err != nil { + log.Fatal(err) + } +} + func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) @@ -1221,6 +1237,15 @@ func (ts Tensor) Sub_(other Tensor) { } } +func (ts Tensor) Sub1_(other Scalar) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + lib.AtgSub1_(ptr, ts.ctensor, other.cscalar) + err := TorchErr() + if err != nil { + log.Fatal(err) + } +} + func Conv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) diff --git a/tensor/tensor_test.go b/tensor/tensor_test.go new file mode 100644 index 0000000..b90f9b8 --- /dev/null +++ b/tensor/tensor_test.go @@ -0,0 +1,87 @@ +package tensor_test + +import ( + "reflect" + "testing" + + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" +) + +func TestInplaceAssign(t *testing.T) { + tensor := ts.MustOfSlice([]int64{3, 1, 4, 1, 5}) + + tensor.Add1_(ts.IntScalar(1)) + tensor.Mul1_(ts.IntScalar(2)) + tensor.Sub1_(ts.IntScalar(1)) + + want := []int64{7, 3, 9, 3, 11} + got := tensor.Vals() + + if !reflect.DeepEqual(want, got) { + t.Errorf("Expected tensor values: %v\n", want) + t.Errorf("Got tensor values: %v\n", got) + } +} + +func TestConstantOp(t *testing.T) { + tensor := ts.MustOfSlice([]int64{3, 9, 3, 11}) + resTs1 := tensor.MustMul1(ts.IntScalar(-1), true) + + want1 := []int64{-3, -9, -3, -11} + got1 := resTs1.Vals() + if !reflect.DeepEqual(want1, got1) { + t.Errorf("Expected tensor values: %v\n", want1) + t.Errorf("Got tensor values: %v\n", got1) + } + + // TODO: more ops + +} + +func TestIter(t *testing.T) { + + tensor := ts.MustOfSlice([]int64{3, 9, 3, 11}) + + iter, err := tensor.Iter(gotch.Int64) + if err != nil { + panic(err) + } + + want := []int64{3, 9, 3, 11} + var got []int64 + + for { + item, ok := iter.Next() + if !ok { + break + } + got = append(got, item.(int64)) + } + + if !reflect.DeepEqual(want, got) { + t.Errorf("Expected tensor values: %v\n", want) + t.Errorf("Got tensor values: %v\n", got) + } + + tensor1 := ts.MustOfSlice([]float64{3.14, 15.926, 5.3589, 79.0}) + iter1, err := tensor1.Iter(gotch.Double) + + want1 := []float64{3.14, 15.926, 5.3589, 79.0} + var got1 []float64 + for { + item, ok := iter1.Next() + if !ok { + break + } + + got1 = append(got1, item.(float64)) + } + + if !reflect.DeepEqual(want1, got1) { + t.Errorf("Expected tensor values: %v\n", want1) + t.Errorf("Got tensor values: %v\n", got1) + } +} + +// TODO: more tests