package gotch import ( "fmt" "log" "reflect" ) // CInt is equal to C type int. Go type is int32 type CInt = int32 // DType represents different kind of element that a tensor can hold. // Ref. https://github.com/pytorch/pytorch/blob/a290cbf32b0c282aa60fa521ca5c6cd19c7f779f/c10/core/ScalarType.h type DType int const ( Invalid DType = -1 Uint8 DType = 0 Int8 DType = 1 Int16 DType = 2 Int DType = 3 Int64 DType = 4 Half DType = 5 Float DType = 6 Double DType = 7 ComplexHalf DType = 8 ComplexFloat DType = 9 ComplexDouble DType = 10 Bool DType = 11 QInt8 DType = 12 QUInt8 DType = 13 QInt32 DType = 14 BFloat16 DType = 15 // ---not implemented --- QUInt4x2 DType = 16 QUInt2x4 DType = 17 Bits1x8 DType = 18 Bits2x4 DType = 19 Bits4x2 DType = 20 Bits8 DType = 21 Bits16 DType = 22 ) var dtype2CKind = map[DType]CInt{ Uint8: 0, Int8: 1, Int16: 2, Int: 3, Int64: 4, Half: 5, Float: 6, Double: 7, ComplexHalf: 8, ComplexFloat: 9, ComplexDouble: 10, Bool: 11, QInt8: 12, QUInt8: 13, QInt32: 14, BFloat16: 15, // ---not implemented --- QUInt4x2: 16, QUInt2x4: 17, Bits1x8: 18, Bits2x4: 19, Bits4x2: 20, Bits8: 21, Bits16: 22, } func (dt DType) CKind() CInt { if cint, ok := dtype2CKind[dt]; ok { return cint } if Debug { log.Printf("WARNING: dt.CKind() failed: no corresponding CKind to this DType %v\n", dt) } return -1 // invalid } // Back compat func (dt DType) CInt() CInt { return dt.CKind() } var ckind2DType map[CInt]DType = map[CInt]DType{ 0: Uint8, 1: Int8, 2: Int16, 3: Int, 4: Int64, 5: Half, 6: Float, 7: Double, 8: ComplexHalf, 9: ComplexFloat, 10: ComplexDouble, 11: Bool, 12: QInt8, 13: QUInt8, 14: QInt32, 15: BFloat16, // ---not implemented --- 16: QUInt4x2, 17: QUInt2x4, 18: Bits1x8, 19: Bits2x4, 20: Bits4x2, 21: Bits8, 22: Bits16, } func CKind2DType(ckind int32) DType { if dtype, ok := ckind2DType[ckind]; ok { return dtype } if Debug { log.Printf("WARNING: CKind2DType() failed: no corresponding DType to input CInt %v\n", ckind) } return -1 // invalid } var dtypeSize map[DType]uint = map[DType]uint{ Uint8: 1, Int8: 1, Int16: 2, Int: 4, Int64: 8, Half: 2, Float: 4, Double: 8, ComplexHalf: 4, ComplexFloat: 8, ComplexDouble: 16, Bool: 1, QInt8: 1, QUInt8: 1, QInt32: 4, BFloat16: 2, QUInt4x2: 2, QUInt2x4: 1, // ---not implemented --- Bits1x8: 1, Bits2x4: 1, Bits4x2: 1, Bits8: 1, Bits16: 2, } // Size returns dtype size in Bytes. func (dt DType) Size() uint { return dtypeSize[dt] } type DTypeDevice struct { DType DType Device Device } var ( FloatCPU DTypeDevice = DTypeDevice{Float, CPU} DoubleCPU DTypeDevice = DTypeDevice{Double, CPU} Int64CPU DTypeDevice = DTypeDevice{Int64, CPU} FloatCUDA DTypeDevice = DTypeDevice{Float, CudaBuilder(0)} DoubleCUDA DTypeDevice = DTypeDevice{Double, CudaBuilder(0)} Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)} ) var dtype2GoKind map[DType]reflect.Kind = map[DType]reflect.Kind{ Uint8: reflect.Uint8, Int8: reflect.Int8, Int16: reflect.Int16, Int: reflect.Int32, Int64: reflect.Int64, Half: reflect.Uint16, // <- just uint16 Float: reflect.Float32, Double: reflect.Float64, ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64? ComplexFloat: reflect.Complex64, ComplexDouble: reflect.Complex128, Bool: reflect.Bool, QInt8: reflect.Int8, QUInt8: reflect.Uint8, QInt32: reflect.Int32, BFloat16: reflect.Uint16, // <- just uint16 // ---not implemented --- QUInt4x2: reflect.Invalid, QUInt2x4: reflect.Invalid, Bits1x8: reflect.Invalid, Bits2x4: reflect.Invalid, Bits4x2: reflect.Invalid, Bits8: reflect.Invalid, Bits16: reflect.Invalid, } func (dt DType) GoKind() reflect.Kind { if kind, ok := dtype2GoKind[dt]; ok && kind != reflect.Invalid { return kind } if Debug { log.Printf("WARNING: DType.GoKind() failed: no corresponding Go reflect.Kind to given DType %v\n", dt) } return reflect.Invalid } const ( // NOTE. reflect.Kind 0-26 QUInt8Kind reflect.Kind = 27 QInt8Kind reflect.Kind = 28 QInt32Kind reflect.Kind = 29 Float16Kind reflect.Kind = 30 BFloat16Kind reflect.Kind = 31 QUInt4x2Kind reflect.Kind = 32 QUInt2x4Kind reflect.Kind = 33 Bits1x8Kind reflect.Kind = 34 Bits2x4Kind reflect.Kind = 35 Bits4x2Kind reflect.Kind = 36 Bits8Kind reflect.Kind = 37 Bits16Kind reflect.Kind = 38 ComplexHalfKind reflect.Kind = 39 ) var goKind2DType map[reflect.Kind]DType = map[reflect.Kind]DType{ reflect.Uint8: Uint8, reflect.Int8: Int8, reflect.Int16: Int16, reflect.Int32: Int, reflect.Int64: Int64, reflect.Float32: Float, reflect.Float64: Double, reflect.Complex64: ComplexFloat, reflect.Complex128: ComplexDouble, reflect.Bool: Bool, reflect.Uint16: Half, // Added Kinds QUInt8Kind: QUInt8, QInt8Kind: QInt8, QInt32Kind: QInt32, // Float16Kind: Half, BFloat16Kind: BFloat16, QUInt4x2Kind: QUInt4x2, QUInt2x4Kind: QUInt2x4, Bits1x8Kind: Bits1x8, Bits2x4Kind: Bits2x4, Bits4x2Kind: Bits4x2, Bits8Kind: Bits8, Bits16Kind: Bits16, ComplexHalfKind: ComplexHalf, } type DTypeOptions struct { HalfDTypePref DType Quantized bool } type DTypeOpt func(*DTypeOptions) func DefaultDTypeOptions() *DTypeOptions { return &DTypeOptions{ HalfDTypePref: Half, Quantized: false, } } func HalfDTypePref(v DType) DTypeOpt { if v != Half && v != BFloat16 { if Debug { log.Printf("WARNING: HalfDTypePref(): Ignoring invalid HalfDTypePref. HalfDTypePref either 'gotch.Half' or 'gotch.BFloat16'. Got %v\n", v) } } return func(o *DTypeOptions) { o.HalfDTypePref = v } } func WithQuantized(v bool) DTypeOpt { return func(o *DTypeOptions) { o.Quantized = v } } func GoKind2DType(kind reflect.Kind, opts ...DTypeOpt) (DType, error) { o := DefaultDTypeOptions() for _, opt := range opts { opt(o) } switch { case kind == reflect.Uint16 && o.HalfDTypePref == Half: return Half, nil case kind == reflect.Uint16 && o.HalfDTypePref == BFloat16: return BFloat16, nil case kind == reflect.Int8 && o.Quantized: return QInt8, nil case kind == reflect.Uint8 && o.Quantized: return QUInt8, nil case kind == reflect.Int32 && o.Quantized: return QInt32, nil default: dtype, ok := goKind2DType[kind] if !ok { err := fmt.Errorf("GoKind2DType() failed: no corresponding DType to given Go reflect.Kind %v\n", kind) return Invalid, err } return dtype, nil } } var dtype2GoType map[DType]reflect.Type = map[DType]reflect.Type{ Uint8: reflect.TypeOf(uint8(0)), Int8: reflect.TypeOf(int8(0)), Int16: reflect.TypeOf(int16(0)), Int: reflect.TypeOf(int(0)), Int64: reflect.TypeOf(int64(0)), Half: reflect.TypeOf(uint16(0)), // <- just uint16 Float: reflect.TypeOf(float32(0)), Double: reflect.TypeOf(float64(0)), // ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64? ComplexFloat: reflect.TypeOf(complex64(0)), ComplexDouble: reflect.TypeOf(complex128(0)), Bool: reflect.TypeOf(true), QInt8: reflect.TypeOf(int8(0)), QUInt8: reflect.TypeOf(uint8(0)), QInt32: reflect.TypeOf(int32(0)), BFloat16: reflect.TypeOf(uint16(0)), // <- just uint16 // ---not implemented --- QUInt4x2: reflect.TypeOf(int8(0)), QUInt2x4: reflect.TypeOf(uint8(0)), Bits1x8: reflect.TypeOf(uint8(0)), Bits2x4: reflect.TypeOf(uint8(0)), Bits4x2: reflect.TypeOf(uint8(0)), Bits8: reflect.TypeOf(uint8(0)), Bits16: reflect.TypeOf(uint16(0)), } func (dt DType) GoType() (reflect.Type, error) { typ, ok := dtype2GoType[dt] if !ok { err := fmt.Errorf("DType.GoType() failed: no corresponding Go type to given DType %v\n", typ.String()) return nil, err } return typ, nil } var dtypeNames map[DType]string = map[DType]string{ Uint8: "Uint8", Int8: "Int8", Int16: "Int16", Int: "Int", Int64: "Int64", Half: "Half", // <- just uint16 Float: "Float", Double: "Double", // ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64? ComplexFloat: "ComplexFloat", ComplexDouble: "ComplexDouble", Bool: "Bool", QInt8: "QInt8", QUInt8: "QUInt8", QInt32: "QInt32", BFloat16: "BFloat16", // <- just uint16 // ---not implemented --- QUInt4x2: "QUInt4x2", QUInt2x4: "QUInt2x4", Bits1x8: "Bits1x8", Bits2x4: "Bits2x4", Bits4x2: "Bits4x2", Bits8: "Bits8", Bits16: "Bits16", } func (dt DType) String() string { return dtypeNames[dt] } func DTypeFromData(data interface{}) (DType, error) { dataKind := reflect.TypeOf(data).Kind() // Data is a slice/array if dataKind == reflect.Slice || dataKind == reflect.Array { elementKind := reflect.TypeOf(data).Elem().Kind() return GoKind2DType(elementKind) } // single element return GoKind2DType(dataKind) } // IsFloatDType returns whether dtype is floating point data type. func IsFloatDType(dtype DType) bool { switch dtype { case Double, Float, Half, BFloat16: return true default: return false } } // Default DType: // ============== var DefaultDType DType = Float // SetDefaultDType set DefaultDType to new value and return the previous one. func SetDefaultDType(dtype DType) DType { odtype := DefaultDType DefaultDType = dtype if Debug { log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype) } return odtype }