example/tensor: simple create a tensor from slice. Just make it work with API
This commit is contained in:
parent
816e6109ea
commit
5f167e3b67
|
@ -1,31 +1,129 @@
|
|||
package main
|
||||
|
||||
//#include <stdlib.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
t "github.com/sugarme/gotch/torch"
|
||||
t "github.com/sugarme/gotch/torch/libtch"
|
||||
)
|
||||
|
||||
type Tensor struct {
|
||||
c_tensor *t.C_tensor
|
||||
}
|
||||
|
||||
func FnOfSlice(data []float64) (retVal Tensor, err error) {
|
||||
dataLen := len(data)
|
||||
dat := unsafe.Pointer(data)
|
||||
func FnOfSlice() (retVal Tensor, err error) {
|
||||
|
||||
c_tensor := t.AtTensorOfData(dat, int64(dataLen), 1, 7, 7)
|
||||
data := []int{1, 2, 3, 4, 5, 6}
|
||||
nflattened := len(data)
|
||||
dtype := 3 // Kind.Int
|
||||
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
|
||||
|
||||
nbytes := eltSizeInBytes * int(uintptr(nflattened))
|
||||
|
||||
dataPtr := C.malloc(C.size_t(nbytes))
|
||||
|
||||
// Recall: 1 << 30 = 1 * 2 * 30
|
||||
// Ref. See more at https://stackoverflow.com/questions/48756732
|
||||
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
||||
|
||||
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||
|
||||
encodeTensor(buf, reflect.ValueOf(data), []int64{1})
|
||||
|
||||
c_tensor := t.AtTensorOfData(dataPtr, int64(nflattened), 1, uint(eltSizeInBytes), int32(dtype))
|
||||
|
||||
retVal = Tensor{c_tensor}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func numElements(shape []int) int {
|
||||
n := 1
|
||||
for _, d := range shape {
|
||||
n *= d
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
t := t.NewTensor()
|
||||
|
||||
fmt.Printf("Type of t: %v\n", reflect.TypeOf(t))
|
||||
|
||||
res, err := FnOfSlice()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
fmt.Println(res)
|
||||
}
|
||||
|
||||
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
b := byte(0)
|
||||
if v.Bool() {
|
||||
b = 1
|
||||
}
|
||||
if err := w.WriteByte(b); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case reflect.Array, reflect.Slice:
|
||||
// If current dimension is a slice, verify that it has the expected size
|
||||
// Go's type system makes that guarantee for arrays.
|
||||
if v.Kind() == reflect.Slice {
|
||||
expected := int(shape[0])
|
||||
if v.Len() != expected {
|
||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
||||
if len(shape) == 1 && v.Len() > 0 {
|
||||
switch v.Index(0).Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
return binary.Write(w, nativeEndian, v.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
subShape := shape[1:]
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
err := encodeTensor(w, v.Index(i), subShape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported type %v", v.Type())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var nativeEndian binary.ByteOrder
|
||||
|
||||
func init() {
|
||||
buf := [2]byte{}
|
||||
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
||||
|
||||
switch buf {
|
||||
case [2]byte{0xCD, 0xAB}:
|
||||
nativeEndian = binary.LittleEndian
|
||||
case [2]byte{0xAB, 0xCD}:
|
||||
nativeEndian = binary.BigEndian
|
||||
default:
|
||||
panic("Could not determine native endianness.")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,14 +11,11 @@ type Device struct {
|
|||
Value int
|
||||
}
|
||||
|
||||
type Cuda struct {
|
||||
Device
|
||||
Count int
|
||||
}
|
||||
type Cuda Device
|
||||
|
||||
var (
|
||||
CPU Device = Device{Name: "CPU", Value: -1}
|
||||
CUDA Cuda = Cuda{Name: "CUDA", Value: 0, Count: 1}
|
||||
CUDA Cuda = Cuda{Name: "CUDA", Value: 0}
|
||||
)
|
||||
|
||||
func CudaBuilder(v uint) Device {
|
||||
|
@ -35,29 +32,31 @@ func (cu Cuda) DeviceCount() int64 {
|
|||
return int64(cInt)
|
||||
}
|
||||
|
||||
// CudnnIsAvailable returns true if cuda support is available
|
||||
func (cu Cuda) IsAvailable() bool {
|
||||
return lib.Atc_cuda_is_available()
|
||||
}
|
||||
|
||||
// CudnnIsAvailable return true if cudnn support is available
|
||||
func (cu Cuda) CudnnIsAvailable() bool {
|
||||
return lib.Atc_cudnn_is_available()
|
||||
}
|
||||
|
||||
// CudnnSetBenchmark sets cudnn benchmark mode
|
||||
//
|
||||
// When set cudnn will try to optimize the generators during the first network
|
||||
// runs and then use the optimized architecture in the following runs. This can
|
||||
// result in significant performance improvements.
|
||||
func (cu Cuda) CudnnSetBenchmark(b bool) {
|
||||
switch b {
|
||||
case true:
|
||||
lib.Atc_set_benchmark_cudnn(1)
|
||||
case false:
|
||||
lib.Act_cuda_benchmark_cudd(0)
|
||||
}
|
||||
}
|
||||
/*
|
||||
*
|
||||
* // CudnnIsAvailable returns true if cuda support is available
|
||||
* func (cu Cuda) IsAvailable() bool {
|
||||
* return lib.Atc_cuda_is_available()
|
||||
* }
|
||||
*
|
||||
* // CudnnIsAvailable return true if cudnn support is available
|
||||
* func (cu Cuda) CudnnIsAvailable() bool {
|
||||
* return lib.Atc_cudnn_is_available()
|
||||
* }
|
||||
*
|
||||
* // CudnnSetBenchmark sets cudnn benchmark mode
|
||||
* //
|
||||
* // When set cudnn will try to optimize the generators during the first network
|
||||
* // runs and then use the optimized architecture in the following runs. This can
|
||||
* // result in significant performance improvements.
|
||||
* func (cu Cuda) CudnnSetBenchmark(b bool) {
|
||||
* switch b {
|
||||
* case true:
|
||||
* lib.Atc_set_benchmark_cudnn(1)
|
||||
* case false:
|
||||
* lib.Act_cuda_benchmark_cudd(0)
|
||||
* }
|
||||
* } */
|
||||
|
||||
// Device methods:
|
||||
//================
|
||||
|
@ -72,6 +71,7 @@ func (d Device) CInt() CInt {
|
|||
return CInt(deviceIndex)
|
||||
default:
|
||||
log.Fatal("Not reachable")
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,13 +84,14 @@ func (d Device) OfCInt(v CInt) Device {
|
|||
default:
|
||||
log.Fatalf("Unexpected device %v", v)
|
||||
}
|
||||
return Device{}
|
||||
}
|
||||
|
||||
// CudaIfAvailable returns a GPU device if available, else default to CPU
|
||||
func (d Device) CudaIfAvailable() Device {
|
||||
switch {
|
||||
case CUDA.IsAvailable():
|
||||
return CudaBuilder(0)
|
||||
// case CUDA.IsAvailable():
|
||||
// return CudaBuilder(0)
|
||||
default:
|
||||
return CPU
|
||||
}
|
||||
|
|
|
@ -5,7 +5,11 @@ package libtch
|
|||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
import "C"
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func Atc_cuda_device_count() int {
|
||||
return C.atc_cuda_device_count()
|
||||
t := C.atc_cuda_device_count()
|
||||
return *(*int)(unsafe.Pointer(&t))
|
||||
}
|
||||
|
|
|
@ -10,17 +10,23 @@ import (
|
|||
|
||||
type c_void unsafe.Pointer
|
||||
type size_t uint
|
||||
type c_int int32
|
||||
|
||||
type C_tensor struct {
|
||||
_private uint8
|
||||
private uint8
|
||||
}
|
||||
|
||||
func NewTensor() *C_tensor {
|
||||
t := C.at_new_tensor()
|
||||
return &C_tensor{_private: *(*uint8)(unsafe.Pointer(&t))}
|
||||
return &C_tensor{private: *(*uint8)(unsafe.Pointer(&t))}
|
||||
}
|
||||
|
||||
func AtTensorOfData(vs c_void, dims int64, ndims size_t, elt_size_in_bytes size_t, kind c_int) *C_tensor {
|
||||
t := C.at_tensor_of_data(vs, dims, ndims, elt_size_in_bytes, kind)
|
||||
return &C_tensor{_private: *(*uint8)(unsafe.Pointer(&t))}
|
||||
func AtTensorOfData(vs unsafe.Pointer, dims int64, ndims uint, elt_size_in_bytes uint, kind int32) *C_tensor {
|
||||
c_dims := (*C.long)(unsafe.Pointer(&dims))
|
||||
c_ndims := *(*C.ulong)(unsafe.Pointer(&ndims))
|
||||
c_elt_size_in_bytes := *(*C.ulong)(unsafe.Pointer(&elt_size_in_bytes))
|
||||
c_kind := *(*C.int)(unsafe.Pointer(&kind))
|
||||
|
||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||
return &C_tensor{private: *(*uint8)(unsafe.Pointer(&t))}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user