WIP(wrapper/tensor): CopyData
This commit is contained in:
parent
c85aa7d6c4
commit
fe6c76a2b8
39
example/tensor-copy-data/main.go
Normal file
39
example/tensor-copy-data/main.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// TODO: Check Go type of data and tensor DType
|
||||
// For. if data is []int and DType is Bool
|
||||
// It is still running but get wrong result.
|
||||
// data := [][]int16{
|
||||
// {1, 1, 1, 2, 2, 2, 3, 3},
|
||||
// {1, 1, 1, 2, 2, 2, 4, 4},
|
||||
// }
|
||||
// shape := []int64{2, 8}
|
||||
|
||||
data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
|
||||
shape := []int64{1, 8}
|
||||
|
||||
ts, err := wrapper.NewTensorFromData(data, shape)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ts.Print()
|
||||
|
||||
numel := uint(11)
|
||||
// dst := make([]uint8, numel)
|
||||
var dst = make([]uint8, 1)
|
||||
|
||||
ts.MustCopyData(dst, numel)
|
||||
|
||||
fmt.Println(dst)
|
||||
|
||||
}
|
|
@ -20,10 +20,10 @@ func main() {
|
|||
|
||||
y.Backward()
|
||||
xgrad := x.MustGrad()
|
||||
xgrad.Print()
|
||||
xgrad.Print() // [2.0]
|
||||
z.Backward()
|
||||
xgrad = x.MustGrad()
|
||||
xgrad.Print()
|
||||
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
|
||||
|
||||
}
|
||||
|
||||
|
|
54
example/tensor-run-backward/main.go
Normal file
54
example/tensor-run-backward/main.go
Normal file
|
@ -0,0 +1,54 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
x := wrapper.TensorFrom([]float64{2.0})
|
||||
x = x.MustSetRequiresGrad(true)
|
||||
x.ZeroGrad()
|
||||
|
||||
xmul := wrapper.TensorFrom([]float64{3.0})
|
||||
xadd := wrapper.TensorFrom([]float64{5.0})
|
||||
|
||||
x1 := x.MustMul(xmul)
|
||||
x2 := x1.MustMul(xmul)
|
||||
x3 := x2.MustMul(xmul)
|
||||
|
||||
y := x3.MustAdd(xadd)
|
||||
|
||||
inputs := []wrapper.Tensor{x}
|
||||
|
||||
dy_over_dx, err := wrapper.RunBackward([]wrapper.Tensor{y}, inputs, true, true)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("dy_over_dx length: %v\n", len(dy_over_dx))
|
||||
|
||||
// dy_over_dx1 := dy_over_dx[0]
|
||||
// err = dy_over_dx1.Backward()
|
||||
// if err != nil {
|
||||
// log.Fatalf("Errors:\n, %v", err)
|
||||
// }
|
||||
|
||||
dy_over_dx[0].MustBackward()
|
||||
|
||||
x.MustGrad().Print()
|
||||
|
||||
}
|
||||
|
||||
/* // Compute a second order derivative using run_backward.
|
||||
* let mut x = Tensor::from(42.0).set_requires_grad(true);
|
||||
* let y = &x * &x * &x + &x + &x * &x;
|
||||
* x.zero_grad();
|
||||
* let dy_over_dx = Tensor::run_backward(&[y], &[&x], true, true);
|
||||
* assert_eq!(dy_over_dx.len(), 1);
|
||||
* let dy_over_dx = &dy_over_dx[0];
|
||||
* dy_over_dx.backward();
|
||||
* let dy_over_dx2 = x.grad();
|
||||
* assert_eq!(f64::from(&dy_over_dx2), 254.0); */
|
|
@ -40,4 +40,7 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("First element address: %v\n", ele1)
|
||||
|
||||
fmt.Printf("Number of tensor elements: %v\n", ts.Numel())
|
||||
|
||||
}
|
||||
|
|
|
@ -151,3 +151,11 @@ func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninput
|
|||
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
|
||||
C.at_run_backward(tensorsPtr, cntensors, inputsPtr, cninputs, outputsPtr, ckeepGraph, ccreateGraph)
|
||||
}
|
||||
|
||||
// void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
|
||||
func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
|
||||
ctensor := (C.tensor)(tensor)
|
||||
cnumel := *(*C.size_t)(unsafe.Pointer(&numel))
|
||||
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
|
||||
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
|
||||
}
|
||||
|
|
|
@ -181,7 +181,7 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts *Tensor) AddG(other Tensor) (err error) {
|
||||
func (ts Tensor) AddG(other Tensor) (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
|
||||
|
@ -190,12 +190,12 @@ func (ts *Tensor) AddG(other Tensor) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
ts = &Tensor{ctensor: *ptr}
|
||||
ts = Tensor{ctensor: *ptr}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustAddG(other Tensor) {
|
||||
func (ts Tensor) MustAddG(other Tensor) {
|
||||
err := ts.AddG(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
|
|
@ -56,6 +56,14 @@ func (ts Tensor) Size() (retVal []int64, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSize() (retVal []int64) {
|
||||
retVal, err := ts.Size()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for 1D tensors.
|
||||
func (ts Tensor) Size1() (retVal int64, err error) {
|
||||
shape, err := ts.Size()
|
||||
|
@ -388,7 +396,7 @@ func (ts Tensor) ZeroGrad() {
|
|||
// grad.MustDetach_().MustZero_()
|
||||
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
|
||||
detach := grad.MustDetach_()
|
||||
_ = detach.MustZero_()
|
||||
detach.MustZero_()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -414,15 +422,12 @@ func (ts Tensor) MustBackward() {
|
|||
// RunBackward runs the backward ...
|
||||
func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraphB bool) (retVal []Tensor, err error) {
|
||||
// NOTE: outputs is a slice of tensors with length = len(inputs)
|
||||
// outputsPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// defer C.free(unsafe.Pointer(outputsPtr))
|
||||
var outputsPtr []*lib.Ctensor
|
||||
// TODO: Are they allocated continouslly???
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// defer C.free(unsafe.Pointer(outputPtr))
|
||||
outputsPtr = append(outputsPtr, outputPtr)
|
||||
// retVal = append(retVal, Tensor{ctensor: *outputPtr})
|
||||
}
|
||||
|
||||
// Get first element pointer
|
||||
|
@ -451,3 +456,105 @@ func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraph
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// CopyDataUint8 copies `numel` elements from `self` to `dst`.
|
||||
//
|
||||
// NOTE: `dst` located in Go memory. Should it be?
|
||||
func (ts Tensor) CopyDataUint8(dst []uint8, numel uint) (err error) {
|
||||
|
||||
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
|
||||
// there will be memory leak and or out of range error.
|
||||
if len(dst) < int(numel) {
|
||||
err = fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", len(dst), numel)
|
||||
return err
|
||||
}
|
||||
|
||||
vs := unsafe.Pointer(&dst[0])
|
||||
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
|
||||
err := ts.CopyDataUint8(dst, numel)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// CopyData copies `numel` elements from `self` to `dst`.
|
||||
// `dst` should be a slice of Go type equivalent to tensor type.
|
||||
//
|
||||
// NOTE: `dst` located in Go memory. Should it be?
|
||||
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
|
||||
|
||||
dtype, dlen, err := DataCheck(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if dlen < int(numel) {
|
||||
err = fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
|
||||
return err
|
||||
}
|
||||
|
||||
if ts.DType() != dtype {
|
||||
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
||||
return err
|
||||
}
|
||||
|
||||
var vs unsafe.Pointer
|
||||
switch dtype {
|
||||
case gotch.Uint8:
|
||||
vs = unsafe.Pointer(&dst.([]uint8)[0])
|
||||
case gotch.Int8:
|
||||
vs = unsafe.Pointer(&dst.([]int8)[0])
|
||||
case gotch.Int16:
|
||||
vs = unsafe.Pointer(&dst.([]int16)[0])
|
||||
case gotch.Int:
|
||||
vs = unsafe.Pointer(&dst.([]int32)[0])
|
||||
case gotch.Int64:
|
||||
vs = unsafe.Pointer(&dst.([]int64)[0])
|
||||
case gotch.Float:
|
||||
vs = unsafe.Pointer(&dst.([]float32)[0])
|
||||
case gotch.Double:
|
||||
vs = unsafe.Pointer(&dst.([]float64)[0])
|
||||
case gotch.Bool:
|
||||
vs = unsafe.Pointer(&dst.([]bool)[0])
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
||||
return err
|
||||
}
|
||||
|
||||
elt_size_in_bytes, err := gotch.DTypeSize(dtype.(gotch.DType))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustCopyData(dst interface{}, numel uint) {
|
||||
err := ts.CopyData(dst, numel)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Numel returns the total number of elements stored in a tensor.
|
||||
func (ts Tensor) Numel() (retVal uint) {
|
||||
var shape []int64
|
||||
shape = ts.MustSize()
|
||||
return uint(FlattenDim(shape))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user