feat(wrapper/tensor, tenser-generated-sample): added some tensor methods

This commit is contained in:
sugarme 2020-06-08 16:13:23 +10:00
parent af9c1aaeb1
commit 8f32baff08
5 changed files with 141 additions and 0 deletions

View File

@ -43,4 +43,23 @@ func main() {
fmt.Printf("Number of tensor elements: %v\n", ts.Numel())
clone := ts.MustShallowClone()
clone.Print()
atGet := ts.MustGet(1)
atGet.Print() // 29.7
atGet = ts.MustGet(0)
atGet.Print() // 1.3
dst, err := wrapper.NewTensorFromData([]int64{1, 2}, []int64{1, 2})
if err != nil {
panic(err)
}
dst = dst.MustTotype(ts.DType())
wrapper.MustCopy_(dst, ts)
dst.Print()
}

View File

@ -61,3 +61,9 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_add(ptr, self, other)
}
// void atg_totype(tensor *, tensor self, int scalar_type);
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
C.atg_totype(ptr, self, cscalar_type)
}

View File

@ -159,3 +159,23 @@ func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_b
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
}
// tensor at_shallow_clone(tensor);
func AtShallowClone(ts Ctensor) Ctensor {
ctensor := (C.tensor)(ts)
return C.at_shallow_clone(ctensor)
}
// tensor at_get(tensor, int index);
func AtGet(ts Ctensor, index int) Ctensor {
ctensor := (C.tensor)(ts)
cindex := *(*C.int)(unsafe.Pointer(&index))
return C.at_get(ctensor, cindex)
}
// void at_copy_(tensor dst, tensor src);
func AtCopy_(dst Ctensor, src Ctensor) {
cdst := (C.tensor)(dst)
csrc := (C.tensor)(src)
C.at_copy_(cdst, csrc)
}

View File

@ -201,3 +201,33 @@ func (ts Tensor) MustAddG(other Tensor) {
log.Fatal(err)
}
}
// Totype casts type of tensor to a new tensor with specified DType
func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
cint, err := gt.DType2CInt(dtype)
if err != nil {
return retVal, err
}
lib.AtgTotype(ptr, ts.ctensor, cint)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
// Totype casts type of tensor to a new tensor with specified DType. It will
// panic if error
func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
retVal, err := ts.Totype(dtype)
if err != nil {
log.Fatal(err)
}
return retVal
}

View File

@ -567,3 +567,69 @@ func (ts Tensor) Numel() (retVal uint) {
shape = ts.MustSize()
return uint(FlattenDim(shape))
}
// ShallowCopy returns a new tensor that share storage with the input tensor.
func (ts Tensor) ShallowClone() (retVal Tensor, err error) {
ctensor := lib.AtShallowClone(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
// MustShallowClone returns a new tensor that share storage with the input
// tensor. It will panic if error occurred
func (ts Tensor) MustShallowClone() (retVal Tensor) {
retVal, err := ts.ShallowClone()
if err != nil {
log.Fatal(err)
}
return retVal
}
// Get gets the sub-tensor at the given index.
func (ts Tensor) Get(index int) (retVal Tensor, err error) {
ctensor := lib.AtGet(ts.ctensor, index)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
return retVal, nil
}
// MustGet gets the sub-tensor at the given index. It will panic if error
// occurred.
func (ts Tensor) MustGet(index int) (retVal Tensor) {
retVal, err := ts.Get(index)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Copy_ copies in-place values from the argument tensor to the input tensor.
func Copy_(self, src Tensor) (err error) {
lib.AtCopy_(self.ctensor, src.ctensor)
if err = TorchErr(); err != nil {
return err
}
return nil
}
// MustCopy_ copies in-place values from the argument tensor to the input tensor.
// It will panic if error occurred.
func MustCopy_(self, src Tensor) {
if err := Copy_(self, src); err != nil {
log.Fatal(err)
}
}