diff --git a/example/tensor/main.go b/example/tensor/main.go new file mode 100644 index 0000000..4793889 --- /dev/null +++ b/example/tensor/main.go @@ -0,0 +1,10 @@ +package main + +import ( + t "github.com/sugarme/gotch/torch" +) + +func main() { + + t.NewTensor() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..3f67ae4 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module github.com/sugarme/gotch + +go 1.14 diff --git a/torch/dummy_cuda_dependency.cpp b/torch/dummy_cuda_dependency.cpp deleted file mode 100644 index e5c12e3..0000000 --- a/torch/dummy_cuda_dependency.cpp +++ /dev/null @@ -1,12 +0,0 @@ -extern "C" { - void dummy_cuda_dependency(); -} - -namespace at { - namespace cuda { - int warp_size(); - } -} -void dummy_cuda_dependency() { - at::cuda::warp_size(); -} diff --git a/torch/fake_cuda_dependency.cpp b/torch/fake_cuda_dependency.cpp new file mode 100644 index 0000000..9db4803 --- /dev/null +++ b/torch/fake_cuda_dependency.cpp @@ -0,0 +1,6 @@ +extern "C" { + void dummy_cuda_dependency(); +} + +void dummy_cuda_dependency() { +} diff --git a/torch/lib.go b/torch/lib.go index 063ad43..436caed 100644 --- a/torch/lib.go +++ b/torch/lib.go @@ -1,6 +1,13 @@ package torch -// #cgo CXXFLAGS: -std=c++14 -I${SRCDIR} -O3 -Wall -g -Wno-sign-compare -Wno-unused-function -I/usr/local/include -I/opt/libtorch/include -I/opt/libtorch/include/torch/csrc/api/include -// #cgo LDFLAGS: -L/opt/libtorch/lib -ltorch +// #cgo CXXFLAGS: -std=c++17 -I${SRCDIR} -g -O3 +// #cgo CFLAGS: -I${SRCDIR} -O3 -Wall -Wno-unused-variable -Wno-deprecated-declarations -Wno-c++11-narrowing -g -Wno-sign-compare -Wno-unused-function +// #cgo CFLAGS: -I/usr/local/include -I/opt/libtorch/include -I/opt/libtorch/include/torch/csrc/api/include +// #cgo LDFLAGS: -lstdc++ -ltorch -lc10 -ltorch_cpu +// #cgo LDFLAGS: -L/opt/libtorch/lib -L/lib64 +// #cgo CXXFLAGS: -isystem /opt/libtorch/lib +// #cgo CXXFLAGS: -isystem /opt/libtorch/include +// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc/api/include +// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc // #cgo CFLAGS: -D_GLIBCXX_USE_CXX11_ABI=1 import "C" diff --git a/torch/tensor.go b/torch/tensor.go index fd6b7cc..eb55101 100644 --- a/torch/tensor.go +++ b/torch/tensor.go @@ -1,13 +1,20 @@ package torch -//#include +//#include "stdbool.h" //#include "torch_api.h" import "C" +import ( + "fmt" + "reflect" +) + type C_tensor struct { _private []uint8 } -func NewTensor() *C_tensor { - return C.at_new_tensor() +func NewTensor() { + t := C.at_new_tensor() + fmt.Printf("Tensor Type: %v\n", reflect.TypeOf(t).Kind()) + fmt.Printf("Tensor Value: %v\n", t) }