diff --git a/example/tensor1/main.go b/example/tensor1/main.go index f93df91..bfc8587 100644 --- a/example/tensor1/main.go +++ b/example/tensor1/main.go @@ -43,4 +43,23 @@ func main() { fmt.Printf("Number of tensor elements: %v\n", ts.Numel()) + clone := ts.MustShallowClone() + clone.Print() + + atGet := ts.MustGet(1) + atGet.Print() // 29.7 + + atGet = ts.MustGet(0) + atGet.Print() // 1.3 + + dst, err := wrapper.NewTensorFromData([]int64{1, 2}, []int64{1, 2}) + if err != nil { + panic(err) + } + + dst = dst.MustTotype(ts.DType()) + + wrapper.MustCopy_(dst, ts) + dst.Print() + } diff --git a/libtch/c-generated-sample.go b/libtch/c-generated-sample.go index 653f033..5044619 100644 --- a/libtch/c-generated-sample.go +++ b/libtch/c-generated-sample.go @@ -61,3 +61,9 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) { func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) { C.atg_add(ptr, self, other) } + +// void atg_totype(tensor *, tensor self, int scalar_type); +func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) { + cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type)) + C.atg_totype(ptr, self, cscalar_type) +} diff --git a/libtch/tensor.go b/libtch/tensor.go index b2a9bf7..bd7000e 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -159,3 +159,23 @@ func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_b celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes)) C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes) } + +// tensor at_shallow_clone(tensor); +func AtShallowClone(ts Ctensor) Ctensor { + ctensor := (C.tensor)(ts) + return C.at_shallow_clone(ctensor) +} + +// tensor at_get(tensor, int index); +func AtGet(ts Ctensor, index int) Ctensor { + ctensor := (C.tensor)(ts) + cindex := *(*C.int)(unsafe.Pointer(&index)) + return C.at_get(ctensor, cindex) +} + +// void at_copy_(tensor dst, tensor src); +func AtCopy_(dst Ctensor, src Ctensor) { + cdst := (C.tensor)(dst) + csrc := (C.tensor)(src) + C.at_copy_(cdst, csrc) +} diff --git a/wrapper/tensor-generated-sample.go b/wrapper/tensor-generated-sample.go index 791820d..89d3cab 100644 --- a/wrapper/tensor-generated-sample.go +++ b/wrapper/tensor-generated-sample.go @@ -201,3 +201,33 @@ func (ts Tensor) MustAddG(other Tensor) { log.Fatal(err) } } + +// Totype casts type of tensor to a new tensor with specified DType +func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + cint, err := gt.DType2CInt(dtype) + if err != nil { + return retVal, err + } + + lib.AtgTotype(ptr, ts.ctensor, cint) + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +// Totype casts type of tensor to a new tensor with specified DType. It will +// panic if error +func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) { + retVal, err := ts.Totype(dtype) + if err != nil { + log.Fatal(err) + } + + return retVal +} diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 654a334..2d335eb 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -567,3 +567,69 @@ func (ts Tensor) Numel() (retVal uint) { shape = ts.MustSize() return uint(FlattenDim(shape)) } + +// ShallowCopy returns a new tensor that share storage with the input tensor. +func (ts Tensor) ShallowClone() (retVal Tensor, err error) { + + ctensor := lib.AtShallowClone(ts.ctensor) + + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor} + + return retVal, nil +} + +// MustShallowClone returns a new tensor that share storage with the input +// tensor. It will panic if error occurred +func (ts Tensor) MustShallowClone() (retVal Tensor) { + retVal, err := ts.ShallowClone() + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// Get gets the sub-tensor at the given index. +func (ts Tensor) Get(index int) (retVal Tensor, err error) { + + ctensor := lib.AtGet(ts.ctensor, index) + if err = TorchErr(); err != nil { + return retVal, err + } + retVal = Tensor{ctensor} + + return retVal, nil +} + +// MustGet gets the sub-tensor at the given index. It will panic if error +// occurred. +func (ts Tensor) MustGet(index int) (retVal Tensor) { + retVal, err := ts.Get(index) + if err != nil { + log.Fatal(err) + } + return retVal +} + +// Copy_ copies in-place values from the argument tensor to the input tensor. +func Copy_(self, src Tensor) (err error) { + lib.AtCopy_(self.ctensor, src.ctensor) + + if err = TorchErr(); err != nil { + return err + } + + return nil +} + +// MustCopy_ copies in-place values from the argument tensor to the input tensor. +// It will panic if error occurred. +func MustCopy_(self, src Tensor) { + if err := Copy_(self, src); err != nil { + log.Fatal(err) + } +}