feat(wrapper/cuda): add some Cuda APIs and README.md
This commit is contained in:
parent
12f5eaa9d7
commit
4c25c43eb8
55
README.md
Normal file
55
README.md
Normal file
|
@ -0,0 +1,55 @@
|
|||
# GOTCH - Libtorch Go Binding
|
||||
|
||||
|
||||
## Overview
|
||||
|
||||
- **GoTch** is a C++ Libtorch Go binding for developing and implementing deep
|
||||
learning projects in Go.
|
||||
|
||||
- It currently is in heavy development mode and is considered unstable until
|
||||
version v1.0.0 is marked. Hence, one can use it with own risk.
|
||||
|
||||
- One goal of this package is to create a thin wrapper of Libtorch to make use of
|
||||
its tensor APIs and CUDA support while implementing as much
|
||||
idiomatic Go as possible.
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **Libtorch** C++ library of [Pytorch](https://pytorch.org/)
|
||||
|
||||
## How to use
|
||||
|
||||
### 1. Libtorch installation
|
||||
|
||||
- Make sure that a libtorch version 1.5.0 (either CPU or CUDA support) is
|
||||
installed in your system (default at "/opt/libtorch" in Linux/Mac OS).
|
||||
|
||||
### 2. Import **GoTch** package
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import(
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
func main(){
|
||||
|
||||
var d gotch.Cuda
|
||||
fmt.Printf("Cuda device count: %v\n", d.DeviceCount())
|
||||
fmt.Printf("Cuda is available: %v\n", d.IsAvailable())
|
||||
fmt.Printf("Cudnn is available: %v\n", d.CudnnIsAvailable())
|
||||
|
||||
}
|
||||
```
|
||||
|
||||
- Other examples can be found at `example` folder
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
- This projects has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs)
|
||||
Libtorch Rust binding.
|
||||
|
50
device.go
50
device.go
|
@ -28,35 +28,33 @@ func CudaBuilder(v uint) Device {
|
|||
|
||||
// DeviceCount returns the number of GPU that can be used.
|
||||
func (cu Cuda) DeviceCount() int64 {
|
||||
cInt := lib.Atc_cuda_device_count()
|
||||
cInt := lib.AtcCudaDeviceCount()
|
||||
return int64(cInt)
|
||||
}
|
||||
|
||||
/*
|
||||
*
|
||||
* // CudnnIsAvailable returns true if cuda support is available
|
||||
* func (cu Cuda) IsAvailable() bool {
|
||||
* return lib.Atc_cuda_is_available()
|
||||
* }
|
||||
*
|
||||
* // CudnnIsAvailable return true if cudnn support is available
|
||||
* func (cu Cuda) CudnnIsAvailable() bool {
|
||||
* return lib.Atc_cudnn_is_available()
|
||||
* }
|
||||
*
|
||||
* // CudnnSetBenchmark sets cudnn benchmark mode
|
||||
* //
|
||||
* // When set cudnn will try to optimize the generators during the first network
|
||||
* // runs and then use the optimized architecture in the following runs. This can
|
||||
* // result in significant performance improvements.
|
||||
* func (cu Cuda) CudnnSetBenchmark(b bool) {
|
||||
* switch b {
|
||||
* case true:
|
||||
* lib.Atc_set_benchmark_cudnn(1)
|
||||
* case false:
|
||||
* lib.Act_cuda_benchmark_cudd(0)
|
||||
* }
|
||||
* } */
|
||||
// CudnnIsAvailable returns true if cuda support is available
|
||||
func (cu Cuda) IsAvailable() bool {
|
||||
return lib.AtcCudaIsAvailable()
|
||||
}
|
||||
|
||||
// CudnnIsAvailable return true if cudnn support is available
|
||||
func (cu Cuda) CudnnIsAvailable() bool {
|
||||
return lib.AtcCudnnIsAvailable()
|
||||
}
|
||||
|
||||
// CudnnSetBenchmark sets cudnn benchmark mode
|
||||
//
|
||||
// When set cudnn will try to optimize the generators during the first network
|
||||
// runs and then use the optimized architecture in the following runs. This can
|
||||
// result in significant performance improvements.
|
||||
func (cu Cuda) CudnnSetBenchmark(b bool) {
|
||||
switch b {
|
||||
case true:
|
||||
lib.AtcSetBenchmarkCudnn(1)
|
||||
case false:
|
||||
lib.AtcSetBenchmarkCudnn(0)
|
||||
}
|
||||
}
|
||||
|
||||
// Device methods:
|
||||
//================
|
||||
|
|
16
example/cuda/main.go
Normal file
16
example/cuda/main.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
var d gotch.Cuda
|
||||
fmt.Printf("Cuda device count: %v\n", d.DeviceCount())
|
||||
fmt.Printf("Cuda is available: %v\n", d.IsAvailable())
|
||||
fmt.Printf("Cudnn is available: %v\n", d.CudnnIsAvailable())
|
||||
|
||||
}
|
|
@ -5,6 +5,7 @@ import (
|
|||
"log"
|
||||
|
||||
// gotch "github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch"
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
|
|
|
@ -10,4 +10,6 @@ package libtch
|
|||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc/api/include
|
||||
// #cgo CXXFLAGS: -isystem /opt/libtorch/include/torch/csrc
|
||||
// #cgo CFLAGS: -D_GLIBCXX_USE_CXX11_ABI=1
|
||||
// #cgo linux,amd64,!nogpu CFLAGS: -I/usr/local/cuda/include
|
||||
// #cgo linux,amd64,!nogpu LDFLAGS: -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcudnn -lcaffe2_nvrtc -lnvrtc-builtins -lnvrtc -lnvToolsExt -L/opt/libtorch/lib -lc10_cuda
|
||||
import "C"
|
||||
|
|
|
@ -65,3 +65,27 @@ func AtScalarType(t Ctensor) int32 {
|
|||
func GetAndResetLastErr() *C.char {
|
||||
return C.get_and_reset_last_err()
|
||||
}
|
||||
|
||||
// int atc_cuda_device_count();
|
||||
func AtcCudaDeviceCount() int {
|
||||
result := C.atc_cuda_device_count()
|
||||
return *(*int)(unsafe.Pointer(&result))
|
||||
}
|
||||
|
||||
// int atc_cuda_is_available();
|
||||
func AtcCudaIsAvailable() bool {
|
||||
result := C.atc_cuda_is_available()
|
||||
return *(*bool)(unsafe.Pointer(&result))
|
||||
}
|
||||
|
||||
// int atc_cudnn_is_available();
|
||||
func AtcCudnnIsAvailable() bool {
|
||||
result := C.atc_cudnn_is_available()
|
||||
return *(*bool)(unsafe.Pointer(&result))
|
||||
}
|
||||
|
||||
// void atc_set_benchmark_cudnn(int b);
|
||||
func AtcSetBenchmarkCudnn(b int) {
|
||||
cb := *(*C.int)(unsafe.Pointer(&b))
|
||||
C.atc_set_benchmark_cudnn(cb)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user