added get/set cuda device for qt testing

This commit is contained in:
sugarme 2023-07-11 15:35:36 +10:00
parent ea87e7fa38
commit dbbb4b97c1
5 changed files with 128 additions and 4 deletions

View File

@ -134,6 +134,24 @@ func AtcSetBenchmarkCudnn(b int) {
C.atc_set_benchmark_cudnn(cb)
}
// void atc_synchronize(int64_t device_index);
func AtcSynchronize(deviceIndex int64) {
cDeviceIndex := *(*C.int64_t)(unsafe.Pointer(&deviceIndex))
C.atc_synchronize(cDeviceIndex)
}
// int atc_get_device();
func AtcGetDevice() int {
cDeviceIndex := C.atc_get_device()
return int(cDeviceIndex)
}
// int atc_set_device(int device_index);
func AtcSetDevice(deviceIndex int) int {
cDeviceIndex := C.int(deviceIndex)
return int(cDeviceIndex)
}
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
ctensor := (C.tensor)(ts)

View File

@ -7,6 +7,7 @@
#include<torch/script.h>
#include<stdexcept>
#include<vector>
#include "ATen/core/interned_strings.h"
#include "torch_api.h"
@ -45,9 +46,10 @@ c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
}
at::Device device_of_int(int d) {
if (d < 0)
return at::Device(at::kCPU);
return at::Device(at::kCUDA, /*index=*/d);
if (d == -3) return at::Device(at::kVulkan);
// if (d == -2) return at::Device(at::kMPS);
if (d < 0) return at::Device(at::kCPU);
return at::Device(at::kCUDA, /*index=*/d);
}
tensor at_new_tensor() {
PROTECT(return new torch::Tensor();)
@ -176,7 +178,7 @@ bool at_autocast_set_enabled(bool b) {
int at_device(tensor t) {
PROTECT(auto device = t->device(); if (device.type() == at::kCPU) return -1;
if (device.type() == at::kCUDA) return device.index();)
return -2;
return -99; // error
}
void at_backward(tensor t, int keep_graph, int create_graph) {
@ -753,6 +755,31 @@ void atc_set_benchmark_cudnn(int b) {
at::globalContext().setBenchmarkCuDNN(b);
}
void atc_synchronize(int64_t device_index) {
PROTECT(return torch::cuda::synchronize(device_index);)
}
// returns current CUDA device index.
int atc_get_device(){
PROTECT(
at::Device d(at::kCUDA);
auto *g = c10::impl::getDeviceGuardImpl(d.type());
d = g->getDevice();
return d.index();
)
return -99; // error
}
// set new cuda device with input device index.
void atc_set_device(int device_index){
PROTECT(
at::Device new_device(at::kCUDA);
new_device = device_of_int(device_index);
auto *g = c10::impl::getDeviceGuardImpl(new_device.type());
g->setDevice(new_device);
)
}
module atm_load(char *filename) {
PROTECT(return new torch::jit::script::Module(torch::jit::load(filename));)
return nullptr;

View File

@ -155,6 +155,12 @@ int atc_cuda_device_count();
int atc_cuda_is_available();
int atc_cudnn_is_available();
void atc_set_benchmark_cudnn(int b);
void atc_synchronize(int64_t device_index);
// TT. added for testing qt
// ref. https://github.com/pytorch/pytorch/issues/14959
int atc_get_device();
void atc_set_device(int device_index);
module atm_load(char *);
module atm_load_on_device(char *, int device);

View File

@ -1379,3 +1379,51 @@ func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool)
return retVal
}
// TT. Added some torch.cuda APIs for handling CUDA qt
// CudaCurrentDevice get device index of current CUDA device.
func CudaCurrentDevice() (int, error) {
currentDeviceIndex := lib.AtcGetDevice()
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err)
return -99, err
}
return currentDeviceIndex, nil
}
// CudaSetDevice set new cuda device index and returns previous cuda index.
func CudaSetDevice(cudaDeviceIndex int) (int, error) {
currentDeviceIndex, err := CudaCurrentDevice()
if err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
}
lib.AtcSetDevice(cudaDeviceIndex)
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
}
return currentDeviceIndex, nil
}
// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete.
func CudaSynchronize(cudaDeviceIndexOpt ...int) error {
var cudaDeviceIndex int
var err error
if len(cudaDeviceIndexOpt) > 0 {
cudaDeviceIndex = cudaDeviceIndexOpt[0]
} else {
cudaDeviceIndex, err = CudaCurrentDevice()
if err != nil {
err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err)
return err
}
}
lib.AtcSynchronize(int64(cudaDeviceIndex))
return TorchErr()
}

View File

@ -131,3 +131,28 @@ func TestOfSlice(t *testing.T) {
t.Errorf("Got dtype: %v\n", got)
}
}
func TestCudaCurrentDevice(t *testing.T) {
cudaIdx, err := ts.CudaCurrentDevice()
if err != nil {
t.Error(err)
}
t.Logf("current CUDA index: %v\n", cudaIdx) // should be 0 if having 1 GPU device
x := ts.MustZeros([]int64{2, 3, 4}, gotch.Float, gotch.CudaIfAvailable())
currentCudaIndex := x.MustDevice().Value
t.Logf("x current cuda index: %v\n", currentCudaIndex) // 0
previousCudaIndex, err := ts.CudaSetDevice(currentCudaIndex)
if err != nil {
t.Error(err)
}
t.Logf("Cuda index BEFORE set: %v\n", previousCudaIndex) // 0
cudaIdxAfter, err := ts.CudaCurrentDevice()
if err != nil {
t.Error(err)
}
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
}