added tensor IsContiguous and IsMkldnn APIs
This commit is contained in:
parent
6aa57478f5
commit
1fc13ec55e
|
@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- reworked `ts.Format()`
|
||||
- Added conv2d benchmark
|
||||
- Fixed #88 memory leak at `example/char-rnn`
|
||||
- Added missing tensor `Stride()` and `MustDataPtr()`
|
||||
- Added missing tensor `Stride()` and `MustDataPtr()`, `IsMkldnn`, `MustIsMkldnn`, `IsContiguous`, `MustIsContiguous`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
|
@ -59,12 +59,6 @@ func NewTensor() Ctensor {
|
|||
return C.at_new_tensor()
|
||||
}
|
||||
|
||||
// int at_device(tensor);
|
||||
func AtDevice(ts Ctensor) int {
|
||||
cint := C.at_device(ts)
|
||||
return *(*int)(unsafe.Pointer(&cint))
|
||||
}
|
||||
|
||||
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
|
||||
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
|
||||
|
||||
|
@ -88,6 +82,30 @@ func AtDataPtr(t Ctensor) unsafe.Pointer {
|
|||
return C.at_data_ptr(t)
|
||||
}
|
||||
|
||||
// int at_defined(tensor);
|
||||
func AtDefined(ts Ctensor) bool {
|
||||
retVal := C.at_defined(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_mkldnn(tensor);
|
||||
func AtIsMkldnn(ts Ctensor) bool {
|
||||
retVal := C.at_is_mkldnn(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_sparse(tensor);
|
||||
func AtIsSparse(ts Ctensor) bool {
|
||||
retVal := C.at_is_sparse(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_device(tensor);
|
||||
func AtDevice(ts Ctensor) int {
|
||||
cint := C.at_device(ts)
|
||||
return *(*int)(unsafe.Pointer(&cint))
|
||||
}
|
||||
|
||||
// size_t at_dim(tensor);
|
||||
func AtDim(t Ctensor) uint64 {
|
||||
result := C.at_dim(t)
|
||||
|
@ -112,6 +130,12 @@ func AtScalarType(t Ctensor) int32 {
|
|||
return *(*int32)(unsafe.Pointer(&result))
|
||||
}
|
||||
|
||||
// int at_is_contiguous(tensor);
|
||||
func AtIsContiguous(ts Ctensor) bool {
|
||||
retVal := C.at_is_contiguous(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
func GetAndResetLastErr() *C.char {
|
||||
return C.get_and_reset_last_err()
|
||||
}
|
||||
|
@ -182,18 +206,6 @@ func AtRequiresGrad(ts Ctensor) bool {
|
|||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_defined(tensor);
|
||||
func AtDefined(ts Ctensor) bool {
|
||||
retVal := C.at_defined(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_sparse(tensor);
|
||||
func AtIsSparse(ts Ctensor) bool {
|
||||
retVal := C.at_is_sparse(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// void at_backward(tensor, int, int);
|
||||
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
|
||||
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
|
||||
|
|
|
@ -141,6 +141,12 @@ int at_scalar_type(tensor t) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
int at_is_contiguous(tensor t) {
|
||||
PROTECT(return t->is_contiguous();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
// void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor
|
||||
// inf_scale) { PROTECT( at::_amp_non_finite_check_and_unscale_(*t, *found_inf,
|
||||
// *inf_scale);
|
||||
|
|
|
@ -46,6 +46,7 @@ size_t at_dim(tensor);
|
|||
void at_shape(tensor, int64_t *);
|
||||
void at_stride(tensor, int64_t *);
|
||||
int at_scalar_type(tensor);
|
||||
int at_is_contiguous(tensor);
|
||||
|
||||
void at__amp_non_finite_check_and_unscale(tensor, tensor, tensor);
|
||||
|
||||
|
|
46
ts/tensor.go
46
ts/tensor.go
|
@ -713,6 +713,52 @@ func (ts *Tensor) IsSparse() (bool, error) {
|
|||
|
||||
return state, nil
|
||||
}
|
||||
func (ts *Tensor) MustIsSparse() bool {
|
||||
state, err := ts.IsSparse()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
// IsContiguous returns true is the tensor is contiguous.
|
||||
func (ts *Tensor) IsContiguous() (bool, error) {
|
||||
state := lib.AtIsContiguous(ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
func (ts *Tensor) MustIsContiguous() bool {
|
||||
state, err := ts.IsContiguous()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
// IsMkldnn returns true is the tensor is mkldnn.
|
||||
func (ts *Tensor) IsMkldnn() (bool, error) {
|
||||
state := lib.AtIsMkldnn(ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
func (ts *Tensor) MustIsMkldnn() bool {
|
||||
state, err := ts.IsMkldnn()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
|
||||
func (ts *Tensor) ZeroGrad() {
|
||||
|
|
|
@ -167,3 +167,25 @@ func TestTensor_Stride(t *testing.T) {
|
|||
t.Errorf("want %v, got %v\n", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensor_IsContiguous(t *testing.T) {
|
||||
shape := []int64{2, 3, 4}
|
||||
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
|
||||
|
||||
got := x.MustIsContiguous()
|
||||
want := true
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want %v, got %v\n", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensor_IsMkldnn(t *testing.T) {
|
||||
shape := []int64{2, 3, 4}
|
||||
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
|
||||
|
||||
got := x.MustIsMkldnn()
|
||||
want := false
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want %v, got %v\n", want, got)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user