added get/set cuda device for qt testing
This commit is contained in:
parent
ea87e7fa38
commit
dbbb4b97c1
|
@ -134,6 +134,24 @@ func AtcSetBenchmarkCudnn(b int) {
|
|||
C.atc_set_benchmark_cudnn(cb)
|
||||
}
|
||||
|
||||
// void atc_synchronize(int64_t device_index);
|
||||
func AtcSynchronize(deviceIndex int64) {
|
||||
cDeviceIndex := *(*C.int64_t)(unsafe.Pointer(&deviceIndex))
|
||||
C.atc_synchronize(cDeviceIndex)
|
||||
}
|
||||
|
||||
// int atc_get_device();
|
||||
func AtcGetDevice() int {
|
||||
cDeviceIndex := C.atc_get_device()
|
||||
return int(cDeviceIndex)
|
||||
}
|
||||
|
||||
// int atc_set_device(int device_index);
|
||||
func AtcSetDevice(deviceIndex int) int {
|
||||
cDeviceIndex := C.int(deviceIndex)
|
||||
return int(cDeviceIndex)
|
||||
}
|
||||
|
||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
|
||||
ctensor := (C.tensor)(ts)
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include<torch/script.h>
|
||||
#include<stdexcept>
|
||||
#include<vector>
|
||||
#include "ATen/core/interned_strings.h"
|
||||
#include "torch_api.h"
|
||||
|
||||
|
||||
|
@ -45,9 +46,10 @@ c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
|
|||
}
|
||||
|
||||
at::Device device_of_int(int d) {
|
||||
if (d < 0)
|
||||
return at::Device(at::kCPU);
|
||||
return at::Device(at::kCUDA, /*index=*/d);
|
||||
if (d == -3) return at::Device(at::kVulkan);
|
||||
// if (d == -2) return at::Device(at::kMPS);
|
||||
if (d < 0) return at::Device(at::kCPU);
|
||||
return at::Device(at::kCUDA, /*index=*/d);
|
||||
}
|
||||
tensor at_new_tensor() {
|
||||
PROTECT(return new torch::Tensor();)
|
||||
|
@ -176,7 +178,7 @@ bool at_autocast_set_enabled(bool b) {
|
|||
int at_device(tensor t) {
|
||||
PROTECT(auto device = t->device(); if (device.type() == at::kCPU) return -1;
|
||||
if (device.type() == at::kCUDA) return device.index();)
|
||||
return -2;
|
||||
return -99; // error
|
||||
}
|
||||
|
||||
void at_backward(tensor t, int keep_graph, int create_graph) {
|
||||
|
@ -753,6 +755,31 @@ void atc_set_benchmark_cudnn(int b) {
|
|||
at::globalContext().setBenchmarkCuDNN(b);
|
||||
}
|
||||
|
||||
void atc_synchronize(int64_t device_index) {
|
||||
PROTECT(return torch::cuda::synchronize(device_index);)
|
||||
}
|
||||
|
||||
// returns current CUDA device index.
|
||||
int atc_get_device(){
|
||||
PROTECT(
|
||||
at::Device d(at::kCUDA);
|
||||
auto *g = c10::impl::getDeviceGuardImpl(d.type());
|
||||
d = g->getDevice();
|
||||
return d.index();
|
||||
)
|
||||
return -99; // error
|
||||
}
|
||||
|
||||
// set new cuda device with input device index.
|
||||
void atc_set_device(int device_index){
|
||||
PROTECT(
|
||||
at::Device new_device(at::kCUDA);
|
||||
new_device = device_of_int(device_index);
|
||||
auto *g = c10::impl::getDeviceGuardImpl(new_device.type());
|
||||
g->setDevice(new_device);
|
||||
)
|
||||
}
|
||||
|
||||
module atm_load(char *filename) {
|
||||
PROTECT(return new torch::jit::script::Module(torch::jit::load(filename));)
|
||||
return nullptr;
|
||||
|
|
|
@ -155,6 +155,12 @@ int atc_cuda_device_count();
|
|||
int atc_cuda_is_available();
|
||||
int atc_cudnn_is_available();
|
||||
void atc_set_benchmark_cudnn(int b);
|
||||
void atc_synchronize(int64_t device_index);
|
||||
|
||||
// TT. added for testing qt
|
||||
// ref. https://github.com/pytorch/pytorch/issues/14959
|
||||
int atc_get_device();
|
||||
void atc_set_device(int device_index);
|
||||
|
||||
module atm_load(char *);
|
||||
module atm_load_on_device(char *, int device);
|
||||
|
|
48
ts/tensor.go
48
ts/tensor.go
|
@ -1379,3 +1379,51 @@ func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool)
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// TT. Added some torch.cuda APIs for handling CUDA qt
|
||||
|
||||
// CudaCurrentDevice get device index of current CUDA device.
|
||||
func CudaCurrentDevice() (int, error) {
|
||||
currentDeviceIndex := lib.AtcGetDevice()
|
||||
if err := TorchErr(); err != nil {
|
||||
err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err)
|
||||
return -99, err
|
||||
}
|
||||
|
||||
return currentDeviceIndex, nil
|
||||
}
|
||||
|
||||
// CudaSetDevice set new cuda device index and returns previous cuda index.
|
||||
func CudaSetDevice(cudaDeviceIndex int) (int, error) {
|
||||
currentDeviceIndex, err := CudaCurrentDevice()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
|
||||
return -99, err
|
||||
}
|
||||
|
||||
lib.AtcSetDevice(cudaDeviceIndex)
|
||||
if err := TorchErr(); err != nil {
|
||||
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
|
||||
return -99, err
|
||||
}
|
||||
return currentDeviceIndex, nil
|
||||
}
|
||||
|
||||
// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete.
|
||||
func CudaSynchronize(cudaDeviceIndexOpt ...int) error {
|
||||
var cudaDeviceIndex int
|
||||
var err error
|
||||
if len(cudaDeviceIndexOpt) > 0 {
|
||||
cudaDeviceIndex = cudaDeviceIndexOpt[0]
|
||||
} else {
|
||||
cudaDeviceIndex, err = CudaCurrentDevice()
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
lib.AtcSynchronize(int64(cudaDeviceIndex))
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
|
|
@ -131,3 +131,28 @@ func TestOfSlice(t *testing.T) {
|
|||
t.Errorf("Got dtype: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCudaCurrentDevice(t *testing.T) {
|
||||
cudaIdx, err := ts.CudaCurrentDevice()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
t.Logf("current CUDA index: %v\n", cudaIdx) // should be 0 if having 1 GPU device
|
||||
|
||||
x := ts.MustZeros([]int64{2, 3, 4}, gotch.Float, gotch.CudaIfAvailable())
|
||||
currentCudaIndex := x.MustDevice().Value
|
||||
t.Logf("x current cuda index: %v\n", currentCudaIndex) // 0
|
||||
|
||||
previousCudaIndex, err := ts.CudaSetDevice(currentCudaIndex)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Logf("Cuda index BEFORE set: %v\n", previousCudaIndex) // 0
|
||||
|
||||
cudaIdxAfter, err := ts.CudaCurrentDevice()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user