424 lines
9.8 KiB
Go
424 lines
9.8 KiB
Go
package gotch
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
)
|
|
|
|
// CInt is equal to C type int. Go type is int32
|
|
type CInt = int32
|
|
|
|
// DType represents different kind of element that a tensor can hold.
|
|
// Ref. https://github.com/pytorch/pytorch/blob/a290cbf32b0c282aa60fa521ca5c6cd19c7f779f/c10/core/ScalarType.h
|
|
type DType int
|
|
|
|
const (
|
|
Invalid DType = -1
|
|
Uint8 DType = 0
|
|
Int8 DType = 1
|
|
Int16 DType = 2
|
|
Int DType = 3
|
|
Int64 DType = 4
|
|
Half DType = 5
|
|
Float DType = 6
|
|
Double DType = 7
|
|
ComplexHalf DType = 8
|
|
ComplexFloat DType = 9
|
|
ComplexDouble DType = 10
|
|
Bool DType = 11
|
|
QInt8 DType = 12
|
|
QUInt8 DType = 13
|
|
QInt32 DType = 14
|
|
BFloat16 DType = 15
|
|
// ---not implemented ---
|
|
QUInt4x2 DType = 16
|
|
QUInt2x4 DType = 17
|
|
Bits1x8 DType = 18
|
|
Bits2x4 DType = 19
|
|
Bits4x2 DType = 20
|
|
Bits8 DType = 21
|
|
Bits16 DType = 22
|
|
)
|
|
|
|
var dtype2CKind = map[DType]CInt{
|
|
Uint8: 0,
|
|
Int8: 1,
|
|
Int16: 2,
|
|
Int: 3,
|
|
Int64: 4,
|
|
Half: 5,
|
|
Float: 6,
|
|
Double: 7,
|
|
ComplexHalf: 8,
|
|
ComplexFloat: 9,
|
|
ComplexDouble: 10,
|
|
Bool: 11,
|
|
QInt8: 12,
|
|
QUInt8: 13,
|
|
QInt32: 14,
|
|
BFloat16: 15,
|
|
// ---not implemented ---
|
|
QUInt4x2: 16,
|
|
QUInt2x4: 17,
|
|
Bits1x8: 18,
|
|
Bits2x4: 19,
|
|
Bits4x2: 20,
|
|
Bits8: 21,
|
|
Bits16: 22,
|
|
}
|
|
|
|
func (dt DType) CKind() CInt {
|
|
if cint, ok := dtype2CKind[dt]; ok {
|
|
return cint
|
|
}
|
|
|
|
if Debug {
|
|
log.Printf("WARNING: dt.CKind() failed: no corresponding CKind to this DType %v\n", dt)
|
|
}
|
|
return -1 // invalid
|
|
}
|
|
|
|
// Back compat
|
|
func (dt DType) CInt() CInt {
|
|
return dt.CKind()
|
|
}
|
|
|
|
var ckind2DType map[CInt]DType = map[CInt]DType{
|
|
0: Uint8,
|
|
1: Int8,
|
|
2: Int16,
|
|
3: Int,
|
|
4: Int64,
|
|
5: Half,
|
|
6: Float,
|
|
7: Double,
|
|
8: ComplexHalf,
|
|
9: ComplexFloat,
|
|
10: ComplexDouble,
|
|
11: Bool,
|
|
12: QInt8,
|
|
13: QUInt8,
|
|
14: QInt32,
|
|
15: BFloat16,
|
|
// ---not implemented ---
|
|
16: QUInt4x2,
|
|
17: QUInt2x4,
|
|
18: Bits1x8,
|
|
19: Bits2x4,
|
|
20: Bits4x2,
|
|
21: Bits8,
|
|
22: Bits16,
|
|
}
|
|
|
|
func CKind2DType(ckind int32) DType {
|
|
if dtype, ok := ckind2DType[ckind]; ok {
|
|
return dtype
|
|
}
|
|
|
|
if Debug {
|
|
log.Printf("WARNING: CKind2DType() failed: no corresponding DType to input CInt %v\n", ckind)
|
|
}
|
|
return -1 // invalid
|
|
}
|
|
|
|
var dtypeSize map[DType]uint = map[DType]uint{
|
|
Uint8: 1,
|
|
Int8: 1,
|
|
Int16: 2,
|
|
Int: 4,
|
|
Int64: 8,
|
|
Half: 2,
|
|
Float: 4,
|
|
Double: 8,
|
|
ComplexHalf: 4,
|
|
ComplexFloat: 8,
|
|
ComplexDouble: 16,
|
|
Bool: 1,
|
|
QInt8: 1,
|
|
QUInt8: 1,
|
|
QInt32: 4,
|
|
BFloat16: 2,
|
|
QUInt4x2: 2,
|
|
QUInt2x4: 1,
|
|
// ---not implemented ---
|
|
Bits1x8: 1,
|
|
Bits2x4: 1,
|
|
Bits4x2: 1,
|
|
Bits8: 1,
|
|
Bits16: 2,
|
|
}
|
|
|
|
// Size returns dtype size in Bytes.
|
|
func (dt DType) Size() uint {
|
|
return dtypeSize[dt]
|
|
}
|
|
|
|
type DTypeDevice struct {
|
|
DType DType
|
|
Device Device
|
|
}
|
|
|
|
var (
|
|
FloatCPU DTypeDevice = DTypeDevice{Float, CPU}
|
|
DoubleCPU DTypeDevice = DTypeDevice{Double, CPU}
|
|
Int64CPU DTypeDevice = DTypeDevice{Int64, CPU}
|
|
|
|
FloatCUDA DTypeDevice = DTypeDevice{Float, CudaBuilder(0)}
|
|
DoubleCUDA DTypeDevice = DTypeDevice{Double, CudaBuilder(0)}
|
|
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
|
|
)
|
|
|
|
var dtype2GoKind map[DType]reflect.Kind = map[DType]reflect.Kind{
|
|
Uint8: reflect.Uint8,
|
|
Int8: reflect.Int8,
|
|
Int16: reflect.Int16,
|
|
Int: reflect.Int32,
|
|
Int64: reflect.Int64,
|
|
Half: reflect.Uint16, // <- just uint16
|
|
Float: reflect.Float32,
|
|
Double: reflect.Float64,
|
|
ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
|
ComplexFloat: reflect.Complex64,
|
|
ComplexDouble: reflect.Complex128,
|
|
Bool: reflect.Bool,
|
|
QInt8: reflect.Int8,
|
|
QUInt8: reflect.Uint8,
|
|
QInt32: reflect.Int32,
|
|
BFloat16: reflect.Uint16, // <- just uint16
|
|
// ---not implemented ---
|
|
QUInt4x2: reflect.Invalid,
|
|
QUInt2x4: reflect.Invalid,
|
|
Bits1x8: reflect.Invalid,
|
|
Bits2x4: reflect.Invalid,
|
|
Bits4x2: reflect.Invalid,
|
|
Bits8: reflect.Invalid,
|
|
Bits16: reflect.Invalid,
|
|
}
|
|
|
|
func (dt DType) GoKind() reflect.Kind {
|
|
if kind, ok := dtype2GoKind[dt]; ok && kind != reflect.Invalid {
|
|
return kind
|
|
}
|
|
|
|
if Debug {
|
|
log.Printf("WARNING: DType.GoKind() failed: no corresponding Go reflect.Kind to given DType %v\n", dt)
|
|
}
|
|
|
|
return reflect.Invalid
|
|
}
|
|
|
|
const (
|
|
// NOTE. reflect.Kind 0-26
|
|
QUInt8Kind reflect.Kind = 27
|
|
QInt8Kind reflect.Kind = 28
|
|
QInt32Kind reflect.Kind = 29
|
|
Float16Kind reflect.Kind = 30
|
|
BFloat16Kind reflect.Kind = 31
|
|
QUInt4x2Kind reflect.Kind = 32
|
|
QUInt2x4Kind reflect.Kind = 33
|
|
Bits1x8Kind reflect.Kind = 34
|
|
Bits2x4Kind reflect.Kind = 35
|
|
Bits4x2Kind reflect.Kind = 36
|
|
Bits8Kind reflect.Kind = 37
|
|
Bits16Kind reflect.Kind = 38
|
|
ComplexHalfKind reflect.Kind = 39
|
|
)
|
|
|
|
var goKind2DType map[reflect.Kind]DType = map[reflect.Kind]DType{
|
|
reflect.Uint8: Uint8,
|
|
reflect.Int8: Int8,
|
|
reflect.Int16: Int16,
|
|
reflect.Int32: Int,
|
|
reflect.Int64: Int64,
|
|
reflect.Float32: Float,
|
|
reflect.Float64: Double,
|
|
reflect.Complex64: ComplexFloat,
|
|
reflect.Complex128: ComplexDouble,
|
|
reflect.Bool: Bool,
|
|
reflect.Uint16: Half,
|
|
|
|
// Added Kinds
|
|
QUInt8Kind: QUInt8,
|
|
QInt8Kind: QInt8,
|
|
QInt32Kind: QInt32,
|
|
// Float16Kind: Half,
|
|
BFloat16Kind: BFloat16,
|
|
QUInt4x2Kind: QUInt4x2,
|
|
QUInt2x4Kind: QUInt2x4,
|
|
Bits1x8Kind: Bits1x8,
|
|
Bits2x4Kind: Bits2x4,
|
|
Bits4x2Kind: Bits4x2,
|
|
Bits8Kind: Bits8,
|
|
Bits16Kind: Bits16,
|
|
ComplexHalfKind: ComplexHalf,
|
|
}
|
|
|
|
type DTypeOptions struct {
|
|
HalfDTypePref DType
|
|
Quantized bool
|
|
}
|
|
|
|
type DTypeOpt func(*DTypeOptions)
|
|
|
|
func DefaultDTypeOptions() *DTypeOptions {
|
|
return &DTypeOptions{
|
|
HalfDTypePref: Half,
|
|
Quantized: false,
|
|
}
|
|
}
|
|
|
|
func HalfDTypePref(v DType) DTypeOpt {
|
|
if v != Half && v != BFloat16 {
|
|
if Debug {
|
|
log.Printf("WARNING: HalfDTypePref(): Ignoring invalid HalfDTypePref. HalfDTypePref either 'gotch.Half' or 'gotch.BFloat16'. Got %v\n", v)
|
|
}
|
|
}
|
|
|
|
return func(o *DTypeOptions) {
|
|
o.HalfDTypePref = v
|
|
}
|
|
}
|
|
|
|
func WithQuantized(v bool) DTypeOpt {
|
|
return func(o *DTypeOptions) {
|
|
o.Quantized = v
|
|
}
|
|
}
|
|
|
|
func GoKind2DType(kind reflect.Kind, opts ...DTypeOpt) (DType, error) {
|
|
o := DefaultDTypeOptions()
|
|
for _, opt := range opts {
|
|
opt(o)
|
|
}
|
|
|
|
switch {
|
|
case kind == reflect.Uint16 && o.HalfDTypePref == Half:
|
|
return Half, nil
|
|
case kind == reflect.Uint16 && o.HalfDTypePref == BFloat16:
|
|
return BFloat16, nil
|
|
case kind == reflect.Int8 && o.Quantized:
|
|
return QInt8, nil
|
|
case kind == reflect.Uint8 && o.Quantized:
|
|
return QUInt8, nil
|
|
case kind == reflect.Int32 && o.Quantized:
|
|
return QInt32, nil
|
|
|
|
default:
|
|
dtype, ok := goKind2DType[kind]
|
|
if !ok {
|
|
err := fmt.Errorf("GoKind2DType() failed: no corresponding DType to given Go reflect.Kind %v\n", kind)
|
|
return Invalid, err
|
|
}
|
|
return dtype, nil
|
|
}
|
|
}
|
|
|
|
var dtype2GoType map[DType]reflect.Type = map[DType]reflect.Type{
|
|
Uint8: reflect.TypeOf(uint8(0)),
|
|
Int8: reflect.TypeOf(int8(0)),
|
|
Int16: reflect.TypeOf(int16(0)),
|
|
Int: reflect.TypeOf(int(0)),
|
|
Int64: reflect.TypeOf(int64(0)),
|
|
Half: reflect.TypeOf(uint16(0)), // <- just uint16
|
|
Float: reflect.TypeOf(float32(0)),
|
|
Double: reflect.TypeOf(float64(0)),
|
|
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
|
ComplexFloat: reflect.TypeOf(complex64(0)),
|
|
ComplexDouble: reflect.TypeOf(complex128(0)),
|
|
Bool: reflect.TypeOf(true),
|
|
QInt8: reflect.TypeOf(int8(0)),
|
|
QUInt8: reflect.TypeOf(uint8(0)),
|
|
QInt32: reflect.TypeOf(int32(0)),
|
|
BFloat16: reflect.TypeOf(uint16(0)), // <- just uint16
|
|
// ---not implemented ---
|
|
QUInt4x2: reflect.TypeOf(int8(0)),
|
|
QUInt2x4: reflect.TypeOf(uint8(0)),
|
|
Bits1x8: reflect.TypeOf(uint8(0)),
|
|
Bits2x4: reflect.TypeOf(uint8(0)),
|
|
Bits4x2: reflect.TypeOf(uint8(0)),
|
|
Bits8: reflect.TypeOf(uint8(0)),
|
|
Bits16: reflect.TypeOf(uint16(0)),
|
|
}
|
|
|
|
func (dt DType) GoType() (reflect.Type, error) {
|
|
typ, ok := dtype2GoType[dt]
|
|
if !ok {
|
|
err := fmt.Errorf("DType.GoType() failed: no corresponding Go type to given DType %v\n", typ.String())
|
|
return nil, err
|
|
}
|
|
|
|
return typ, nil
|
|
}
|
|
|
|
var dtypeNames map[DType]string = map[DType]string{
|
|
Uint8: "Uint8",
|
|
Int8: "Int8",
|
|
Int16: "Int16",
|
|
Int: "Int",
|
|
Int64: "Int64",
|
|
Half: "Half", // <- just uint16
|
|
Float: "Float",
|
|
Double: "Double",
|
|
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
|
ComplexFloat: "ComplexFloat",
|
|
ComplexDouble: "ComplexDouble",
|
|
Bool: "Bool",
|
|
QInt8: "QInt8",
|
|
QUInt8: "QUInt8",
|
|
QInt32: "QInt32",
|
|
BFloat16: "BFloat16", // <- just uint16
|
|
// ---not implemented ---
|
|
QUInt4x2: "QUInt4x2",
|
|
QUInt2x4: "QUInt2x4",
|
|
Bits1x8: "Bits1x8",
|
|
Bits2x4: "Bits2x4",
|
|
Bits4x2: "Bits4x2",
|
|
Bits8: "Bits8",
|
|
Bits16: "Bits16",
|
|
}
|
|
|
|
func (dt DType) String() string {
|
|
return dtypeNames[dt]
|
|
}
|
|
|
|
func DTypeFromData(data interface{}) (DType, error) {
|
|
dataKind := reflect.TypeOf(data).Kind()
|
|
|
|
// Data is a slice/array
|
|
if dataKind == reflect.Slice || dataKind == reflect.Array {
|
|
elementKind := reflect.TypeOf(data).Elem().Kind()
|
|
return GoKind2DType(elementKind)
|
|
}
|
|
|
|
// single element
|
|
return GoKind2DType(dataKind)
|
|
}
|
|
|
|
// IsFloatDType returns whether dtype is floating point data type.
|
|
func IsFloatDType(dtype DType) bool {
|
|
switch dtype {
|
|
case Double, Float, Half, BFloat16:
|
|
return true
|
|
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Default DType:
|
|
// ==============
|
|
var DefaultDType DType = Float
|
|
|
|
// SetDefaultDType set DefaultDType to new value and return the previous one.
|
|
func SetDefaultDType(dtype DType) DType {
|
|
odtype := DefaultDType
|
|
DefaultDType = dtype
|
|
|
|
if Debug {
|
|
log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype)
|
|
}
|
|
|
|
return odtype
|
|
}
|