feat(wrapper/tensor, tenser-generated-sample): added some tensor methods
This commit is contained in:
parent
af9c1aaeb1
commit
8f32baff08
|
@ -43,4 +43,23 @@ func main() {
|
|||
|
||||
fmt.Printf("Number of tensor elements: %v\n", ts.Numel())
|
||||
|
||||
clone := ts.MustShallowClone()
|
||||
clone.Print()
|
||||
|
||||
atGet := ts.MustGet(1)
|
||||
atGet.Print() // 29.7
|
||||
|
||||
atGet = ts.MustGet(0)
|
||||
atGet.Print() // 1.3
|
||||
|
||||
dst, err := wrapper.NewTensorFromData([]int64{1, 2}, []int64{1, 2})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
dst = dst.MustTotype(ts.DType())
|
||||
|
||||
wrapper.MustCopy_(dst, ts)
|
||||
dst.Print()
|
||||
|
||||
}
|
||||
|
|
|
@ -61,3 +61,9 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|||
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_add(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_totype(tensor *, tensor self, int scalar_type);
|
||||
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
||||
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
||||
C.atg_totype(ptr, self, cscalar_type)
|
||||
}
|
||||
|
|
|
@ -159,3 +159,23 @@ func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_b
|
|||
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
|
||||
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
|
||||
}
|
||||
|
||||
// tensor at_shallow_clone(tensor);
|
||||
func AtShallowClone(ts Ctensor) Ctensor {
|
||||
ctensor := (C.tensor)(ts)
|
||||
return C.at_shallow_clone(ctensor)
|
||||
}
|
||||
|
||||
// tensor at_get(tensor, int index);
|
||||
func AtGet(ts Ctensor, index int) Ctensor {
|
||||
ctensor := (C.tensor)(ts)
|
||||
cindex := *(*C.int)(unsafe.Pointer(&index))
|
||||
return C.at_get(ctensor, cindex)
|
||||
}
|
||||
|
||||
// void at_copy_(tensor dst, tensor src);
|
||||
func AtCopy_(dst Ctensor, src Ctensor) {
|
||||
cdst := (C.tensor)(dst)
|
||||
csrc := (C.tensor)(src)
|
||||
C.at_copy_(cdst, csrc)
|
||||
}
|
||||
|
|
|
@ -201,3 +201,33 @@ func (ts Tensor) MustAddG(other Tensor) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Totype casts type of tensor to a new tensor with specified DType
|
||||
func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
cint, err := gt.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
lib.AtgTotype(ptr, ts.ctensor, cint)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Totype casts type of tensor to a new tensor with specified DType. It will
|
||||
// panic if error
|
||||
func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
|
||||
retVal, err := ts.Totype(dtype)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -567,3 +567,69 @@ func (ts Tensor) Numel() (retVal uint) {
|
|||
shape = ts.MustSize()
|
||||
return uint(FlattenDim(shape))
|
||||
}
|
||||
|
||||
// ShallowCopy returns a new tensor that share storage with the input tensor.
|
||||
func (ts Tensor) ShallowClone() (retVal Tensor, err error) {
|
||||
|
||||
ctensor := lib.AtShallowClone(ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// MustShallowClone returns a new tensor that share storage with the input
|
||||
// tensor. It will panic if error occurred
|
||||
func (ts Tensor) MustShallowClone() (retVal Tensor) {
|
||||
retVal, err := ts.ShallowClone()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Get gets the sub-tensor at the given index.
|
||||
func (ts Tensor) Get(index int) (retVal Tensor, err error) {
|
||||
|
||||
ctensor := lib.AtGet(ts.ctensor, index)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = Tensor{ctensor}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// MustGet gets the sub-tensor at the given index. It will panic if error
|
||||
// occurred.
|
||||
func (ts Tensor) MustGet(index int) (retVal Tensor) {
|
||||
retVal, err := ts.Get(index)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Copy_ copies in-place values from the argument tensor to the input tensor.
|
||||
func Copy_(self, src Tensor) (err error) {
|
||||
lib.AtCopy_(self.ctensor, src.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustCopy_ copies in-place values from the argument tensor to the input tensor.
|
||||
// It will panic if error occurred.
|
||||
func MustCopy_(self, src Tensor) {
|
||||
if err := Copy_(self, src); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user