feat(dtype): TypeCheck func
This commit is contained in:
parent
6b0d6105ae
commit
98c182cef8
20
dtype.go
20
dtype.go
|
@ -248,3 +248,23 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// TypeCheck checks whether data Go type matching DType
|
||||
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
||||
|
||||
dataKind := reflect.ValueOf(data).Kind()
|
||||
dataType := reflect.TypeOf(data)
|
||||
|
||||
switch dataKind {
|
||||
case reflect.Slice:
|
||||
dataEleType := reflect.TypeOf(data).Elem()
|
||||
matched = dataEleType == dtype.Type
|
||||
msg = fmt.Sprintf("data type: %v, DType: %v", dataEleType, dtype.Kind())
|
||||
default:
|
||||
matched = dataType == dtype.Type
|
||||
msg = fmt.Sprintf("data type: %v, DType: %v", dataType, dtype.Kind())
|
||||
}
|
||||
|
||||
return matched, msg
|
||||
|
||||
}
|
||||
|
|
|
@ -12,8 +12,8 @@ func main() {
|
|||
// TODO: Check Go type of data and tensor DType
|
||||
// For. if data is []int and DType is Bool
|
||||
// It is still running but get wrong result.
|
||||
data := []bool{true, true, false}
|
||||
dtype := gotch.Bool
|
||||
data := []float32{1.1, 1.2, 1.1}
|
||||
dtype := gotch.Int
|
||||
|
||||
ts := wrapper.NewTensor()
|
||||
sliceTensor, err := ts.FOfSlice(data, dtype)
|
||||
|
|
|
@ -27,7 +27,7 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor,
|
|||
dataLen := reflect.ValueOf(data).Len()
|
||||
shape := []int64{int64(dataLen)}
|
||||
elementNum := ElementCount(shape)
|
||||
// eltSizeInBytes := dtype.EltSizeInBytes() // Element Size in Byte for Int dtype
|
||||
|
||||
eltSizeInBytes := gotch.DTypeSize(dtype)
|
||||
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
|
@ -42,23 +42,13 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor,
|
|||
|
||||
retVal = &Tensor{ctensor}
|
||||
|
||||
// Read back created tensor values by C libtorch
|
||||
// readDataPtr := lib.AtDataPtr(retVal.ctensor)
|
||||
// readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
||||
// // typ := typeOf(dtype, shape)
|
||||
// typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
||||
// val := reflect.New(typ)
|
||||
// if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
||||
// panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
||||
// }
|
||||
//
|
||||
// tensorData := reflect.Indirect(val).Interface()
|
||||
//
|
||||
// fmt.Println("%v", tensorData)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Print prints tensor values to console.
|
||||
//
|
||||
// NOTE: it is printed from C and will print ALL elements of tensor
|
||||
// with no truncation at all.
|
||||
func (ts Tensor) Print() {
|
||||
lib.AtPrint(ts.ctensor)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user