feat(wrapper/tensor, tenser-generated-sample): added some tensor methods
This commit is contained in:
parent
af9c1aaeb1
commit
8f32baff08
|
@ -43,4 +43,23 @@ func main() {
|
||||||
|
|
||||||
fmt.Printf("Number of tensor elements: %v\n", ts.Numel())
|
fmt.Printf("Number of tensor elements: %v\n", ts.Numel())
|
||||||
|
|
||||||
|
clone := ts.MustShallowClone()
|
||||||
|
clone.Print()
|
||||||
|
|
||||||
|
atGet := ts.MustGet(1)
|
||||||
|
atGet.Print() // 29.7
|
||||||
|
|
||||||
|
atGet = ts.MustGet(0)
|
||||||
|
atGet.Print() // 1.3
|
||||||
|
|
||||||
|
dst, err := wrapper.NewTensorFromData([]int64{1, 2}, []int64{1, 2})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst = dst.MustTotype(ts.DType())
|
||||||
|
|
||||||
|
wrapper.MustCopy_(dst, ts)
|
||||||
|
dst.Print()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -61,3 +61,9 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
C.atg_add(ptr, self, other)
|
C.atg_add(ptr, self, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_totype(tensor *, tensor self, int scalar_type);
|
||||||
|
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
||||||
|
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
||||||
|
C.atg_totype(ptr, self, cscalar_type)
|
||||||
|
}
|
||||||
|
|
|
@ -159,3 +159,23 @@ func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_b
|
||||||
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
|
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
|
||||||
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
|
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tensor at_shallow_clone(tensor);
|
||||||
|
func AtShallowClone(ts Ctensor) Ctensor {
|
||||||
|
ctensor := (C.tensor)(ts)
|
||||||
|
return C.at_shallow_clone(ctensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor at_get(tensor, int index);
|
||||||
|
func AtGet(ts Ctensor, index int) Ctensor {
|
||||||
|
ctensor := (C.tensor)(ts)
|
||||||
|
cindex := *(*C.int)(unsafe.Pointer(&index))
|
||||||
|
return C.at_get(ctensor, cindex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void at_copy_(tensor dst, tensor src);
|
||||||
|
func AtCopy_(dst Ctensor, src Ctensor) {
|
||||||
|
cdst := (C.tensor)(dst)
|
||||||
|
csrc := (C.tensor)(src)
|
||||||
|
C.at_copy_(cdst, csrc)
|
||||||
|
}
|
||||||
|
|
|
@ -201,3 +201,33 @@ func (ts Tensor) MustAddG(other Tensor) {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Totype casts type of tensor to a new tensor with specified DType
|
||||||
|
func (ts Tensor) Totype(dtype gt.DType) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
cint, err := gt.DType2CInt(dtype)
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgTotype(ptr, ts.ctensor, cint)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Totype casts type of tensor to a new tensor with specified DType. It will
|
||||||
|
// panic if error
|
||||||
|
func (ts Tensor) MustTotype(dtype gt.DType) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Totype(dtype)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
|
@ -567,3 +567,69 @@ func (ts Tensor) Numel() (retVal uint) {
|
||||||
shape = ts.MustSize()
|
shape = ts.MustSize()
|
||||||
return uint(FlattenDim(shape))
|
return uint(FlattenDim(shape))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ShallowCopy returns a new tensor that share storage with the input tensor.
|
||||||
|
func (ts Tensor) ShallowClone() (retVal Tensor, err error) {
|
||||||
|
|
||||||
|
ctensor := lib.AtShallowClone(ts.ctensor)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustShallowClone returns a new tensor that share storage with the input
|
||||||
|
// tensor. It will panic if error occurred
|
||||||
|
func (ts Tensor) MustShallowClone() (retVal Tensor) {
|
||||||
|
retVal, err := ts.ShallowClone()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get gets the sub-tensor at the given index.
|
||||||
|
func (ts Tensor) Get(index int) (retVal Tensor, err error) {
|
||||||
|
|
||||||
|
ctensor := lib.AtGet(ts.ctensor, index)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
retVal = Tensor{ctensor}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustGet gets the sub-tensor at the given index. It will panic if error
|
||||||
|
// occurred.
|
||||||
|
func (ts Tensor) MustGet(index int) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Get(index)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy_ copies in-place values from the argument tensor to the input tensor.
|
||||||
|
func Copy_(self, src Tensor) (err error) {
|
||||||
|
lib.AtCopy_(self.ctensor, src.ctensor)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustCopy_ copies in-place values from the argument tensor to the input tensor.
|
||||||
|
// It will panic if error occurred.
|
||||||
|
func MustCopy_(self, src Tensor) {
|
||||||
|
if err := Copy_(self, src); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user