fix(wrapper/error): get full detail string from C; WIP(tensor_generated_sample): cuda
This commit is contained in:
parent
07c7e0ed7d
commit
41c0cfaab2
11
device.go
11
device.go
|
@ -23,6 +23,17 @@ func CudaBuilder(v uint) Device {
|
|||
return Device{Name: "CUDA", Value: int(v)}
|
||||
}
|
||||
|
||||
// NewCuda creates a cuda device (default) if available
|
||||
// If will be panic if cuda is not available.
|
||||
func NewCuda() Device {
|
||||
var d Cuda
|
||||
if !d.IsAvailable() {
|
||||
log.Fatalf("Cuda is not available.")
|
||||
}
|
||||
|
||||
return CudaBuilder(0)
|
||||
}
|
||||
|
||||
// Cuda methods:
|
||||
// =============
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"log"
|
||||
|
||||
// gotch "github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch"
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
|
@ -59,6 +58,7 @@ func main() {
|
|||
dx := [][]int32{
|
||||
{1, 1},
|
||||
{1, 1},
|
||||
{1, 1},
|
||||
}
|
||||
|
||||
dy := [][]int32{
|
||||
|
@ -66,7 +66,7 @@ func main() {
|
|||
{1, 1, 1},
|
||||
}
|
||||
|
||||
xs, err := wrapper.NewTensorFromData(dx, []int64{2, 2})
|
||||
xs, err := wrapper.NewTensorFromData(dx, []int64{2, 3})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -75,6 +75,25 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
xs.Matmul(*ys)
|
||||
xs.Matmul(ys)
|
||||
|
||||
// device := gotch.NewCuda()
|
||||
//
|
||||
// // cy := ys.To(device)
|
||||
// // cx := xs.To(device)
|
||||
// zs := wrapper.NewTensor()
|
||||
// cz := zs.To(device)
|
||||
// fmt.Println(cz.Device().Name)
|
||||
|
||||
// cx.Matmul(cy)
|
||||
|
||||
// for i := 1; i < 1000000; i++ {
|
||||
// for i := 1; i < 2; i++ {
|
||||
// cx := xs.To(device)
|
||||
// cx.Print()
|
||||
// cy := ys.To(device)
|
||||
// cy.Print()
|
||||
//
|
||||
// }
|
||||
|
||||
}
|
||||
|
|
32
libtch/c-generated-sample.go
Normal file
32
libtch/c-generated-sample.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
// NOTE: this is a sample for OCaml generated code for `c-generated.go`
|
||||
package libtch
|
||||
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// void atg_eq1(tensor *, tensor self, tensor other);
|
||||
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_eq1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_matmul(tensor *, tensor self, tensor other);
|
||||
func AtgMatmul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_matmul(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_to(tensor *, tensor self, int device);
|
||||
func AtgTo(ptr *Ctensor, self Ctensor, device int) {
|
||||
cdevice := *(*C.int)(unsafe.Pointer(&device))
|
||||
C.atg_to(ptr, self, cdevice)
|
||||
}
|
||||
|
||||
// int at_device(tensor);
|
||||
func AtDevice(ts Ctensor) int {
|
||||
cint := C.at_device(ts)
|
||||
return *(*int)(unsafe.Pointer(&cint))
|
||||
}
|
12
libtch/dummy_cuda_dependency.cpp
Normal file
12
libtch/dummy_cuda_dependency.cpp
Normal file
|
@ -0,0 +1,12 @@
|
|||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
int warp_size();
|
||||
}
|
||||
}
|
||||
void dummy_cuda_dependency() {
|
||||
at::cuda::warp_size();
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
void dummy_cuda_dependency() {
|
||||
}
|
|
@ -11,5 +11,5 @@ package libtch
|
|||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc
|
||||
// #cgo CFLAGS: -D_GLIBCXX_USE_CXX11_ABI=1
|
||||
// #cgo linux,amd64,!nogpu CFLAGS: -I/usr/local/cuda/include
|
||||
// #cgo linux,amd64,!nogpu LDFLAGS: -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcudnn -lcaffe2_nvrtc -lnvrtc-builtins -lnvrtc -lnvToolsExt -L/opt/libtorch/lib -lc10_cuda
|
||||
// #cgo linux,amd64,!nogpu LDFLAGS: -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcudnn -lcaffe2_nvrtc -lnvrtc-builtins -lnvrtc -lnvToolsExt -L/opt/libtorch/lib -lc10_cuda -ltorch_cuda
|
||||
import "C"
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
// NOTE: this is a sample for OCaml generated code for `tensor_generated.go`
|
||||
package libtch
|
||||
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
import "C"
|
||||
|
||||
// void atg_eq1(tensor *, tensor self, tensor other);
|
||||
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_eq1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_matmul(tensor *, tensor self, tensor other);
|
||||
func AtgMatmul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_matmul(ptr, self, other)
|
||||
}
|
|
@ -21,9 +21,7 @@ func ptrToString(cptr *C.char) string {
|
|||
var str string = ""
|
||||
|
||||
if cptr != nil {
|
||||
// strPtr := (*string)(unsafe.Pointer(cptr))
|
||||
// TODO: get error string from pointer
|
||||
str = fmt.Sprintf("TODO: will show more detail here.")
|
||||
str = C.GoString(cptr)
|
||||
C.free(unsafe.Pointer(cptr))
|
||||
}
|
||||
|
||||
|
@ -43,7 +41,7 @@ func TorchErr() error {
|
|||
cptr := (*C.char)(lib.GetAndResetLastErr())
|
||||
errStr := ptrToString(cptr)
|
||||
if errStr != "" {
|
||||
return fmt.Errorf("Libtorch API Err: %v\n", errStr)
|
||||
return fmt.Errorf("Libtorch API Error: %v\n", errStr)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
61
wrapper/tensor-generated-sample.go
Normal file
61
wrapper/tensor-generated-sample.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
// NOTE: this is a sample for OCaml generated code for `tensor-generated.go`
|
||||
package wrapper
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17 -I${SRCDIR} -g -O3
|
||||
// #cgo CFLAGS: -I${SRCDIR} -O3 -Wall -Wno-unused-variable -Wno-deprecated-declarations -Wno-c++11-narrowing -g -Wno-sign-compare -Wno-unused-function
|
||||
// #cgo CFLAGS: -I/usr/local/include -I/opt/libtorch/include -I/opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo LDFLAGS: -lstdc++ -ltorch -lc10 -ltorch_cpu
|
||||
// #cgo LDFLAGS: -L/opt/libtorch/lib -L/lib64
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/lib
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc
|
||||
// #cgo CFLAGS: -D_GLIBCXX_USE_CXX11_ABI=1
|
||||
// #cgo linux,amd64,!nogpu CFLAGS: -I/usr/local/cuda/include
|
||||
// #cgo linux,amd64,!nogpu LDFLAGS: -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcudnn -lcaffe2_nvrtc -lnvrtc-builtins -lnvrtc -lnvToolsExt -L/opt/libtorch/lib -lc10_cuda -ltorch_cuda
|
||||
// # include <cuda.h>
|
||||
import "C"
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"log"
|
||||
"unsafe"
|
||||
|
||||
gt "github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
func (ts Tensor) To(device gt.Device) Tensor {
|
||||
|
||||
// TODO: how to get pointer to CUDA memory???
|
||||
// Something like `C.cudaMalloc()`???
|
||||
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
// var cudaPtr unsafe.Pointer
|
||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1)
|
||||
// fmt.Printf("Cuda Pointer: %v\n", &cudaPtr)
|
||||
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.cuMemAlloc(device.CInt(), 0)))
|
||||
|
||||
var ptr unsafe.Pointer
|
||||
// lib.AtgTo(ptr, ts.ctensor, int(device.CInt()))
|
||||
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
|
||||
// lib.AtgTo((*lib.Ctensor)(cudaPtr), ts.ctensor, int(device.CInt()))
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *(*lib.Ctensor)(unsafe.Pointer(&ptr))}
|
||||
// return Tensor{ctensor: (lib.Ctensor)(cudaPtr)}
|
||||
}
|
||||
|
||||
func (ts Tensor) Device() gt.Device {
|
||||
cInt := lib.AtDevice(ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var device gt.Device
|
||||
|
||||
return device.OfCInt(int32(cInt))
|
||||
}
|
|
@ -178,45 +178,45 @@ func (ts Tensor) Print() {
|
|||
}
|
||||
|
||||
// NewTensorFromData creates tensor from given data and shape
|
||||
func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err error) {
|
||||
func NewTensorFromData(data interface{}, shape []int64) (retVal Tensor, err error) {
|
||||
// 1. Check whether data and shape match
|
||||
elementNum, err := DataDim(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
nflattend := FlattenDim(shape)
|
||||
|
||||
if elementNum != nflattend {
|
||||
err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend)
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// 2. Write raw data to C memory and get C pointer
|
||||
dataPtr, err := DataAsPtr(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// 3. Create tensor with pointer and shape
|
||||
dtype, err := gotch.DTypeFromData(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
|
||||
retVal = &Tensor{ctensor}
|
||||
retVal = Tensor{ctensor}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user