feat(dtype): TypeCheck func

This commit is contained in:
sugarme 2020-05-30 11:15:36 +10:00
parent 6b0d6105ae
commit 98c182cef8
3 changed files with 27 additions and 17 deletions

View File

@ -248,3 +248,23 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
return nil, err
}
}
// TypeCheck checks whether data Go type matching DType
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
dataKind := reflect.ValueOf(data).Kind()
dataType := reflect.TypeOf(data)
switch dataKind {
case reflect.Slice:
dataEleType := reflect.TypeOf(data).Elem()
matched = dataEleType == dtype.Type
msg = fmt.Sprintf("data type: %v, DType: %v", dataEleType, dtype.Kind())
default:
matched = dataType == dtype.Type
msg = fmt.Sprintf("data type: %v, DType: %v", dataType, dtype.Kind())
}
return matched, msg
}

View File

@ -12,8 +12,8 @@ func main() {
// TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool
// It is still running but get wrong result.
data := []bool{true, true, false}
dtype := gotch.Bool
data := []float32{1.1, 1.2, 1.1}
dtype := gotch.Int
ts := wrapper.NewTensor()
sliceTensor, err := ts.FOfSlice(data, dtype)

View File

@ -27,7 +27,7 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor,
dataLen := reflect.ValueOf(data).Len()
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
// eltSizeInBytes := dtype.EltSizeInBytes() // Element Size in Byte for Int dtype
eltSizeInBytes := gotch.DTypeSize(dtype)
nbytes := int(eltSizeInBytes) * int(elementNum)
@ -42,23 +42,13 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor,
retVal = &Tensor{ctensor}
// Read back created tensor values by C libtorch
// readDataPtr := lib.AtDataPtr(retVal.ctensor)
// readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
// // typ := typeOf(dtype, shape)
// typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
// val := reflect.New(typ)
// if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
// panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
// }
//
// tensorData := reflect.Indirect(val).Interface()
//
// fmt.Println("%v", tensorData)
return retVal, nil
}
// Print prints tensor values to console.
//
// NOTE: it is printed from C and will print ALL elements of tensor
// with no truncation at all.
func (ts Tensor) Print() {
lib.AtPrint(ts.ctensor)
}