works with libtorch API
This commit is contained in:
parent
c45ca32070
commit
b1c70b1dde
10
example/tensor/main.go
Normal file
10
example/tensor/main.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
t "github.com/sugarme/gotch/torch"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
t.NewTensor()
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
int warp_size();
|
||||
}
|
||||
}
|
||||
void dummy_cuda_dependency() {
|
||||
at::cuda::warp_size();
|
||||
}
|
6
torch/fake_cuda_dependency.cpp
Normal file
6
torch/fake_cuda_dependency.cpp
Normal file
|
@ -0,0 +1,6 @@
|
|||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
void dummy_cuda_dependency() {
|
||||
}
|
11
torch/lib.go
11
torch/lib.go
|
@ -1,6 +1,13 @@
|
|||
package torch
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++14 -I${SRCDIR} -O3 -Wall -g -Wno-sign-compare -Wno-unused-function -I/usr/local/include -I/opt/libtorch/include -I/opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo LDFLAGS: -L/opt/libtorch/lib -ltorch
|
||||
// #cgo CXXFLAGS: -std=c++17 -I${SRCDIR} -g -O3
|
||||
// #cgo CFLAGS: -I${SRCDIR} -O3 -Wall -Wno-unused-variable -Wno-deprecated-declarations -Wno-c++11-narrowing -g -Wno-sign-compare -Wno-unused-function
|
||||
// #cgo CFLAGS: -I/usr/local/include -I/opt/libtorch/include -I/opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo LDFLAGS: -lstdc++ -ltorch -lc10 -ltorch_cpu
|
||||
// #cgo LDFLAGS: -L/opt/libtorch/lib -L/lib64
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/lib
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc
|
||||
// #cgo CFLAGS: -D_GLIBCXX_USE_CXX11_ABI=1
|
||||
import "C"
|
||||
|
|
|
@ -1,13 +1,20 @@
|
|||
package torch
|
||||
|
||||
//#include <stdbool.h>
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type C_tensor struct {
|
||||
_private []uint8
|
||||
}
|
||||
|
||||
func NewTensor() *C_tensor {
|
||||
return C.at_new_tensor()
|
||||
func NewTensor() {
|
||||
t := C.at_new_tensor()
|
||||
fmt.Printf("Tensor Type: %v\n", reflect.TypeOf(t).Kind())
|
||||
fmt.Printf("Tensor Value: %v\n", t)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user