WIP(wrapper/tensor): CopyData

This commit is contained in:
sugarme 2020-06-08 13:28:07 +10:00
parent c85aa7d6c4
commit fe6c76a2b8
7 changed files with 220 additions and 9 deletions

View File

@ -0,0 +1,39 @@
package main
import (
"fmt"
"log"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
// TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool
// It is still running but get wrong result.
// data := [][]int16{
// {1, 1, 1, 2, 2, 2, 3, 3},
// {1, 1, 1, 2, 2, 2, 4, 4},
// }
// shape := []int64{2, 8}
data := []int16{1, 1, 1, 2, 2, 2, 3, 3}
shape := []int64{1, 8}
ts, err := wrapper.NewTensorFromData(data, shape)
if err != nil {
log.Fatal(err)
}
ts.Print()
numel := uint(11)
// dst := make([]uint8, numel)
var dst = make([]uint8, 1)
ts.MustCopyData(dst, numel)
fmt.Println(dst)
}

View File

@ -20,10 +20,10 @@ func main() {
y.Backward()
xgrad := x.MustGrad()
xgrad.Print()
xgrad.Print() // [2.0]
z.Backward()
xgrad = x.MustGrad()
xgrad.Print()
xgrad.Print() // [5.0] due to accumulated 2.0 + 3.0
}

View File

@ -0,0 +1,54 @@
package main
import (
"fmt"
"log"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
x := wrapper.TensorFrom([]float64{2.0})
x = x.MustSetRequiresGrad(true)
x.ZeroGrad()
xmul := wrapper.TensorFrom([]float64{3.0})
xadd := wrapper.TensorFrom([]float64{5.0})
x1 := x.MustMul(xmul)
x2 := x1.MustMul(xmul)
x3 := x2.MustMul(xmul)
y := x3.MustAdd(xadd)
inputs := []wrapper.Tensor{x}
dy_over_dx, err := wrapper.RunBackward([]wrapper.Tensor{y}, inputs, true, true)
if err != nil {
log.Fatal(err)
}
fmt.Printf("dy_over_dx length: %v\n", len(dy_over_dx))
// dy_over_dx1 := dy_over_dx[0]
// err = dy_over_dx1.Backward()
// if err != nil {
// log.Fatalf("Errors:\n, %v", err)
// }
dy_over_dx[0].MustBackward()
x.MustGrad().Print()
}
/* // Compute a second order derivative using run_backward.
* let mut x = Tensor::from(42.0).set_requires_grad(true);
* let y = &x * &x * &x + &x + &x * &x;
* x.zero_grad();
* let dy_over_dx = Tensor::run_backward(&[y], &[&x], true, true);
* assert_eq!(dy_over_dx.len(), 1);
* let dy_over_dx = &dy_over_dx[0];
* dy_over_dx.backward();
* let dy_over_dx2 = x.grad();
* assert_eq!(f64::from(&dy_over_dx2), 254.0); */

View File

@ -40,4 +40,7 @@ func main() {
log.Fatal(err)
}
fmt.Printf("First element address: %v\n", ele1)
fmt.Printf("Number of tensor elements: %v\n", ts.Numel())
}

View File

@ -151,3 +151,11 @@ func AtRunBackward(tensorsPtr *Ctensor, ntensors int, inputsPtr *Ctensor, ninput
ccreateGraph := *(*C.int)(unsafe.Pointer(&createGraph))
C.at_run_backward(tensorsPtr, cntensors, inputsPtr, cninputs, outputsPtr, ckeepGraph, ccreateGraph)
}
// void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
func AtCopyData(tensor Ctensor, vs unsafe.Pointer, numel uint, element_size_in_bytes uint) {
ctensor := (C.tensor)(tensor)
cnumel := *(*C.size_t)(unsafe.Pointer(&numel))
celement_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&element_size_in_bytes))
C.at_copy_data(ctensor, vs, cnumel, celement_size_in_bytes)
}

View File

@ -181,7 +181,7 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
return retVal
}
func (ts *Tensor) AddG(other Tensor) (err error) {
func (ts Tensor) AddG(other Tensor) (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd(ptr, ts.ctensor, other.ctensor)
@ -190,12 +190,12 @@ func (ts *Tensor) AddG(other Tensor) (err error) {
return err
}
ts = &Tensor{ctensor: *ptr}
ts = Tensor{ctensor: *ptr}
return nil
}
func (ts *Tensor) MustAddG(other Tensor) {
func (ts Tensor) MustAddG(other Tensor) {
err := ts.AddG(other)
if err != nil {
log.Fatal(err)

View File

@ -56,6 +56,14 @@ func (ts Tensor) Size() (retVal []int64, err error) {
return retVal, nil
}
func (ts Tensor) MustSize() (retVal []int64) {
retVal, err := ts.Size()
if err != nil {
log.Fatal(err)
}
return retVal
}
// Size1 returns the tensor size for 1D tensors.
func (ts Tensor) Size1() (retVal int64, err error) {
shape, err := ts.Size()
@ -388,7 +396,7 @@ func (ts Tensor) ZeroGrad() {
// grad.MustDetach_().MustZero_()
// https://www.calhoun.io/using-functional-options-instead-of-method-chaining-in-go/
detach := grad.MustDetach_()
_ = detach.MustZero_()
detach.MustZero_()
}
}
@ -414,15 +422,12 @@ func (ts Tensor) MustBackward() {
// RunBackward runs the backward ...
func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraphB bool) (retVal []Tensor, err error) {
// NOTE: outputs is a slice of tensors with length = len(inputs)
// outputsPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// defer C.free(unsafe.Pointer(outputsPtr))
var outputsPtr []*lib.Ctensor
// TODO: Are they allocated continouslly???
for i := 0; i < len(inputs); i++ {
outputPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// defer C.free(unsafe.Pointer(outputPtr))
outputsPtr = append(outputsPtr, outputPtr)
// retVal = append(retVal, Tensor{ctensor: *outputPtr})
}
// Get first element pointer
@ -451,3 +456,105 @@ func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraph
return retVal, nil
}
// CopyDataUint8 copies `numel` elements from `self` to `dst`.
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts Tensor) CopyDataUint8(dst []uint8, numel uint) (err error) {
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
// there will be memory leak and or out of range error.
if len(dst) < int(numel) {
err = fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", len(dst), numel)
return err
}
vs := unsafe.Pointer(&dst[0])
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
err := ts.CopyDataUint8(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// CopyData copies `numel` elements from `self` to `dst`.
// `dst` should be a slice of Go type equivalent to tensor type.
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts Tensor) CopyData(dst interface{}, numel uint) (err error) {
dtype, dlen, err := DataCheck(dst)
if err != nil {
return err
}
if dlen < int(numel) {
err = fmt.Errorf("CopyDataUint8 Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
return err
}
if ts.DType() != dtype {
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
var vs unsafe.Pointer
switch dtype {
case gotch.Uint8:
vs = unsafe.Pointer(&dst.([]uint8)[0])
case gotch.Int8:
vs = unsafe.Pointer(&dst.([]int8)[0])
case gotch.Int16:
vs = unsafe.Pointer(&dst.([]int16)[0])
case gotch.Int:
vs = unsafe.Pointer(&dst.([]int32)[0])
case gotch.Int64:
vs = unsafe.Pointer(&dst.([]int64)[0])
case gotch.Float:
vs = unsafe.Pointer(&dst.([]float32)[0])
case gotch.Double:
vs = unsafe.Pointer(&dst.([]float64)[0])
case gotch.Bool:
vs = unsafe.Pointer(&dst.([]bool)[0])
default:
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
elt_size_in_bytes, err := gotch.DTypeSize(dtype.(gotch.DType))
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustCopyData(dst interface{}, numel uint) {
err := ts.CopyData(dst, numel)
if err != nil {
log.Fatal(err)
}
}
// Numel returns the total number of elements stored in a tensor.
func (ts Tensor) Numel() (retVal uint) {
var shape []int64
shape = ts.MustSize()
return uint(FlattenDim(shape))
}