127 lines
2.5 KiB
Go
127 lines
2.5 KiB
Go
package gotch
|
|
|
|
import (
|
|
"log"
|
|
|
|
lib "git.andr3h3nriqu3s.com/andr3/gotch/libtch"
|
|
)
|
|
|
|
type Device struct {
|
|
Name string
|
|
Value int
|
|
}
|
|
|
|
type Cuda Device
|
|
|
|
var (
|
|
CPU Device = Device{Name: "CPU", Value: -1}
|
|
CUDA Cuda = Cuda{Name: "CUDA", Value: 0}
|
|
)
|
|
|
|
func CudaBuilder(v uint) Device {
|
|
// TODO: fully initiate cuda here
|
|
return Device{Name: "CUDA", Value: int(v)}
|
|
}
|
|
|
|
// NewCuda creates a cuda device (default) if available
|
|
// If will be panic if cuda is not available.
|
|
func NewCuda() Device {
|
|
var d Cuda
|
|
if !d.IsAvailable() {
|
|
log.Fatalf("Cuda is not available.")
|
|
}
|
|
|
|
return CudaBuilder(0)
|
|
}
|
|
|
|
// Cuda methods:
|
|
// =============
|
|
|
|
// DeviceCount returns the number of GPU that can be used.
|
|
func (cu Cuda) DeviceCount() int64 {
|
|
cInt := lib.AtcCudaDeviceCount()
|
|
return int64(cInt)
|
|
}
|
|
|
|
// CudnnIsAvailable returns true if cuda support is available
|
|
func (cu Cuda) IsAvailable() bool {
|
|
return lib.AtcCudaIsAvailable()
|
|
}
|
|
|
|
// CudnnIsAvailable return true if cudnn support is available
|
|
func (cu Cuda) CudnnIsAvailable() bool {
|
|
return lib.AtcCudnnIsAvailable()
|
|
}
|
|
|
|
// CudnnSetBenchmark sets cudnn benchmark mode
|
|
//
|
|
// When set cudnn will try to optimize the generators during the first network
|
|
// runs and then use the optimized architecture in the following runs. This can
|
|
// result in significant performance improvements.
|
|
func (cu Cuda) CudnnSetBenchmark(b bool) {
|
|
switch b {
|
|
case true:
|
|
lib.AtcSetBenchmarkCudnn(1)
|
|
case false:
|
|
lib.AtcSetBenchmarkCudnn(0)
|
|
}
|
|
}
|
|
|
|
// Device methods:
|
|
//================
|
|
|
|
func (d Device) CInt() CInt {
|
|
switch {
|
|
case d.Name == "CPU":
|
|
return -1
|
|
case d.Name == "CUDA":
|
|
// TODO: create a function to retrieve cuda_index
|
|
var deviceIndex int = d.Value
|
|
return CInt(deviceIndex)
|
|
default:
|
|
log.Fatal("Not reachable")
|
|
return 0
|
|
}
|
|
}
|
|
|
|
func (d Device) OfCInt(v CInt) Device {
|
|
switch {
|
|
case v == -1:
|
|
return Device{Name: "CPU", Value: 1}
|
|
case v >= 0:
|
|
return CudaBuilder(uint(v))
|
|
default:
|
|
log.Fatalf("Unexpected device %v", v)
|
|
}
|
|
return Device{}
|
|
}
|
|
|
|
// CudaIfAvailable returns a GPU device if available, else default to CPU
|
|
func (d Device) CudaIfAvailable() Device {
|
|
switch {
|
|
case CUDA.IsAvailable():
|
|
return CudaBuilder(0)
|
|
default:
|
|
return CPU
|
|
}
|
|
}
|
|
|
|
// IsCuda returns whether device is a Cuda device
|
|
func (d Device) IsCuda() bool {
|
|
if d.Name == "CPU" {
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// CudaIfAvailable returns a GPU device if available, else CPU.
|
|
func CudaIfAvailable() Device {
|
|
switch {
|
|
case CUDA.IsAvailable():
|
|
return CudaBuilder(0)
|
|
default:
|
|
return CPU
|
|
}
|
|
}
|