From 1fc13ec55e9664309a1defeda865575b3bcb12b6 Mon Sep 17 00:00:00 2001 From: sugarme Date: Sun, 23 Jul 2023 14:08:04 +1000 Subject: [PATCH] added tensor IsContiguous and IsMkldnn APIs --- CHANGELOG.md | 2 +- libtch/tensor.go | 48 +++++++++++++++++++++++++++----------------- libtch/torch_api.cpp | 6 ++++++ libtch/torch_api.h | 1 + ts/tensor.go | 46 ++++++++++++++++++++++++++++++++++++++++++ ts/tensor_test.go | 22 ++++++++++++++++++++ 6 files changed, 106 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ffb2f9b..acd47c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - reworked `ts.Format()` - Added conv2d benchmark - Fixed #88 memory leak at `example/char-rnn` -- Added missing tensor `Stride()` and `MustDataPtr()` +- Added missing tensor `Stride()` and `MustDataPtr()`, `IsMkldnn`, `MustIsMkldnn`, `IsContiguous`, `MustIsContiguous` ## [Nofix] - ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box. diff --git a/libtch/tensor.go b/libtch/tensor.go index 8233533..1213d9b 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -59,12 +59,6 @@ func NewTensor() Ctensor { return C.at_new_tensor() } -// int at_device(tensor); -func AtDevice(ts Ctensor) int { - cint := C.at_device(ts) - return *(*int)(unsafe.Pointer(&cint)) -} - // tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type); func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor { @@ -88,6 +82,30 @@ func AtDataPtr(t Ctensor) unsafe.Pointer { return C.at_data_ptr(t) } +// int at_defined(tensor); +func AtDefined(ts Ctensor) bool { + retVal := C.at_defined(ts) + return *(*bool)(unsafe.Pointer(&retVal)) +} + +// int at_is_mkldnn(tensor); +func AtIsMkldnn(ts Ctensor) bool { + retVal := C.at_is_mkldnn(ts) + return *(*bool)(unsafe.Pointer(&retVal)) +} + +// int at_is_sparse(tensor); +func AtIsSparse(ts Ctensor) bool { + retVal := C.at_is_sparse(ts) + return *(*bool)(unsafe.Pointer(&retVal)) +} + +// int at_device(tensor); +func AtDevice(ts Ctensor) int { + cint := C.at_device(ts) + return *(*int)(unsafe.Pointer(&cint)) +} + // size_t at_dim(tensor); func AtDim(t Ctensor) uint64 { result := C.at_dim(t) @@ -112,6 +130,12 @@ func AtScalarType(t Ctensor) int32 { return *(*int32)(unsafe.Pointer(&result)) } +// int at_is_contiguous(tensor); +func AtIsContiguous(ts Ctensor) bool { + retVal := C.at_is_contiguous(ts) + return *(*bool)(unsafe.Pointer(&retVal)) +} + func GetAndResetLastErr() *C.char { return C.get_and_reset_last_err() } @@ -182,18 +206,6 @@ func AtRequiresGrad(ts Ctensor) bool { return *(*bool)(unsafe.Pointer(&retVal)) } -// int at_defined(tensor); -func AtDefined(ts Ctensor) bool { - retVal := C.at_defined(ts) - return *(*bool)(unsafe.Pointer(&retVal)) -} - -// int at_is_sparse(tensor); -func AtIsSparse(ts Ctensor) bool { - retVal := C.at_is_sparse(ts) - return *(*bool)(unsafe.Pointer(&retVal)) -} - // void at_backward(tensor, int, int); func AtBackward(ts Ctensor, keepGraph int, createGraph int) { ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph)) diff --git a/libtch/torch_api.cpp b/libtch/torch_api.cpp index 82c52fc..8da827d 100644 --- a/libtch/torch_api.cpp +++ b/libtch/torch_api.cpp @@ -141,6 +141,12 @@ int at_scalar_type(tensor t) { return -1; } +int at_is_contiguous(tensor t) { + PROTECT(return t->is_contiguous();) + return -1; +} + + // void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor // inf_scale) { PROTECT( at::_amp_non_finite_check_and_unscale_(*t, *found_inf, // *inf_scale); diff --git a/libtch/torch_api.h b/libtch/torch_api.h index 587e869..8724642 100644 --- a/libtch/torch_api.h +++ b/libtch/torch_api.h @@ -46,6 +46,7 @@ size_t at_dim(tensor); void at_shape(tensor, int64_t *); void at_stride(tensor, int64_t *); int at_scalar_type(tensor); +int at_is_contiguous(tensor); void at__amp_non_finite_check_and_unscale(tensor, tensor, tensor); diff --git a/ts/tensor.go b/ts/tensor.go index 208f1d1..14d8269 100644 --- a/ts/tensor.go +++ b/ts/tensor.go @@ -713,6 +713,52 @@ func (ts *Tensor) IsSparse() (bool, error) { return state, nil } +func (ts *Tensor) MustIsSparse() bool { + state, err := ts.IsSparse() + if err != nil { + log.Fatal(err) + } + + return state +} + +// IsContiguous returns true is the tensor is contiguous. +func (ts *Tensor) IsContiguous() (bool, error) { + state := lib.AtIsContiguous(ts.ctensor) + + if err := TorchErr(); err != nil { + return false, err + } + + return state, nil +} +func (ts *Tensor) MustIsContiguous() bool { + state, err := ts.IsContiguous() + if err != nil { + log.Fatal(err) + } + + return state +} + +// IsMkldnn returns true is the tensor is mkldnn. +func (ts *Tensor) IsMkldnn() (bool, error) { + state := lib.AtIsMkldnn(ts.ctensor) + + if err := TorchErr(); err != nil { + return false, err + } + + return state, nil +} +func (ts *Tensor) MustIsMkldnn() bool { + state, err := ts.IsMkldnn() + if err != nil { + log.Fatal(err) + } + + return state +} // ZeroGrad zeroes the gradient tensor attached to this tensor if defined. func (ts *Tensor) ZeroGrad() { diff --git a/ts/tensor_test.go b/ts/tensor_test.go index 3f89821..f82440b 100644 --- a/ts/tensor_test.go +++ b/ts/tensor_test.go @@ -167,3 +167,25 @@ func TestTensor_Stride(t *testing.T) { t.Errorf("want %v, got %v\n", want, got) } } + +func TestTensor_IsContiguous(t *testing.T) { + shape := []int64{2, 3, 4} + x := ts.MustRand(shape, gotch.Float, gotch.CPU) + + got := x.MustIsContiguous() + want := true + if !reflect.DeepEqual(want, got) { + t.Errorf("want %v, got %v\n", want, got) + } +} + +func TestTensor_IsMkldnn(t *testing.T) { + shape := []int64{2, 3, 4} + x := ts.MustRand(shape, gotch.Float, gotch.CPU) + + got := x.MustIsMkldnn() + want := false + if !reflect.DeepEqual(want, got) { + t.Errorf("want %v, got %v\n", want, got) + } +}