added tensor IsContiguous and IsMkldnn APIs

This commit is contained in:
sugarme 2023-07-23 14:08:04 +10:00
parent 6aa57478f5
commit 1fc13ec55e
6 changed files with 106 additions and 19 deletions

View File

@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- reworked `ts.Format()`
- Added conv2d benchmark
- Fixed #88 memory leak at `example/char-rnn`
- Added missing tensor `Stride()` and `MustDataPtr()`
- Added missing tensor `Stride()` and `MustDataPtr()`, `IsMkldnn`, `MustIsMkldnn`, `IsContiguous`, `MustIsContiguous`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -59,12 +59,6 @@ func NewTensor() Ctensor {
return C.at_new_tensor()
}
// int at_device(tensor);
func AtDevice(ts Ctensor) int {
cint := C.at_device(ts)
return *(*int)(unsafe.Pointer(&cint))
}
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
@ -88,6 +82,30 @@ func AtDataPtr(t Ctensor) unsafe.Pointer {
return C.at_data_ptr(t)
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_mkldnn(tensor);
func AtIsMkldnn(ts Ctensor) bool {
retVal := C.at_is_mkldnn(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_device(tensor);
func AtDevice(ts Ctensor) int {
cint := C.at_device(ts)
return *(*int)(unsafe.Pointer(&cint))
}
// size_t at_dim(tensor);
func AtDim(t Ctensor) uint64 {
result := C.at_dim(t)
@ -112,6 +130,12 @@ func AtScalarType(t Ctensor) int32 {
return *(*int32)(unsafe.Pointer(&result))
}
// int at_is_contiguous(tensor);
func AtIsContiguous(ts Ctensor) bool {
retVal := C.at_is_contiguous(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
func GetAndResetLastErr() *C.char {
return C.get_and_reset_last_err()
}
@ -182,18 +206,6 @@ func AtRequiresGrad(ts Ctensor) bool {
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// void at_backward(tensor, int, int);
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))

View File

@ -141,6 +141,12 @@ int at_scalar_type(tensor t) {
return -1;
}
int at_is_contiguous(tensor t) {
PROTECT(return t->is_contiguous();)
return -1;
}
// void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor
// inf_scale) { PROTECT( at::_amp_non_finite_check_and_unscale_(*t, *found_inf,
// *inf_scale);

View File

@ -46,6 +46,7 @@ size_t at_dim(tensor);
void at_shape(tensor, int64_t *);
void at_stride(tensor, int64_t *);
int at_scalar_type(tensor);
int at_is_contiguous(tensor);
void at__amp_non_finite_check_and_unscale(tensor, tensor, tensor);

View File

@ -713,6 +713,52 @@ func (ts *Tensor) IsSparse() (bool, error) {
return state, nil
}
func (ts *Tensor) MustIsSparse() bool {
state, err := ts.IsSparse()
if err != nil {
log.Fatal(err)
}
return state
}
// IsContiguous returns true is the tensor is contiguous.
func (ts *Tensor) IsContiguous() (bool, error) {
state := lib.AtIsContiguous(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustIsContiguous() bool {
state, err := ts.IsContiguous()
if err != nil {
log.Fatal(err)
}
return state
}
// IsMkldnn returns true is the tensor is mkldnn.
func (ts *Tensor) IsMkldnn() (bool, error) {
state := lib.AtIsMkldnn(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustIsMkldnn() bool {
state, err := ts.IsMkldnn()
if err != nil {
log.Fatal(err)
}
return state
}
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
func (ts *Tensor) ZeroGrad() {

View File

@ -167,3 +167,25 @@ func TestTensor_Stride(t *testing.T) {
t.Errorf("want %v, got %v\n", want, got)
}
}
func TestTensor_IsContiguous(t *testing.T) {
shape := []int64{2, 3, 4}
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
got := x.MustIsContiguous()
want := true
if !reflect.DeepEqual(want, got) {
t.Errorf("want %v, got %v\n", want, got)
}
}
func TestTensor_IsMkldnn(t *testing.T) {
shape := []int64{2, 3, 4}
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
got := x.MustIsMkldnn()
want := false
if !reflect.DeepEqual(want, got) {
t.Errorf("want %v, got %v\n", want, got)
}
}