From dbbb4b97c19e72806ba6e9eac68e9d5ea5f3b1d2 Mon Sep 17 00:00:00 2001 From: sugarme Date: Tue, 11 Jul 2023 15:35:36 +1000 Subject: [PATCH] added get/set cuda device for qt testing --- libtch/tensor.go | 18 +++++++++++++++++ libtch/torch_api.cpp | 35 ++++++++++++++++++++++++++++---- libtch/torch_api.h | 6 ++++++ ts/tensor.go | 48 ++++++++++++++++++++++++++++++++++++++++++++ ts/tensor_test.go | 25 +++++++++++++++++++++++ 5 files changed, 128 insertions(+), 4 deletions(-) diff --git a/libtch/tensor.go b/libtch/tensor.go index adac1a7..6143a0b 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -134,6 +134,24 @@ func AtcSetBenchmarkCudnn(b int) { C.atc_set_benchmark_cudnn(cb) } +// void atc_synchronize(int64_t device_index); +func AtcSynchronize(deviceIndex int64) { + cDeviceIndex := *(*C.int64_t)(unsafe.Pointer(&deviceIndex)) + C.atc_synchronize(cDeviceIndex) +} + +// int atc_get_device(); +func AtcGetDevice() int { + cDeviceIndex := C.atc_get_device() + return int(cDeviceIndex) +} + +// int atc_set_device(int device_index); +func AtcSetDevice(deviceIndex int) int { + cDeviceIndex := C.int(deviceIndex) + return int(cDeviceIndex) +} + // double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len); func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 { ctensor := (C.tensor)(ts) diff --git a/libtch/torch_api.cpp b/libtch/torch_api.cpp index 17c536b..82c52fc 100644 --- a/libtch/torch_api.cpp +++ b/libtch/torch_api.cpp @@ -7,6 +7,7 @@ #include #include #include +#include "ATen/core/interned_strings.h" #include "torch_api.h" @@ -45,9 +46,10 @@ c10::List> of_carray_tensor_opt(torch::Tensor **vs, } at::Device device_of_int(int d) { - if (d < 0) - return at::Device(at::kCPU); - return at::Device(at::kCUDA, /*index=*/d); + if (d == -3) return at::Device(at::kVulkan); + // if (d == -2) return at::Device(at::kMPS); + if (d < 0) return at::Device(at::kCPU); + return at::Device(at::kCUDA, /*index=*/d); } tensor at_new_tensor() { PROTECT(return new torch::Tensor();) @@ -176,7 +178,7 @@ bool at_autocast_set_enabled(bool b) { int at_device(tensor t) { PROTECT(auto device = t->device(); if (device.type() == at::kCPU) return -1; if (device.type() == at::kCUDA) return device.index();) - return -2; + return -99; // error } void at_backward(tensor t, int keep_graph, int create_graph) { @@ -753,6 +755,31 @@ void atc_set_benchmark_cudnn(int b) { at::globalContext().setBenchmarkCuDNN(b); } +void atc_synchronize(int64_t device_index) { + PROTECT(return torch::cuda::synchronize(device_index);) +} + +// returns current CUDA device index. +int atc_get_device(){ + PROTECT( + at::Device d(at::kCUDA); + auto *g = c10::impl::getDeviceGuardImpl(d.type()); + d = g->getDevice(); + return d.index(); + ) + return -99; // error +} + +// set new cuda device with input device index. +void atc_set_device(int device_index){ + PROTECT( + at::Device new_device(at::kCUDA); + new_device = device_of_int(device_index); + auto *g = c10::impl::getDeviceGuardImpl(new_device.type()); + g->setDevice(new_device); + ) +} + module atm_load(char *filename) { PROTECT(return new torch::jit::script::Module(torch::jit::load(filename));) return nullptr; diff --git a/libtch/torch_api.h b/libtch/torch_api.h index 2902b1f..587e869 100644 --- a/libtch/torch_api.h +++ b/libtch/torch_api.h @@ -155,6 +155,12 @@ int atc_cuda_device_count(); int atc_cuda_is_available(); int atc_cudnn_is_available(); void atc_set_benchmark_cudnn(int b); +void atc_synchronize(int64_t device_index); + +// TT. added for testing qt +// ref. https://github.com/pytorch/pytorch/issues/14959 +int atc_get_device(); +void atc_set_device(int device_index); module atm_load(char *); module atm_load_on_device(char *, int device); diff --git a/ts/tensor.go b/ts/tensor.go index 3744177..c97f4cb 100644 --- a/ts/tensor.go +++ b/ts/tensor.go @@ -1379,3 +1379,51 @@ func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool) return retVal } + +// TT. Added some torch.cuda APIs for handling CUDA qt + +// CudaCurrentDevice get device index of current CUDA device. +func CudaCurrentDevice() (int, error) { + currentDeviceIndex := lib.AtcGetDevice() + if err := TorchErr(); err != nil { + err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err) + return -99, err + } + + return currentDeviceIndex, nil +} + +// CudaSetDevice set new cuda device index and returns previous cuda index. +func CudaSetDevice(cudaDeviceIndex int) (int, error) { + currentDeviceIndex, err := CudaCurrentDevice() + if err != nil { + err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err) + return -99, err + } + + lib.AtcSetDevice(cudaDeviceIndex) + if err := TorchErr(); err != nil { + err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err) + return -99, err + } + return currentDeviceIndex, nil +} + +// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete. +func CudaSynchronize(cudaDeviceIndexOpt ...int) error { + var cudaDeviceIndex int + var err error + if len(cudaDeviceIndexOpt) > 0 { + cudaDeviceIndex = cudaDeviceIndexOpt[0] + } else { + cudaDeviceIndex, err = CudaCurrentDevice() + if err != nil { + err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err) + return err + } + } + + lib.AtcSynchronize(int64(cudaDeviceIndex)) + + return TorchErr() +} diff --git a/ts/tensor_test.go b/ts/tensor_test.go index 57534b6..846a99e 100644 --- a/ts/tensor_test.go +++ b/ts/tensor_test.go @@ -131,3 +131,28 @@ func TestOfSlice(t *testing.T) { t.Errorf("Got dtype: %v\n", got) } } + +func TestCudaCurrentDevice(t *testing.T) { + cudaIdx, err := ts.CudaCurrentDevice() + if err != nil { + t.Error(err) + } + + t.Logf("current CUDA index: %v\n", cudaIdx) // should be 0 if having 1 GPU device + + x := ts.MustZeros([]int64{2, 3, 4}, gotch.Float, gotch.CudaIfAvailable()) + currentCudaIndex := x.MustDevice().Value + t.Logf("x current cuda index: %v\n", currentCudaIndex) // 0 + + previousCudaIndex, err := ts.CudaSetDevice(currentCudaIndex) + if err != nil { + t.Error(err) + } + t.Logf("Cuda index BEFORE set: %v\n", previousCudaIndex) // 0 + + cudaIdxAfter, err := ts.CudaCurrentDevice() + if err != nil { + t.Error(err) + } + t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0 +}