From 523061eca60017f9f69327f08d6bbbc8a4098c18 Mon Sep 17 00:00:00 2001 From: sugarme Date: Fri, 7 Jul 2023 00:01:23 +1000 Subject: [PATCH] reworked gotch.dtype with more dtypes --- dtype.go | 624 +++++++++++++------------- file-util.go | 9 +- go.sum | 24 + half/bfloat16.go | 167 +++++++ half/bfloat16_test.go | 1 + half/float16.go | 303 +++++++++++++ half/float16_bench_test.go | 88 ++++ half/float16_test.go | 798 ++++++++++++++++++++++++++++++++++ pickle/pickle_example_test.go | 26 +- pickle/serialization.go | 188 ++++---- pickle/storage.go | 66 ++- ts/iter.go | 2 +- ts/npy.go | 20 +- ts/print.go | 46 +- ts/tensor.go | 240 ++++------ ts/util.go | 64 ++- 16 files changed, 2028 insertions(+), 638 deletions(-) create mode 100644 half/bfloat16.go create mode 100644 half/bfloat16_test.go create mode 100644 half/float16.go create mode 100644 half/float16_bench_test.go create mode 100644 half/float16_test.go diff --git a/dtype.go b/dtype.go index 0716410..02d7225 100644 --- a/dtype.go +++ b/dtype.go @@ -3,8 +3,6 @@ package gotch import ( "fmt" "log" - - // "log" "reflect" ) @@ -12,151 +10,148 @@ import ( type CInt = int32 // DType represents different kind of element that a tensor can hold. -// It has an embedded `reflect.Type` for type reflection. -type DType struct { - reflect.Type -} +// Ref. https://github.com/pytorch/pytorch/blob/a290cbf32b0c282aa60fa521ca5c6cd19c7f779f/c10/core/ScalarType.h +type DType int -/* - * // Custom-made Float16 as not exist in Go - * // Ref: https://github.com/golang/go/issues/32022 - * type GoFloat16 = int16 // not implemented yet - * type GoComplexHalf = interface{} // not implemented yet! - * */ - -// TODO: double check these Torch DType to Go type -var ( - Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0 - Int8 DType = DType{reflect.TypeOf(int8(1))} // 1 - Int16 DType = DType{reflect.TypeOf(int16(1))} // 2 - Int DType = DType{reflect.TypeOf(int32(1))} // 3 - Int64 DType = DType{reflect.TypeOf(int64(1))} // 4 - // Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5 - Half DType = DType{reflect.TypeOf(float32(1))} // 5 - Float DType = DType{reflect.TypeOf(float32(1))} // 6 - Double DType = DType{reflect.TypeOf(float64(1))} // 7 - // ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8 - // ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9 - // ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10 - Bool DType = DType{reflect.TypeOf(true)} // 11 +const ( + Invalid DType = -1 + Uint8 DType = 0 + Int8 DType = 1 + Int16 DType = 2 + Int DType = 3 + Int64 DType = 4 + Half DType = 5 + Float DType = 6 + Double DType = 7 + ComplexHalf DType = 8 + ComplexFloat DType = 9 + ComplexDouble DType = 10 + Bool DType = 11 + QInt8 DType = 12 + QUInt8 DType = 13 + QInt32 DType = 14 + BFloat16 DType = 15 + // ---not implemented --- + QUInt4x2 DType = 16 + QUInt2x4 DType = 17 + Bits1x8 DType = 18 + Bits2x4 DType = 19 + Bits4x2 DType = 20 + Bits8 DType = 21 + Bits16 DType = 22 ) -var dtypeGoType = map[DType]reflect.Type{ - Uint8: reflect.TypeOf(uint8(1)), - Int8: reflect.TypeOf(int8(1)), - Int16: reflect.TypeOf(int16(1)), - Int: reflect.TypeOf(int32(1)), - Int64: reflect.TypeOf(int64(1)), - Half: reflect.TypeOf(float32(1)), - Float: reflect.TypeOf(float32(1)), - Double: reflect.TypeOf(float64(1)), - Bool: reflect.TypeOf(true), +var dtype2CKind = map[DType]CInt{ + Uint8: 0, + Int8: 1, + Int16: 2, + Int: 3, + Int64: 4, + Half: 5, + Float: 6, + Double: 7, + ComplexHalf: 8, + ComplexFloat: 9, + ComplexDouble: 10, + Bool: 11, + QInt8: 12, + QUInt8: 13, + QInt32: 14, + BFloat16: 15, + // ---not implemented --- + QUInt4x2: 16, + QUInt2x4: 17, + Bits1x8: 18, + Bits2x4: 19, + Bits4x2: 20, + Bits8: 21, + Bits16: 22, } -// ToDType infers and returns supported equivalent DType from given Go type -func ToDType(typ reflect.Type) (retVal DType, err error) { - var found = false - for key, val := range dtypeGoType { - if val == typ { - retVal = key - found = true - break - } +func (dt DType) CKind() CInt { + if cint, ok := dtype2CKind[dt]; ok { + return cint } - if !found { - err = fmt.Errorf("Unsupported Go type: %v", typ) - return DType{}, err + if Debug { + log.Printf("WARNING: dt.CKind() failed: no corresponding CKind to this DType %v\n", dt) } - - return retVal, nil + return -1 // invalid } -// ToGoType infers and returns supported equivalent Go type from given DType -func ToGoType(dtype DType) (retVal reflect.Type, err error) { - if _, ok := dtypeGoType[dtype]; !ok { - err = fmt.Errorf("Unsupported DType %v", dtype) - return nil, err - } - - retVal = dtypeGoType[dtype] - - return retVal, nil +// Back compat +func (dt DType) CInt() CInt { + return dt.CKind() } -var dtypeCInt = map[DType]CInt{ - Uint8: 0, - Int8: 1, - Int16: 2, - Int: 3, - Int64: 4, - Half: 5, - Float: 6, - Double: 7, - Bool: 11, +var ckind2DType map[CInt]DType = map[CInt]DType{ + 0: Uint8, + 1: Int8, + 2: Int16, + 3: Int, + 4: Int64, + 5: Half, + 6: Float, + 7: Double, + 8: ComplexHalf, + 9: ComplexFloat, + 10: ComplexDouble, + 11: Bool, + 12: QInt8, + 13: QUInt8, + 14: QInt32, + 15: BFloat16, + // ---not implemented --- + 16: QUInt4x2, + 17: QUInt2x4, + 18: Bits1x8, + 19: Bits2x4, + 20: Bits4x2, + 21: Bits8, + 22: Bits16, } -func DType2CInt(dt DType) (retVal CInt, err error) { - if _, ok := dtypeCInt[dt]; !ok { - err = fmt.Errorf("Unsupported CInt conversion from DType: %v\n", dt) +func CKind2DType(ckind int32) DType { + if dtype, ok := ckind2DType[ckind]; ok { + return dtype } - retVal = dtypeCInt[dt] - - return retVal, nil + if Debug { + log.Printf("WARNING: CKind2DType() failed: no corresponding DType to input CInt %v\n", ckind) + } + return -1 // invalid } -func (dt DType) CInt() (retVal CInt) { - retVal, err := DType2CInt(dt) - if err != nil { - log.Fatal(err) - } - - return retVal +var dtypeSize map[DType]uint = map[DType]uint{ + Uint8: 1, + Int8: 1, + Int16: 2, + Int: 4, + Int64: 8, + Half: 2, + Float: 4, + Double: 8, + ComplexHalf: 4, + ComplexFloat: 8, + ComplexDouble: 16, + Bool: 1, + QInt8: 1, + QUInt8: 1, + QInt32: 4, + BFloat16: 2, + QUInt4x2: 2, + QUInt2x4: 1, + // ---not implemented --- + Bits1x8: 1, + Bits2x4: 1, + Bits4x2: 1, + Bits8: 1, + Bits16: 2, } -func CInt2DType(v CInt) (dtype DType, err error) { - var found = false - for key, val := range dtypeCInt { - if val == v { - dtype = key - found = true - break - } - } - - if !found { - err = fmt.Errorf("Unsuported DType for CInt %v\n", v) - return DType{}, err - } - - return dtype, nil - -} - -// dtypeSize is a map of DType and its size in Bytes -var dtypeSize = map[DType]uint{ - Uint8: 1, - Int8: 1, - Int16: 2, - Int: 4, - Int64: 8, - Half: 4, // Should it be? - Float: 4, - Double: 8, - Bool: 1, -} - -// DTypeSize returns DType size in Bytes -func DTypeSize(dt DType) (retVal uint, err error) { - if _, ok := dtypeSize[dt]; !ok { - err = fmt.Errorf("Unsupported conversion DType size in Byte for DType: %v\n", dt) - return 99, err - } - - retVal = dtypeSize[dt] - - return retVal, nil +// Size returns dtype size in Bytes. +func (dt DType) Size() uint { + return dtypeSize[dt] } type DTypeDevice struct { @@ -174,201 +169,228 @@ var ( Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)} ) -// Type Inferring: -// =============== - -// DTypeFromData infers returns equavalent DType from given data -func DTypeFromData(data interface{}) (retVal DType, err error) { - - // NOTE: call `Interface()` to get data type back to interface{} type - typ, _, err := dataCheck(reflect.ValueOf(data).Interface(), 0) - if err != nil { - return retVal, err - } - - if typ.Kind() == reflect.Slice { - return ToDType(typ.Elem()) - } - - return ToDType(typ) +var dtype2GoKind map[DType]reflect.Kind = map[DType]reflect.Kind{ + Uint8: reflect.Uint8, + Int8: reflect.Int8, + Int16: reflect.Int16, + Int: reflect.Int32, + Int64: reflect.Int64, + Half: reflect.Uint16, // <- just uint16 + Float: reflect.Float32, + Double: reflect.Float64, + ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64? + ComplexFloat: reflect.Complex64, + ComplexDouble: reflect.Complex128, + Bool: reflect.Bool, + QInt8: reflect.Int8, + QUInt8: reflect.Uint8, + QInt32: reflect.Int32, + BFloat16: reflect.Uint16, // <- just uint16 + // ---not implemented --- + QUInt4x2: reflect.Invalid, + QUInt2x4: reflect.Invalid, + Bits1x8: reflect.Invalid, + Bits2x4: reflect.Invalid, + Bits4x2: reflect.Invalid, + Bits8: reflect.Invalid, + Bits16: reflect.Invalid, } -// NOTE: 0 is reflect.Kind() of Invalid -// See: https://golang.org/pkg/reflect/#Kind -func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) { - v := reflect.ValueOf(data) - var goType reflect.Type = reflect.TypeOf(data) - var total int = count - var round = 0 +func (dt DType) GoKind() reflect.Kind { + if kind, ok := dtype2GoKind[dt]; ok && kind != reflect.Invalid { + return kind + } - switch v.Kind() { - case reflect.Slice, reflect.Array: - if round == 0 { - round = v.Len() + if Debug { + log.Printf("WARNING: DType.GoKind() failed: no corresponding Go reflect.Kind to given DType %v\n", dt) + } + + return reflect.Invalid +} + +const ( + // NOTE. reflect.Kind 0-26 + QUInt8Kind reflect.Kind = 27 + QInt8Kind reflect.Kind = 28 + QInt32Kind reflect.Kind = 29 + Float16Kind reflect.Kind = 30 + BFloat16Kind reflect.Kind = 31 + QUInt4x2Kind reflect.Kind = 32 + QUInt2x4Kind reflect.Kind = 33 + Bits1x8Kind reflect.Kind = 34 + Bits2x4Kind reflect.Kind = 35 + Bits4x2Kind reflect.Kind = 36 + Bits8Kind reflect.Kind = 37 + Bits16Kind reflect.Kind = 38 + ComplexHalfKind reflect.Kind = 39 +) + +var goKind2DType map[reflect.Kind]DType = map[reflect.Kind]DType{ + reflect.Uint8: Uint8, + reflect.Int8: Int8, + reflect.Int16: Int16, + reflect.Int32: Int, + reflect.Int64: Int64, + reflect.Float32: Float, + reflect.Float64: Double, + reflect.Complex64: ComplexFloat, + reflect.Complex128: ComplexDouble, + reflect.Bool: Bool, + reflect.Uint16: Half, + + // Added Kinds + QUInt8Kind: QUInt8, + QInt8Kind: QInt8, + QInt32Kind: QInt32, + // Float16Kind: Half, + BFloat16Kind: BFloat16, + QUInt4x2Kind: QUInt4x2, + QUInt2x4Kind: QUInt2x4, + Bits1x8Kind: Bits1x8, + Bits2x4Kind: Bits2x4, + Bits4x2Kind: Bits4x2, + Bits8Kind: Bits8, + Bits16Kind: Bits16, + ComplexHalfKind: ComplexHalf, +} + +type DTypeOptions struct { + HalfDTypePref DType + Quantized bool +} + +type DTypeOpt func(*DTypeOptions) + +func DefaultDTypeOptions() *DTypeOptions { + return &DTypeOptions{ + HalfDTypePref: Half, + Quantized: false, + } +} + +func HalfDTypePref(v DType) DTypeOpt { + if v != Half && v != BFloat16 { + if Debug { + log.Printf("WARNING: HalfDTypePref(): Ignoring invalid HalfDTypePref. HalfDTypePref either 'gotch.Half' or 'gotch.BFloat16'. Got %v\n", v) } - for i := 0; i < v.Len(); i++ { - round-- - goType, total, err = dataCheck(v.Index(i).Interface(), total) - - if err != nil { - return reflect.TypeOf(reflect.Zero), 0, err - } - } - - return goType, total, nil - - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: - total++ - if goType.String() != "invalid" { - goType = v.Type() - } - default: - err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind()) - return reflect.TypeOf(reflect.Zero), 0, err } - return goType, total, nil -} - -// ElementGoType infers and returns Go type of element in given data -func ElementGoType(data interface{}) (retVal reflect.Type, err error) { - dataValue := reflect.ValueOf(data) - return elementType(dataValue) -} - -func elementType(data reflect.Value) (dataType reflect.Type, err error) { - dataKind := data.Kind() - switch dataKind { - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: - dataType = data.Type() - case reflect.Slice, reflect.Array: - data = data.Elem() - dataType, err = elementType(data) // recursively type inferring - default: - err = fmt.Errorf("Unsupported type for data type %v\n", dataType) - return DType{}, err + return func(o *DTypeOptions) { + o.HalfDTypePref = v } - - return dataType, nil } -// DataDType infers and returns data type of tensor data -func DataDType(v interface{}, shape []int64) (retVal DType, err error) { - // assuming that all elements in data have the same type - switch { - case len(shape) == 0: - retVal, err = ElementDType(v) - case len(shape) > 0: - return ElementDType(v.([]interface{})[0]) - default: - err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v)) - return DType{}, err +func WithQuantized(v bool) DTypeOpt { + return func(o *DTypeOptions) { + o.Quantized = v } - return DType{}, nil } -// ElementDType infers and returns its own tensor data type -func ElementDType(v interface{}) (retVal DType, err error) { - switch v.(type) { - case uint8: - retVal = Uint8 - case int8: - retVal = Int8 - case int16: - retVal = Int16 - case int32: - retVal = Int - case int64: - retVal = Int64 - case float32: - retVal = Float - case float64: - retVal = Double - case bool: - retVal = Bool - default: - err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v)) - } - - return retVal, nil -} - -// TypeOf infers and returns element Go type from given tensor DType and shape -func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) { - var typ reflect.Type - if typ, err = ToGoType(dt); err != nil { - return nil, err +func GoKind2DType(kind reflect.Kind, opts ...DTypeOpt) (DType, error) { + o := DefaultDTypeOptions() + for _, opt := range opts { + opt(o) } switch { - case len(shape) == 0: - return typ, nil - case len(shape) > 0: - return reflect.SliceOf(typ), nil + case kind == reflect.Uint16 && o.HalfDTypePref == Half: + return Half, nil + case kind == reflect.Uint16 && o.HalfDTypePref == BFloat16: + return BFloat16, nil + case kind == reflect.Int8 && o.Quantized: + return QInt8, nil + case kind == reflect.Uint8 && o.Quantized: + return QUInt8, nil + case kind == reflect.Int32 && o.Quantized: + return QInt32, nil + default: - err = fmt.Errorf("Unsupported data type.") - return nil, err + dtype, ok := goKind2DType[kind] + if !ok { + err := fmt.Errorf("GoKind2DType() failed: no corresponding DType to given Go reflect.Kind %v\n", kind) + return Invalid, err + } + return dtype, nil } } -/* - * // TypeCheck checks whether data Go type matching DType - * func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) { - * dataValue := reflect.ValueOf(data) - * var dataType reflect.Type - * var err error - * dataType, err = elementType(dataValue) - * if err != nil { - * msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) - * msg += err.Error() - * return false, msg - * } - * - * matched = dataType == dtype.Type - * msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) - * - * return matched, msg - * } - * */ - -var supportedTypes = map[reflect.Kind]bool{ - reflect.Uint8: true, - reflect.Int8: true, - reflect.Int16: true, - reflect.Int32: true, - reflect.Int64: true, - reflect.Float32: true, - reflect.Float64: true, - reflect.Bool: true, +var dtype2GoType map[DType]reflect.Type = map[DType]reflect.Type{ + Uint8: reflect.TypeOf(uint8(0)), + Int8: reflect.TypeOf(int8(0)), + Int16: reflect.TypeOf(int16(0)), + Int: reflect.TypeOf(int(0)), + Int64: reflect.TypeOf(int64(0)), + Half: reflect.TypeOf(uint16(0)), // <- just uint16 + Float: reflect.TypeOf(float32(0)), + Double: reflect.TypeOf(float64(0)), + // ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64? + ComplexFloat: reflect.TypeOf(complex64(0)), + ComplexDouble: reflect.TypeOf(complex128(0)), + Bool: reflect.TypeOf(true), + QInt8: reflect.TypeOf(int8(0)), + QUInt8: reflect.TypeOf(uint8(0)), + QInt32: reflect.TypeOf(int32(0)), + BFloat16: reflect.TypeOf(uint16(0)), // <- just uint16 + // ---not implemented --- + QUInt4x2: reflect.TypeOf(int8(0)), + QUInt2x4: reflect.TypeOf(uint8(0)), + Bits1x8: reflect.TypeOf(uint8(0)), + Bits2x4: reflect.TypeOf(uint8(0)), + Bits4x2: reflect.TypeOf(uint8(0)), + Bits8: reflect.TypeOf(uint8(0)), + Bits16: reflect.TypeOf(uint16(0)), } -var scalarTypes = map[reflect.Kind]bool{ - reflect.Bool: true, - reflect.Int: true, - reflect.Int8: true, - reflect.Int16: true, - reflect.Int32: true, - reflect.Int64: true, - reflect.Uint: true, - reflect.Uint8: true, - reflect.Uint16: true, - reflect.Uint32: true, - reflect.Uint64: true, - reflect.Uintptr: true, - reflect.Float32: true, - reflect.Float64: true, - reflect.Complex64: true, - reflect.Complex128: true, +func (dt DType) GoType() (reflect.Type, error) { + typ, ok := dtype2GoType[dt] + if !ok { + err := fmt.Errorf("DType.GoType() failed: no corresponding Go type to given DType %v\n", typ.String()) + return nil, err + } + + return typ, nil } -// IsSupportedScalar checks whether given SCALAR type is supported -// TODO: check input is a scalar. -func IsSupportedScalar(k reflect.Kind) bool { - // if _, ok := scalarTypes[k]; !ok { - // log.Fatalf("Input type: %v is not a Go scalar type.", k) - // } - - _, retVal := supportedTypes[k] - - return retVal +var dtypeNames map[DType]string = map[DType]string{ + Uint8: "Uint8", + Int8: "Int8", + Int16: "Int16", + Int: "Int", + Int64: "Int64", + Half: "Half", // <- just uint16 + Float: "Float", + Double: "Double", + // ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64? + ComplexFloat: "ComplexFloat", + ComplexDouble: "ComplexDouble", + Bool: "Bool", + QInt8: "QInt8", + QUInt8: "QUInt8", + QInt32: "QInt32", + BFloat16: "BFloat16", // <- just uint16 + // ---not implemented --- + QUInt4x2: "QUInt4x2", + QUInt2x4: "QUInt2x4", + Bits1x8: "Bits1x8", + Bits2x4: "Bits2x4", + Bits4x2: "Bits4x2", + Bits8: "Bits8", + Bits16: "Bits16", +} + +func (dt DType) String() string { + return dtypeNames[dt] +} + +func DTypeFromData(data interface{}) (DType, error) { + dataKind := reflect.TypeOf(data).Kind() + + // Data is a slice/array + if dataKind == reflect.Slice || dataKind == reflect.Array { + elementKind := reflect.TypeOf(data).Elem().Kind() + return GoKind2DType(elementKind) + } + + // single element + return GoKind2DType(dataKind) } diff --git a/file-util.go b/file-util.go index 85b0208..6810d73 100644 --- a/file-util.go +++ b/file-util.go @@ -113,10 +113,15 @@ var ModelUrls map[string]string = map[string]string{ // 1. Resolves input string to a fullpath cached filename candidate. // 2. Check it at `CachedDir`, if exists, then return the candidate. If not // 3. Retrieves and Caches data to `CachedDir` and returns path to cached data -func CachedPath(filenameOrUrl string) (resolvedPath string, err error) { +func CachedPath(filenameOrUrl string, folderOpt ...string) (resolvedPath string, err error) { filename := path.Base(filenameOrUrl) // Resolves to "candidate" filename at `CachedDir` - cachedFileCandidate := fmt.Sprintf("%s/%s", CachedDir, filename) + fullPath := CachedDir + if len(folderOpt) > 0 { + fullPath = fmt.Sprintf("%v/%v", CachedDir, folderOpt[0]) + } + + cachedFileCandidate := fmt.Sprintf("%s/%s", fullPath, filename) // 1. Cached candidate file exists if _, err := os.Stat(cachedFileCandidate); err == nil { diff --git a/go.sum b/go.sum index 6bab717..98862f5 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,31 @@ github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/image v0.0.0-20200927104501-e162460cd6b5 h1:QelT11PB4FXiDEXucrfNckHoFxwt8USGY1ajP1ZF5lM= golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI= golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/half/bfloat16.go b/half/bfloat16.go new file mode 100644 index 0000000..c81a8c5 --- /dev/null +++ b/half/bfloat16.go @@ -0,0 +1,167 @@ +package half + +import ( + "math" + "math/bits" +) + +// A 16-bit floating point type implementing the bfloat16 format. +// Ref. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format +// https://github.com/starkat99/half-rs/tree/main/src/bfloat + +// The bfloat16 - Google 'brain' floating point format is a truncated 16-bit version of the IEEE 754 standard binary32. +// bfloat16 has approximately the same dynamic range as float32 (8 bits -> 3.4 × 10^38) by having a lower precision than float16. +// While float16 has a precision of 10 bits, bfloat16 has a precision of only 7 bits. +// +// +------------+------------------------+----------------------------+ +// | 1-bit sign | 8-bit exponent (range) | 7-bit fraction (precision) | +// +------------+------------------------+----------------------------+ +type BFloat16 uint16 + +// Ref.https://github.com/starkat99/half-rs/blob/cabfc74e2a48b44b4556780f9d1550dd50a708be/src/bfloat/convert.rs#L5C1-L24C1 +func Float32ToBFloat16(value float32) uint16 { + // convert to raw bytes + x := math.Float32bits(value) + + // Check for NaN + if (x & 0x7FFF_FFFF) > 0x7F80_0000 { + // keep high part of current mantissa but also set most significant mantissa bit + return uint16((x >> 16) | 0x0040) + } + + // Round and shift + var roundBit uint32 = 0x0000_8000 + if ((x & roundBit) != 0) && ((x & (3*roundBit - 1)) != 0) { + return uint16(x>>16) + 1 + } else { + return uint16(x >> 16) + } +} + +func Float64ToBFloat16(value float64) uint16 { + // Convert o raw bytes, truncating the last 32-bits of mantissa + // that precision will always be lost on half-precision + val := math.Float64bits(value) + x := uint32(val >> 32) + + // Extract IEEE754 components + sign := x & 0x8000_0000 + exp := x & 0x7FF0_0000 + man := x & 0x000F_FFFF + + // Check for all exponent bit being set, which is Infinity or NaN + if exp == 0x7FF0_0000 { + // Set mantissa MSB for NaN and also keep shifted mantissa bits. + // Also check the last 32 bits. + var nanBit uint32 = 0x0040 + if man == 0 && (uint32(val) == 0) { + nanBit = 0 + } + + return uint16((sign >> 16) | 0x7F80 | nanBit | (man >> 13)) + } + + // The number is normalized, start assembling half precision version + halfSign := sign >> 16 + + // Unbias the exponent, then bias for bfloat16 precision + unbiasedExp := (int64(exp>>20) - 1023) + halfExp := unbiasedExp + 127 + + // Check for exponent overflow, return +infinity + if halfExp >= 0xFF { + return uint16(halfSign | 0x7F80) + } + + // Check for underflow + if halfExp <= 0 { + // Check mantissa for what we can do + if 7-halfExp > 21 { + // No rounding possibility, so this is a full underflow, return signed zero + return uint16(halfSign) + } + + // Don't forget about hidden leading mantissa bit when assembling mantissa + man = man | 0x0010_0000 + halfMan := man >> (14 - halfExp) + + // Check for rounding + var roundBit uint32 = 1 << (13 - halfExp) + if ((man & roundBit) != 0) && ((man & (3*roundBit - 1)) != 0) { + halfMan += 1 + } + + // No exponent for subnormals + return uint16(halfSign | halfMan) + } + + // Rebias the exponent + halfExp1 := uint32(halfExp) << 7 + halfMan1 := man >> 13 + + // Check for rounding + var roundBit1 uint32 = 0x0000_1000 + + if ((man & roundBit1) != 0) && ((man & (3*roundBit1 - 1)) != 0) { + // Round it + return uint16((halfSign | halfExp1 | halfMan1) + 1) + } else { + return uint16(halfSign | halfExp1 | halfMan1) + } +} + +func BFloat16ToFloat32(i uint16) float32 { + // If NaN, keep current mantissa but also set most significant mantissa bit + if i&0x7FFF > 0x7F80 { + return math.Float32frombits((uint32(i) | 0x0040) << 16) + } else { + return math.Float32frombits(uint32(i) << 16) + } +} + +func BFloat16ToFloat64(i uint16) float64 { + // Check for signed zero + if i&0x7FFF == 0 { + return math.Float64frombits(uint64(i) << 48) + } + + halfSign := uint64(i & 0x8000) + halfExp := uint64(i & 0x7F80) + halfMan := uint64(i & 0x007F) + + // Check for an infinity or NaN when all exponent bits set + if halfExp == 0x7F80 { + // Check for signed infinity if mantissa is zero + if halfMan == 0 { + return math.Float64frombits((halfSign << 48) | 0x7FF0_0000_0000_0000) + } else { + // NaN, keep current mantissa but also set most significant mantissa bit + return math.Float64frombits((halfSign << 48) | 0x7FF8_0000_0000_0000 | (halfMan << 45)) + } + } + + // Calculate double-precision components with adjusted exponent + sign := halfSign << 48 + + // Unbias exponent + unbiasedExp := (int64(halfExp) >> 7) - 127 + + // Check for subnormals, which will be normalized by adjusting exponent + if halfExp == 0 { + // Calculate how much to adjust the exponent by + // leading zeros uint16 + e := bits.LeadingZeros16(uint16(halfMan)) - 9 + + // Rebias and adjust exponent + exp := (uint64(1023-127-e) << 52) + man := (halfMan << (46 + e)) & 0xF_FFFF_FFFF_FFFF + + return math.Float64frombits(sign | exp | man) + } + + // Rebias exponent for a normalized normal + exp := uint64(unbiasedExp+1023) << 52 + man := (halfMan & 0x007F) << 45 + + return math.Float64frombits(sign | exp | man) +} diff --git a/half/bfloat16_test.go b/half/bfloat16_test.go new file mode 100644 index 0000000..b6bd3c8 --- /dev/null +++ b/half/bfloat16_test.go @@ -0,0 +1 @@ +package half diff --git a/half/float16.go b/half/float16.go new file mode 100644 index 0000000..3a5d1f9 --- /dev/null +++ b/half/float16.go @@ -0,0 +1,303 @@ +// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker +// +// Special thanks to Kathryn Long for her Rust implementation +// of float16 at github.com/starkat99/half-rs (MIT license) + +// Package half defines support for half-precision floating-point numbers. +package half + +import ( + "math" + "strconv" +) + +// Float16 represents IEEE 754 half-precision floating-point numbers (binary16). +type Float16 uint16 + +// Precision indicates whether the conversion to Float16 is +// exact, subnormal without dropped bits, inexact, underflow, or overflow. +type Precision int + +const ( + + // PrecisionExact is for non-subnormals that don't drop bits during conversion. + // All of these can round-trip. Should always convert to float16. + PrecisionExact Precision = iota + + // PrecisionUnknown is for subnormals that don't drop bits during conversion but + // not all of these can round-trip so precision is unknown without more effort. + // Only 2046 of these can round-trip and the rest cannot round-trip. + PrecisionUnknown + + // PrecisionInexact is for dropped significand bits and cannot round-trip. + // Some of these are subnormals. Cannot round-trip float32->float16->float32. + PrecisionInexact + + // PrecisionUnderflow is for Underflows. Cannot round-trip float32->float16->float32. + PrecisionUnderflow + + // PrecisionOverflow is for Overflows. Cannot round-trip float32->float16->float32. + PrecisionOverflow +) + +// PrecisionFromfloat32 returns Precision without performing +// the conversion. Conversions from both Infinity and NaN +// values will always report PrecisionExact even if NaN payload +// or NaN-Quiet-Bit is lost. This function is kept simple to +// allow inlining and run < 0.5 ns/op, to serve as a fast filter. +func PrecisionFromfloat32(f32 float32) Precision { + u32 := math.Float32bits(f32) + + if u32 == 0 || u32 == 0x80000000 { + // +- zero will always be exact conversion + return PrecisionExact + } + + const COEFMASK uint32 = 0x7fffff // 23 least significant bits + const EXPSHIFT uint32 = 23 + const EXPBIAS uint32 = 127 + const EXPMASK uint32 = uint32(0xff) << EXPSHIFT + const DROPMASK uint32 = COEFMASK >> 10 + + exp := int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS) + coef := u32 & COEFMASK + + if exp == 128 { + // +- infinity or NaN + // apps may want to do extra checks for NaN separately + return PrecisionExact + } + + // https://en.wikipedia.org/wiki/Half-precision_floating-point_format says, + // "Decimals between 2^−24 (minimum positive subnormal) and 2^−14 (maximum subnormal): fixed interval 2^−24" + if exp < -24 { + return PrecisionUnderflow + } + if exp > 15 { + return PrecisionOverflow + } + if (coef & DROPMASK) != uint32(0) { + // these include subnormals and non-subnormals that dropped bits + return PrecisionInexact + } + + if exp < -14 { + // Subnormals. Caller may want to test these further. + // There are 2046 subnormals that can successfully round-trip f32->f16->f32 + // and 20 of those 2046 have 32-bit input coef == 0. + // RFC 7049 and 7049bis Draft 12 don't precisely define "preserves value" + // so some protocols and libraries will choose to handle subnormals differently + // when deciding to encode them to CBOR float32 vs float16. + return PrecisionUnknown + } + + return PrecisionExact +} + +// Frombits returns the float16 number corresponding to the IEEE 754 binary16 +// representation u16, with the sign bit of u16 and the result in the same bit +// position. Frombits(Bits(x)) == x. +func Frombits(u16 uint16) Float16 { + return Float16(u16) +} + +// Fromfloat32 returns a Float16 value converted from f32. Conversion uses +// IEEE default rounding (nearest int, with ties to even). +func Fromfloat32(f32 float32) Float16 { + return Float16(f32bitsToF16bits(math.Float32bits(f32))) +} + +// ErrInvalidNaNValue indicates a NaN was not received. +const ErrInvalidNaNValue = float16Error("float16: invalid NaN value, expected IEEE 754 NaN") + +type float16Error string + +func (e float16Error) Error() string { return string(e) } + +// FromNaN32ps converts nan to IEEE binary16 NaN while preserving both +// signaling and payload. Unlike Fromfloat32(), which can only return +// qNaN because it sets quiet bit = 1, this can return both sNaN and qNaN. +// If the result is infinity (sNaN with empty payload), then the +// lowest bit of payload is set to make the result a NaN. +// Returns ErrInvalidNaNValue and 0x7c01 (sNaN) if nan isn't IEEE 754 NaN. +// This function was kept simple to be able to inline. +func FromNaN32ps(nan float32) (Float16, error) { + const SNAN = Float16(uint16(0x7c01)) // signaling NaN + + u32 := math.Float32bits(nan) + sign := u32 & 0x80000000 + exp := u32 & 0x7f800000 + coef := u32 & 0x007fffff + + if (exp != 0x7f800000) || (coef == 0) { + return SNAN, ErrInvalidNaNValue + } + + u16 := uint16((sign >> 16) | uint32(0x7c00) | (coef >> 13)) + + if (u16 & 0x03ff) == 0 { + // result became infinity, make it NaN by setting lowest bit in payload + u16 |= 0x0001 + } + + return Float16(u16), nil +} + +// NaN returns a Float16 of IEEE 754 binary16 not-a-number (NaN). +// Returned NaN value 0x7e01 has all exponent bits = 1 with the +// first and last bits = 1 in the significand. This is consistent +// with Go's 64-bit math.NaN(). Canonical CBOR in RFC 7049 uses 0x7e00. +func NaN() Float16 { + return Float16(0x7e01) +} + +// Inf returns a Float16 with an infinity value with the specified sign. +// A sign >= returns positive infinity. +// A sign < 0 returns negative infinity. +func Inf(sign int) Float16 { + if sign >= 0 { + return Float16(0x7c00) + } + return Float16(0x8000 | 0x7c00) +} + +// Float32 returns a float32 converted from f (Float16). +// This is a lossless conversion. +func (f Float16) Float32() float32 { + u32 := f16bitsToF32bits(uint16(f)) + return math.Float32frombits(u32) +} + +// Bits returns the IEEE 754 binary16 representation of f, with the sign bit +// of f and the result in the same bit position. Bits(Frombits(x)) == x. +func (f Float16) Bits() uint16 { + return uint16(f) +} + +// IsNaN reports whether f is an IEEE 754 binary16 “not-a-number” value. +func (f Float16) IsNaN() bool { + return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0) +} + +// IsQuietNaN reports whether f is a quiet (non-signaling) IEEE 754 binary16 +// “not-a-number” value. +func (f Float16) IsQuietNaN() bool { + return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0) && (f&0x0200 != 0) +} + +// IsInf reports whether f is an infinity (inf). +// A sign > 0 reports whether f is positive inf. +// A sign < 0 reports whether f is negative inf. +// A sign == 0 reports whether f is either inf. +func (f Float16) IsInf(sign int) bool { + return ((f == 0x7c00) && sign >= 0) || + (f == 0xfc00 && sign <= 0) +} + +// IsFinite returns true if f is neither infinite nor NaN. +func (f Float16) IsFinite() bool { + return (uint16(f) & uint16(0x7c00)) != uint16(0x7c00) +} + +// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN. +func (f Float16) IsNormal() bool { + exp := uint16(f) & uint16(0x7c00) + return (exp != uint16(0x7c00)) && (exp != 0) +} + +// Signbit reports whether f is negative or negative zero. +func (f Float16) Signbit() bool { + return (uint16(f) & uint16(0x8000)) != 0 +} + +// String satisfies the fmt.Stringer interface. +func (f Float16) String() string { + return strconv.FormatFloat(float64(f.Float32()), 'f', -1, 32) +} + +// f16bitsToF32bits returns uint32 (float32 bits) converted from specified uint16. +func f16bitsToF32bits(in uint16) uint32 { + // All 65536 conversions with this were confirmed to be correct + // by Montgomery Edwards⁴⁴⁸ (github.com/x448). + + sign := uint32(in&0x8000) << 16 // sign for 32-bit + exp := uint32(in&0x7c00) >> 10 // exponenent for 16-bit + coef := uint32(in&0x03ff) << 13 // significand for 32-bit + + if exp == 0x1f { + if coef == 0 { + // infinity + return sign | 0x7f800000 | coef + } + // NaN + return sign | 0x7fc00000 | coef + } + + if exp == 0 { + if coef == 0 { + // zero + return sign + } + + // normalize subnormal numbers + exp++ + for coef&0x7f800000 == 0 { + coef <<= 1 + exp-- + } + coef &= 0x007fffff + } + + return sign | ((exp + (0x7f - 0xf)) << 23) | coef +} + +// f32bitsToF16bits returns uint16 (Float16 bits) converted from the specified float32. +// Conversion rounds to nearest integer with ties to even. +func f32bitsToF16bits(u32 uint32) uint16 { + // Translated from Rust to Go by Montgomery Edwards⁴⁴⁸ (github.com/x448). + // All 4294967296 conversions with this were confirmed to be correct by x448. + // Original Rust implementation is by Kathryn Long (github.com/starkat99) with MIT license. + + sign := u32 & 0x80000000 + exp := u32 & 0x7f800000 + coef := u32 & 0x007fffff + + if exp == 0x7f800000 { + // NaN or Infinity + nanBit := uint32(0) + if coef != 0 { + nanBit = uint32(0x0200) + } + return uint16((sign >> 16) | uint32(0x7c00) | nanBit | (coef >> 13)) + } + + halfSign := sign >> 16 + + unbiasedExp := int32(exp>>23) - 127 + halfExp := unbiasedExp + 15 + + if halfExp >= 0x1f { + return uint16(halfSign | uint32(0x7c00)) + } + + if halfExp <= 0 { + if 14-halfExp > 24 { + return uint16(halfSign) + } + c := coef | uint32(0x00800000) + halfCoef := c >> uint32(14-halfExp) + roundBit := uint32(1) << uint32(13-halfExp) + if (c&roundBit) != 0 && (c&(3*roundBit-1)) != 0 { + halfCoef++ + } + return uint16(halfSign | halfCoef) + } + + uHalfExp := uint32(halfExp) << 10 + halfCoef := coef >> 13 + roundBit := uint32(0x00001000) + if (coef&roundBit) != 0 && (coef&(3*roundBit-1)) != 0 { + return uint16((halfSign | uHalfExp | halfCoef) + 1) + } + return uint16(halfSign | uHalfExp | halfCoef) +} diff --git a/half/float16_bench_test.go b/half/float16_bench_test.go new file mode 100644 index 0000000..c1ed12a --- /dev/null +++ b/half/float16_bench_test.go @@ -0,0 +1,88 @@ +// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker + +package half_test + +import ( + "math" + "testing" + + float16 "github.com/sugarme/gotch/half" +) + +// prevent compiler optimizing out code by assigning to these +var resultF16 float16.Float16 +var resultF32 float32 +var resultStr string +var pcn float16.Precision + +func BenchmarkFloat32pi(b *testing.B) { + result := float32(0) + pi32 := float32(math.Pi) + pi16 := float16.Fromfloat32(pi32) + for i := 0; i < b.N; i++ { + f16 := float16.Frombits(uint16(pi16)) + result = f16.Float32() + } + resultF32 = result +} + +func BenchmarkFrombits(b *testing.B) { + result := float16.Float16(0) + pi32 := float32(math.Pi) + pi16 := float16.Fromfloat32(pi32) + for i := 0; i < b.N; i++ { + result = float16.Frombits(uint16(pi16)) + } + resultF16 = result +} + +func BenchmarkFromFloat32pi(b *testing.B) { + result := float16.Float16(0) + + pi := float32(math.Pi) + for i := 0; i < b.N; i++ { + result = float16.Fromfloat32(pi) + } + resultF16 = result +} + +func BenchmarkFromFloat32nan(b *testing.B) { + result := float16.Float16(0) + + nan := float32(math.NaN()) + for i := 0; i < b.N; i++ { + result = float16.Fromfloat32(nan) + } + resultF16 = result +} + +func BenchmarkFromFloat32subnorm(b *testing.B) { + result := float16.Float16(0) + + subnorm := math.Float32frombits(0x007fffff) + for i := 0; i < b.N; i++ { + result = float16.Fromfloat32(subnorm) + } + resultF16 = result +} + +func BenchmarkPrecisionFromFloat32(b *testing.B) { + var result float16.Precision + + for i := 0; i < b.N; i++ { + f32 := float32(0.00001) + float32(0.00001) + result = float16.PrecisionFromfloat32(f32) + } + pcn = result +} + +func BenchmarkString(b *testing.B) { + var result string + + pi32 := float32(math.Pi) + pi16 := float16.Fromfloat32(pi32) + for i := 0; i < b.N; i++ { + result = pi16.String() + } + resultStr = result +} diff --git a/half/float16_test.go b/half/float16_test.go new file mode 100644 index 0000000..073c267 --- /dev/null +++ b/half/float16_test.go @@ -0,0 +1,798 @@ +// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker + +package half_test + +import ( + "bytes" + "crypto/sha512" + "encoding/binary" + "encoding/hex" + "fmt" + "math" + "testing" + + float16 "github.com/sugarme/gotch/half" +) + +// wantF32toF16bits is a tiny subset of expected values +var wantF32toF16bits = []struct { + in float32 + out uint16 +}{ + // generated to provide 100% code coverage plus additional tests for rounding, etc. + {in: math.Float32frombits(0x00000000), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x00000001), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x00001fff), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x00002000), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x00003fff), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x00004000), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x007fffff), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x00800000), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x33000000), out: 0x0000}, // in f32=0.000000, out f16=0 + {in: math.Float32frombits(0x33000001), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645 + {in: math.Float32frombits(0x33000002), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645 + {in: math.Float32frombits(0x387fc000), out: 0x03ff}, // in f32=0.000061, out f16=0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips + {in: math.Float32frombits(0x387fffff), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156 + {in: math.Float32frombits(0x38800000), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156 + {in: math.Float32frombits(0x38801fff), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476 + {in: math.Float32frombits(0x38802000), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476 + {in: math.Float32frombits(0x38803fff), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366 + {in: math.Float32frombits(0x38804000), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366 + {in: math.Float32frombits(0x33bfffff), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645 + {in: math.Float32frombits(0x33c00000), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929 + {in: math.Float32frombits(0x33c00001), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929 + {in: math.Float32frombits(0x477fffff), out: 0x7c00}, // in f32=65535.996094, out f16=+Inf + {in: math.Float32frombits(0x47800000), out: 0x7c00}, // in f32=65536.000000, out f16=+Inf + {in: math.Float32frombits(0x7f7fffff), out: 0x7c00}, // in f32=340282346638528859811704183484516925440.000000, out f16=+Inf + {in: math.Float32frombits(0x7f800000), out: 0x7c00}, // in f32=+Inf, out f16=+Inf + {in: math.Float32frombits(0x7f801fff), out: 0x7e00}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0x7f802000), out: 0x7e01}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0x7f803fff), out: 0x7e01}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0x7f804000), out: 0x7e02}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0x7fffffff), out: 0x7fff}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0x80000000), out: 0x8000}, // in f32=-0.000000, out f16=-0 + {in: math.Float32frombits(0x80001fff), out: 0x8000}, // in f32=-0.000000, out f16=-0 + {in: math.Float32frombits(0x80002000), out: 0x8000}, // in f32=-0.000000, out f16=-0 + {in: math.Float32frombits(0x80003fff), out: 0x8000}, // in f32=-0.000000, out f16=-0 + {in: math.Float32frombits(0x80004000), out: 0x8000}, // in f32=-0.000000, out f16=-0 + {in: math.Float32frombits(0x807fffff), out: 0x8000}, // in f32=-0.000000, out f16=-0 + {in: math.Float32frombits(0x80800000), out: 0x8000}, // in f32=-0.000000, out f16=-0 + {in: math.Float32frombits(0xb87fc000), out: 0x83ff}, // in f32=-0.000061, out f16=-0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips + {in: math.Float32frombits(0xb87fffff), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156 + {in: math.Float32frombits(0xb8800000), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156 + {in: math.Float32frombits(0xb8801fff), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476 + {in: math.Float32frombits(0xb8802000), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476 + {in: math.Float32frombits(0xb8803fff), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366 + {in: math.Float32frombits(0xb8804000), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366 + {in: math.Float32frombits(0xc77fffff), out: 0xfc00}, // in f32=-65535.996094, out f16=-Inf + {in: math.Float32frombits(0xc7800000), out: 0xfc00}, // in f32=-65536.000000, out f16=-Inf + {in: math.Float32frombits(0xff7fffff), out: 0xfc00}, // in f32=-340282346638528859811704183484516925440.000000, out f16=-Inf + {in: math.Float32frombits(0xff800000), out: 0xfc00}, // in f32=-Inf, out f16=-Inf + {in: math.Float32frombits(0xff801fff), out: 0xfe00}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0xff802000), out: 0xfe01}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0xff803fff), out: 0xfe01}, // in f32=NaN, out f16=NaN + {in: math.Float32frombits(0xff804000), out: 0xfe02}, // in f32=NaN, out f16=NaN + // additional tests + {in: math.Float32frombits(0xc77ff000), out: 0xfc00}, // in f32=-65520.000000, out f16=-Inf + {in: math.Float32frombits(0xc77fef00), out: 0xfbff}, // in f32=-65519.000000, out f16=-65504 + {in: math.Float32frombits(0xc77fee00), out: 0xfbff}, // in f32=-65518.000000, out f16=-65504 + {in: math.Float32frombits(0xc5802000), out: 0xec01}, // in f32=-4100.000000, out f16=-4100 + {in: math.Float32frombits(0xc5801800), out: 0xec01}, // in f32=-4099.000000, out f16=-4100 + {in: math.Float32frombits(0xc5801000), out: 0xec00}, // in f32=-4098.000000, out f16=-4096 + {in: math.Float32frombits(0xc5800800), out: 0xec00}, // in f32=-4097.000000, out f16=-4096 + {in: math.Float32frombits(0xc5800000), out: 0xec00}, // in f32=-4096.000000, out f16=-4096 + {in: math.Float32frombits(0xc57ff000), out: 0xec00}, // in f32=-4095.000000, out f16=-4096 + {in: math.Float32frombits(0xc57fe000), out: 0xebff}, // in f32=-4094.000000, out f16=-4094 + {in: math.Float32frombits(0xc57fd000), out: 0xebfe}, // in f32=-4093.000000, out f16=-4092 + {in: math.Float32frombits(0xc5002000), out: 0xe801}, // in f32=-2050.000000, out f16=-2050 + {in: math.Float32frombits(0xc5001000), out: 0xe800}, // in f32=-2049.000000, out f16=-2048 + {in: math.Float32frombits(0xc5000829), out: 0xe800}, // in f32=-2048.510010, out f16=-2048 + {in: math.Float32frombits(0xc5000800), out: 0xe800}, // in f32=-2048.500000, out f16=-2048 + {in: math.Float32frombits(0xc50007d7), out: 0xe800}, // in f32=-2048.489990, out f16=-2048 + {in: math.Float32frombits(0xc5000000), out: 0xe800}, // in f32=-2048.000000, out f16=-2048 + {in: math.Float32frombits(0xc4fff052), out: 0xe800}, // in f32=-2047.510010, out f16=-2048 + {in: math.Float32frombits(0xc4fff000), out: 0xe800}, // in f32=-2047.500000, out f16=-2048 + {in: math.Float32frombits(0xc4ffefae), out: 0xe7ff}, // in f32=-2047.489990, out f16=-2047 + {in: math.Float32frombits(0xc4ffe000), out: 0xe7ff}, // in f32=-2047.000000, out f16=-2047 + {in: math.Float32frombits(0xc4ffc000), out: 0xe7fe}, // in f32=-2046.000000, out f16=-2046 + {in: math.Float32frombits(0xc4ffa000), out: 0xe7fd}, // in f32=-2045.000000, out f16=-2045 + {in: math.Float32frombits(0xbf800000), out: 0xbc00}, // in f32=-1.000000, out f16=-1 + {in: math.Float32frombits(0xbf028f5c), out: 0xb814}, // in f32=-0.510000, out f16=-0.5097656 + {in: math.Float32frombits(0xbf000000), out: 0xb800}, // in f32=-0.500000, out f16=-0.5 + {in: math.Float32frombits(0xbefae148), out: 0xb7d7}, // in f32=-0.490000, out f16=-0.48999023 + {in: math.Float32frombits(0x3efae148), out: 0x37d7}, // in f32=0.490000, out f16=0.48999023 + {in: math.Float32frombits(0x3f000000), out: 0x3800}, // in f32=0.500000, out f16=0.5 + {in: math.Float32frombits(0x3f028f5c), out: 0x3814}, // in f32=0.510000, out f16=0.5097656 + {in: math.Float32frombits(0x3f800000), out: 0x3c00}, // in f32=1.000000, out f16=1 + {in: math.Float32frombits(0x3fbeb852), out: 0x3df6}, // in f32=1.490000, out f16=1.4902344 + {in: math.Float32frombits(0x3fc00000), out: 0x3e00}, // in f32=1.500000, out f16=1.5 + {in: math.Float32frombits(0x3fc147ae), out: 0x3e0a}, // in f32=1.510000, out f16=1.5097656 + {in: math.Float32frombits(0x3fcf1bbd), out: 0x3e79}, // in f32=1.618034, out f16=1.6181641 + {in: math.Float32frombits(0x401f5c29), out: 0x40fb}, // in f32=2.490000, out f16=2.4902344 + {in: math.Float32frombits(0x40200000), out: 0x4100}, // in f32=2.500000, out f16=2.5 + {in: math.Float32frombits(0x4020a3d7), out: 0x4105}, // in f32=2.510000, out f16=2.5097656 + {in: math.Float32frombits(0x402df854), out: 0x4170}, // in f32=2.718282, out f16=2.71875 + {in: math.Float32frombits(0x40490fdb), out: 0x4248}, // in f32=3.141593, out f16=3.140625 + {in: math.Float32frombits(0x40b00000), out: 0x4580}, // in f32=5.500000, out f16=5.5 + {in: math.Float32frombits(0x44ffa000), out: 0x67fd}, // in f32=2045.000000, out f16=2045 + {in: math.Float32frombits(0x44ffc000), out: 0x67fe}, // in f32=2046.000000, out f16=2046 + {in: math.Float32frombits(0x44ffe000), out: 0x67ff}, // in f32=2047.000000, out f16=2047 + {in: math.Float32frombits(0x44ffefae), out: 0x67ff}, // in f32=2047.489990, out f16=2047 + {in: math.Float32frombits(0x44fff000), out: 0x6800}, // in f32=2047.500000, out f16=2048 + {in: math.Float32frombits(0x44fff052), out: 0x6800}, // in f32=2047.510010, out f16=2048 + {in: math.Float32frombits(0x45000000), out: 0x6800}, // in f32=2048.000000, out f16=2048 + {in: math.Float32frombits(0x450007d7), out: 0x6800}, // in f32=2048.489990, out f16=2048 + {in: math.Float32frombits(0x45000800), out: 0x6800}, // in f32=2048.500000, out f16=2048 + {in: math.Float32frombits(0x45000829), out: 0x6800}, // in f32=2048.510010, out f16=2048 + {in: math.Float32frombits(0x45001000), out: 0x6800}, // in f32=2049.000000, out f16=2048 + {in: math.Float32frombits(0x450017d7), out: 0x6801}, // in f32=2049.489990, out f16=2050 + {in: math.Float32frombits(0x45001800), out: 0x6801}, // in f32=2049.500000, out f16=2050 + {in: math.Float32frombits(0x45001829), out: 0x6801}, // in f32=2049.510010, out f16=2050 + {in: math.Float32frombits(0x45002000), out: 0x6801}, // in f32=2050.000000, out f16=2050 + {in: math.Float32frombits(0x45003000), out: 0x6802}, // in f32=2051.000000, out f16=2052 + {in: math.Float32frombits(0x457fd000), out: 0x6bfe}, // in f32=4093.000000, out f16=4092 + {in: math.Float32frombits(0x457fe000), out: 0x6bff}, // in f32=4094.000000, out f16=4094 + {in: math.Float32frombits(0x457ff000), out: 0x6c00}, // in f32=4095.000000, out f16=4096 + {in: math.Float32frombits(0x45800000), out: 0x6c00}, // in f32=4096.000000, out f16=4096 + {in: math.Float32frombits(0x45800800), out: 0x6c00}, // in f32=4097.000000, out f16=4096 + {in: math.Float32frombits(0x45801000), out: 0x6c00}, // in f32=4098.000000, out f16=4096 + {in: math.Float32frombits(0x45801800), out: 0x6c01}, // in f32=4099.000000, out f16=4100 + {in: math.Float32frombits(0x45802000), out: 0x6c01}, // in f32=4100.000000, out f16=4100 + {in: math.Float32frombits(0x45ad9c00), out: 0x6d6d}, // in f32=5555.500000, out f16=5556 + {in: math.Float32frombits(0x45ffe800), out: 0x6fff}, // in f32=8189.000000, out f16=8188 + {in: math.Float32frombits(0x45fff000), out: 0x7000}, // in f32=8190.000000, out f16=8192 + {in: math.Float32frombits(0x45fff800), out: 0x7000}, // in f32=8191.000000, out f16=8192 + {in: math.Float32frombits(0x46000000), out: 0x7000}, // in f32=8192.000000, out f16=8192 + {in: math.Float32frombits(0x46000400), out: 0x7000}, // in f32=8193.000000, out f16=8192 + {in: math.Float32frombits(0x46000800), out: 0x7000}, // in f32=8194.000000, out f16=8192 + {in: math.Float32frombits(0x46000c00), out: 0x7000}, // in f32=8195.000000, out f16=8192 + {in: math.Float32frombits(0x46001000), out: 0x7000}, // in f32=8196.000000, out f16=8192 + {in: math.Float32frombits(0x46001400), out: 0x7001}, // in f32=8197.000000, out f16=8200 + {in: math.Float32frombits(0x46001800), out: 0x7001}, // in f32=8198.000000, out f16=8200 + {in: math.Float32frombits(0x46001c00), out: 0x7001}, // in f32=8199.000000, out f16=8200 + {in: math.Float32frombits(0x46002000), out: 0x7001}, // in f32=8200.000000, out f16=8200 + {in: math.Float32frombits(0x46002400), out: 0x7001}, // in f32=8201.000000, out f16=8200 + {in: math.Float32frombits(0x46002800), out: 0x7001}, // in f32=8202.000000, out f16=8200 + {in: math.Float32frombits(0x46002c00), out: 0x7001}, // in f32=8203.000000, out f16=8200 + {in: math.Float32frombits(0x46003000), out: 0x7002}, // in f32=8204.000000, out f16=8208 + {in: math.Float32frombits(0x467fec00), out: 0x73ff}, // in f32=16379.000000, out f16=16376 + {in: math.Float32frombits(0x467ff000), out: 0x7400}, // in f32=16380.000000, out f16=16384 + {in: math.Float32frombits(0x467ff400), out: 0x7400}, // in f32=16381.000000, out f16=16384 + {in: math.Float32frombits(0x467ff800), out: 0x7400}, // in f32=16382.000000, out f16=16384 + {in: math.Float32frombits(0x467ffc00), out: 0x7400}, // in f32=16383.000000, out f16=16384 + {in: math.Float32frombits(0x46800000), out: 0x7400}, // in f32=16384.000000, out f16=16384 + {in: math.Float32frombits(0x46800200), out: 0x7400}, // in f32=16385.000000, out f16=16384 + {in: math.Float32frombits(0x46800400), out: 0x7400}, // in f32=16386.000000, out f16=16384 + {in: math.Float32frombits(0x46800600), out: 0x7400}, // in f32=16387.000000, out f16=16384 + {in: math.Float32frombits(0x46800800), out: 0x7400}, // in f32=16388.000000, out f16=16384 + {in: math.Float32frombits(0x46800a00), out: 0x7400}, // in f32=16389.000000, out f16=16384 + {in: math.Float32frombits(0x46800c00), out: 0x7400}, // in f32=16390.000000, out f16=16384 + {in: math.Float32frombits(0x46800e00), out: 0x7400}, // in f32=16391.000000, out f16=16384 + {in: math.Float32frombits(0x46801000), out: 0x7400}, // in f32=16392.000000, out f16=16384 + {in: math.Float32frombits(0x46801200), out: 0x7401}, // in f32=16393.000000, out f16=16400 + {in: math.Float32frombits(0x46801400), out: 0x7401}, // in f32=16394.000000, out f16=16400 + {in: math.Float32frombits(0x46801600), out: 0x7401}, // in f32=16395.000000, out f16=16400 + {in: math.Float32frombits(0x46801800), out: 0x7401}, // in f32=16396.000000, out f16=16400 + {in: math.Float32frombits(0x46801a00), out: 0x7401}, // in f32=16397.000000, out f16=16400 + {in: math.Float32frombits(0x46801c00), out: 0x7401}, // in f32=16398.000000, out f16=16400 + {in: math.Float32frombits(0x46801e00), out: 0x7401}, // in f32=16399.000000, out f16=16400 + {in: math.Float32frombits(0x46802000), out: 0x7401}, // in f32=16400.000000, out f16=16400 + {in: math.Float32frombits(0x46802200), out: 0x7401}, // in f32=16401.000000, out f16=16400 + {in: math.Float32frombits(0x46802400), out: 0x7401}, // in f32=16402.000000, out f16=16400 + {in: math.Float32frombits(0x46802600), out: 0x7401}, // in f32=16403.000000, out f16=16400 + {in: math.Float32frombits(0x46802800), out: 0x7401}, // in f32=16404.000000, out f16=16400 + {in: math.Float32frombits(0x46802a00), out: 0x7401}, // in f32=16405.000000, out f16=16400 + {in: math.Float32frombits(0x46802c00), out: 0x7401}, // in f32=16406.000000, out f16=16400 + {in: math.Float32frombits(0x46802e00), out: 0x7401}, // in f32=16407.000000, out f16=16400 + {in: math.Float32frombits(0x46803000), out: 0x7402}, // in f32=16408.000000, out f16=16416 + {in: math.Float32frombits(0x46ffee00), out: 0x77ff}, // in f32=32759.000000, out f16=32752 + {in: math.Float32frombits(0x46fff000), out: 0x7800}, // in f32=32760.000000, out f16=32768 + {in: math.Float32frombits(0x46fff200), out: 0x7800}, // in f32=32761.000000, out f16=32768 + {in: math.Float32frombits(0x46fff400), out: 0x7800}, // in f32=32762.000000, out f16=32768 + {in: math.Float32frombits(0x46fff600), out: 0x7800}, // in f32=32763.000000, out f16=32768 + {in: math.Float32frombits(0x46fff800), out: 0x7800}, // in f32=32764.000000, out f16=32768 + {in: math.Float32frombits(0x46fffa00), out: 0x7800}, // in f32=32765.000000, out f16=32768 + {in: math.Float32frombits(0x46fffc00), out: 0x7800}, // in f32=32766.000000, out f16=32768 + {in: math.Float32frombits(0x46fffe00), out: 0x7800}, // in f32=32767.000000, out f16=32768 + {in: math.Float32frombits(0x47000000), out: 0x7800}, // in f32=32768.000000, out f16=32768 + {in: math.Float32frombits(0x47000100), out: 0x7800}, // in f32=32769.000000, out f16=32768 + {in: math.Float32frombits(0x47000200), out: 0x7800}, // in f32=32770.000000, out f16=32768 + {in: math.Float32frombits(0x47000300), out: 0x7800}, // in f32=32771.000000, out f16=32768 + {in: math.Float32frombits(0x47000400), out: 0x7800}, // in f32=32772.000000, out f16=32768 + {in: math.Float32frombits(0x47000500), out: 0x7800}, // in f32=32773.000000, out f16=32768 + {in: math.Float32frombits(0x47000600), out: 0x7800}, // in f32=32774.000000, out f16=32768 + {in: math.Float32frombits(0x47000700), out: 0x7800}, // in f32=32775.000000, out f16=32768 + {in: math.Float32frombits(0x47000800), out: 0x7800}, // in f32=32776.000000, out f16=32768 + {in: math.Float32frombits(0x47000900), out: 0x7800}, // in f32=32777.000000, out f16=32768 + {in: math.Float32frombits(0x47000a00), out: 0x7800}, // in f32=32778.000000, out f16=32768 + {in: math.Float32frombits(0x47000b00), out: 0x7800}, // in f32=32779.000000, out f16=32768 + {in: math.Float32frombits(0x47000c00), out: 0x7800}, // in f32=32780.000000, out f16=32768 + {in: math.Float32frombits(0x47000d00), out: 0x7800}, // in f32=32781.000000, out f16=32768 + {in: math.Float32frombits(0x47000e00), out: 0x7800}, // in f32=32782.000000, out f16=32768 + {in: math.Float32frombits(0x47000f00), out: 0x7800}, // in f32=32783.000000, out f16=32768 + {in: math.Float32frombits(0x47001000), out: 0x7800}, // in f32=32784.000000, out f16=32768 + {in: math.Float32frombits(0x47001100), out: 0x7801}, // in f32=32785.000000, out f16=32800 + {in: math.Float32frombits(0x47001200), out: 0x7801}, // in f32=32786.000000, out f16=32800 + {in: math.Float32frombits(0x47001300), out: 0x7801}, // in f32=32787.000000, out f16=32800 + {in: math.Float32frombits(0x47001400), out: 0x7801}, // in f32=32788.000000, out f16=32800 + {in: math.Float32frombits(0x47001500), out: 0x7801}, // in f32=32789.000000, out f16=32800 + {in: math.Float32frombits(0x47001600), out: 0x7801}, // in f32=32790.000000, out f16=32800 + {in: math.Float32frombits(0x47001700), out: 0x7801}, // in f32=32791.000000, out f16=32800 + {in: math.Float32frombits(0x47001800), out: 0x7801}, // in f32=32792.000000, out f16=32800 + {in: math.Float32frombits(0x47001900), out: 0x7801}, // in f32=32793.000000, out f16=32800 + {in: math.Float32frombits(0x47001a00), out: 0x7801}, // in f32=32794.000000, out f16=32800 + {in: math.Float32frombits(0x47001b00), out: 0x7801}, // in f32=32795.000000, out f16=32800 + {in: math.Float32frombits(0x47001c00), out: 0x7801}, // in f32=32796.000000, out f16=32800 + {in: math.Float32frombits(0x47001d00), out: 0x7801}, // in f32=32797.000000, out f16=32800 + {in: math.Float32frombits(0x47001e00), out: 0x7801}, // in f32=32798.000000, out f16=32800 + {in: math.Float32frombits(0x47001f00), out: 0x7801}, // in f32=32799.000000, out f16=32800 + {in: math.Float32frombits(0x47002000), out: 0x7801}, // in f32=32800.000000, out f16=32800 + {in: math.Float32frombits(0x47002100), out: 0x7801}, // in f32=32801.000000, out f16=32800 + {in: math.Float32frombits(0x47002200), out: 0x7801}, // in f32=32802.000000, out f16=32800 + {in: math.Float32frombits(0x47002300), out: 0x7801}, // in f32=32803.000000, out f16=32800 + {in: math.Float32frombits(0x47002400), out: 0x7801}, // in f32=32804.000000, out f16=32800 + {in: math.Float32frombits(0x47002500), out: 0x7801}, // in f32=32805.000000, out f16=32800 + {in: math.Float32frombits(0x47002600), out: 0x7801}, // in f32=32806.000000, out f16=32800 + {in: math.Float32frombits(0x47002700), out: 0x7801}, // in f32=32807.000000, out f16=32800 + {in: math.Float32frombits(0x47002800), out: 0x7801}, // in f32=32808.000000, out f16=32800 + {in: math.Float32frombits(0x47002900), out: 0x7801}, // in f32=32809.000000, out f16=32800 + {in: math.Float32frombits(0x47002a00), out: 0x7801}, // in f32=32810.000000, out f16=32800 + {in: math.Float32frombits(0x47002b00), out: 0x7801}, // in f32=32811.000000, out f16=32800 + {in: math.Float32frombits(0x47002c00), out: 0x7801}, // in f32=32812.000000, out f16=32800 + {in: math.Float32frombits(0x47002d00), out: 0x7801}, // in f32=32813.000000, out f16=32800 + {in: math.Float32frombits(0x47002e00), out: 0x7801}, // in f32=32814.000000, out f16=32800 + {in: math.Float32frombits(0x47002f00), out: 0x7801}, // in f32=32815.000000, out f16=32800 + {in: math.Float32frombits(0x47003000), out: 0x7802}, // in f32=32816.000000, out f16=32832 + {in: math.Float32frombits(0x477fe500), out: 0x7bff}, // in f32=65509.000000, out f16=65504 + {in: math.Float32frombits(0x477fe100), out: 0x7bff}, // in f32=65505.000000, out f16=65504 + {in: math.Float32frombits(0x477fee00), out: 0x7bff}, // in f32=65518.000000, out f16=65504 + {in: math.Float32frombits(0x477fef00), out: 0x7bff}, // in f32=65519.000000, out f16=65504 + {in: math.Float32frombits(0x477feffd), out: 0x7bff}, // in f32=65519.988281, out f16=65504 + {in: math.Float32frombits(0x477ff000), out: 0x7c00}, // in f32=65520.000000, out f16=+Inf +} + +func TestPrecisionFromfloat32(t *testing.T) { + for i, v := range wantF32toF16bits { + f16 := float16.Fromfloat32(v.in) + u16 := uint16(f16) + + if u16 != v.out { + t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16) + } + + checkPrecision(t, v.in, f16, uint64(i)) + } + + f32 := float32(5.5) // value that doesn't drop any bits in the significand, is within normal exponent range + pre := float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionExact { + t.Errorf("f32bits=0x%08x, wanted=PrecisionExact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionExact, pre) + } + + f32 = math.Float32frombits(0x38000000) // subnormal value with coef = 0 that can round-trip float32->float16->float32 + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionUnknown { + t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre) + } + + f32 = math.Float32frombits(0x387fc000) // subnormal value with coef !=0 that can round-trip float32->float16->float32 + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionUnknown { + t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre) + } + + f32 = math.Float32frombits(0x33c00000) // subnormal value with no dropped bits that cannot round-trip float32->float16->float32 + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionUnknown { + t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre) + } + + f32 = math.Float32frombits(0x38000001) // subnormal value with dropped non-zero bits > 0 + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionInexact { + t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre) + } + + f32 = float32(math.Pi) // value that cannot "preserve value" because it drops bits in the significand + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionInexact { + t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre) + } + + f32 = math.Float32frombits(0x1) // value that will underflow + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionUnderflow { + t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre) + } + + f32 = math.Float32frombits(0x33000000) // value that will underflow + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionUnderflow { + t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre) + } + + f32 = math.Float32frombits(0x47800000) // value that will overflow + pre = float16.PrecisionFromfloat32(f32) + if pre != float16.PrecisionOverflow { + t.Errorf("f32bits=0x%08x, wanted=PrecisionOverflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionOverflow, pre) + } + +} + +func TestFromNaN32ps(t *testing.T) { + for i, v := range wantF32toF16bits { + f16 := float16.Fromfloat32(v.in) + u16 := uint16(f16) + + if u16 != v.out { + t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16) + } + + checkFromNaN32ps(t, v.in, f16) + } + + // since checkFromNaN32ps rejects non-NaN input, try one here + nan, err := float16.FromNaN32ps(float32(math.Pi)) + if err != float16.ErrInvalidNaNValue { + t.Errorf("FromNaN32ps: in float32(math.Pi) wanted err float16.ErrInvalidNaNValue, got err = %q", err) + } + if err.Error() != "float16: invalid NaN value, expected IEEE 754 NaN" { + t.Errorf("unexpected string value returned by err.Error() for ErrInvalidNaNValue: %s", err.Error()) + } + if uint16(nan) != 0x7c01 { // signaling NaN + t.Errorf("FromNaN32ps: in float32(math.Pi) wanted nan = 0x7c01, got nan = 0x%04x", uint16(nan)) + } + +} + +// Test a small subset of possible conversions from float32 to Float16. +// TestSomeFromFloat32 runs in under 1 second while TestAllFromFloat32 takes about 45 seconds. +func TestSomeFromFloat32(t *testing.T) { + + for i, v := range wantF32toF16bits { + f16 := float16.Fromfloat32(v.in) + u16 := uint16(f16) + + if u16 != v.out { + t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16) + } + } +} + +// Test all possible 4294967296 float32 input values and results for +// Fromfloat32(), FromNaN32ps(), and PrecisionFromfloat32(). +func TestAllFromFloat32(t *testing.T) { + + if testing.Short() { + t.Skip("skipping TestAllFromFloat32 in short mode.") + } + + fmt.Printf("WARNING: TestAllFromFloat32 should take about 1-2 minutes to run on amd64, other platforms may take longer...\n") + + // Blake2b is "3f310bc5608a087462d361644fe66feeb4c68145f6f18eb6f1439cd7914888b6df9e30ae5350dce0635162cc6a2f23b31b3e4353ca132a3c552bdbd58baa54e6" + const wantSHA512 = "08670429a475164d6c4a080969e35231c77ef7069b430b5f38af22e013796b7818bbe8f5942a6ddf26de0e1dfc67d02243f483d85729ebc3762fc2948a5ca1f8" + + const batchSize uint32 = 16384 + results := make([]uint16, batchSize) + buf := new(bytes.Buffer) + h := sha512.New() + + for i := uint64(0); i < uint64(0xFFFFFFFF); i += uint64(batchSize) { + // fill results + for j := uint32(0); j < batchSize; j++ { + inF32 := math.Float32frombits(uint32(i) + j) + f16 := float16.Fromfloat32(inF32) + results[j] = uint16(f16) + checkPrecision(t, inF32, f16, i) + checkFromNaN32ps(t, inF32, f16) + } + + // convert results to []byte + err := binary.Write(buf, binary.LittleEndian, results) + if err != nil { + panic(err) + } + + // update hash with []byte of results + _, err = h.Write(buf.Bytes()) + if err != nil { + panic(err) + } + + buf.Reset() + } + + // display hash digest in hex + digest := h.Sum(nil) + gotSHA512hex := hex.EncodeToString(digest) + if gotSHA512hex != wantSHA512 { + t.Errorf("gotSHA512hex = %s", gotSHA512hex) + } +} + +// Test all 65536 conversions from float16 to float32. +// TestAllToFloat32 runs in under 1 second. +func TestAllToFloat32(t *testing.T) { + // Blake2b is "078d8e3fac9480de1493f22c8f9bfc1eb2051537c536f00f621557d70eed1af057a487c3e252f6d593769f5288d5ab66d8e9cd1adba359838802944bdb731f4d" + const wantSHA512 = "1a4ccec9fd7b6e83310c6b4958a25778cd95f8d4f88b19950e4b8d6932a955f7fbd96b1c9bd9b2a79c3a9d34d653f55e671f8f86e6a5a876660cd38479001aa6" + const batchSize uint32 = 16384 + results := make([]float32, batchSize) + buf := new(bytes.Buffer) + h := sha512.New() + + for i := uint64(0); i < uint64(0xFFFF); i += uint64(batchSize) { + // fill results + for j := uint32(0); j < batchSize; j++ { + inU16 := uint16(i) + uint16(j) + f16 := float16.Float16(inU16) + results[j] = f16.Float32() + } + + // convert results to []byte + err := binary.Write(buf, binary.LittleEndian, results) + if err != nil { + panic(err) + } + + // update hash with []byte of results + _, err = h.Write(buf.Bytes()) + if err != nil { + panic(err) + } + + buf.Reset() + } + + // display hash digest in hex + digest := h.Sum(nil) + gotSHA512hex := hex.EncodeToString(digest) + if gotSHA512hex != wantSHA512 { + t.Errorf("Float16toFloat32: gotSHA512hex = %s", gotSHA512hex) + } + +} + +func TestFrombits(t *testing.T) { + x := uint16(0x1234) + f16 := float16.Frombits(x) + if uint16(f16) != f16.Bits() || uint16(f16) != x { + t.Errorf("float16.Frombits(0x7fff) returned %04x, wanted %04x", uint16(f16), x) + } +} + +func TestNaN(t *testing.T) { + nan := float16.NaN() + if !nan.IsNaN() { + t.Errorf("nan.IsNaN() returned false, wanted true") + } +} + +func TestInf(t *testing.T) { + posInf := float16.Inf(0) + if uint16(posInf) != 0x7c00 { + t.Errorf("float16.Inf(0) returned %04x, wanted %04x", uint16(posInf), 0x7c00) + } + + posInf = float16.Inf(1) + if uint16(posInf) != 0x7c00 { + t.Errorf("float16.Inf(1) returned %04x, wanted %04x", uint16(posInf), 0x7c00) + } + + negInf := float16.Inf(-1) + if uint16(negInf) != 0xfc00 { + t.Errorf("float16.Inf(-1) returned %04x, wanted %04x", uint16(negInf), 0xfc00) + } +} + +func TestBits(t *testing.T) { + x := uint16(0x1234) + f16 := float16.Frombits(x) + if uint16(f16) != f16.Bits() || f16.Bits() != x { + t.Errorf("Bits() returned %04x, wanted %04x", uint16(f16), x) + } +} + +func TestIsFinite(t *testing.T) { + // IsFinite returns true if f is neither infinite nor NaN. + + finite := float16.Fromfloat32(float32(1.5)) + if !finite.IsFinite() { + t.Errorf("finite.Infinite() returned false, wanted true") + } + + posInf := float16.Inf(0) + if posInf.IsFinite() { + t.Errorf("posInf.Infinite() returned true, wanted false") + } + + negInf := float16.Inf(-1) + if negInf.IsFinite() { + t.Errorf("negInf.Infinite() returned true, wanted false") + } + + nan := float16.NaN() + if nan.IsFinite() { + t.Errorf("nan.Infinite() returned true, wanted false") + } +} + +func TestIsNaN(t *testing.T) { + + f16 := float16.Float16(0) + if f16.IsNaN() { + t.Errorf("Float16(0).IsNaN() returned true, wanted false") + } + + f16 = float16.Float16(0x7e00) + if !f16.IsNaN() { + t.Errorf("Float16(0x7e00).IsNaN() returned false, wanted true") + } +} + +func TestIsQuietNaN(t *testing.T) { + + f16 := float16.Float16(0) + if f16.IsQuietNaN() { + t.Errorf("Float16(0).IsQuietNaN() returned true, wanted false") + } + + f16 = float16.Float16(0x7e00) + if !f16.IsQuietNaN() { + t.Errorf("Float16(0x7e00).IsQuietNaN() returned false, wanted true") + } + + f16 = float16.Float16(0x7e00 ^ 0x0200) + if f16.IsQuietNaN() { + t.Errorf("Float16(0x7e00 ^ 0x0200).IsQuietNaN() returned true, wanted false") + } +} + +func TestIsNormal(t *testing.T) { + // IsNormal returns true if f is neither zero, infinite, subnormal, or NaN. + + zero := float16.Frombits(0) + if zero.IsNormal() { + t.Errorf("zero.IsNormal() returned true, wanted false") + } + + posInf := float16.Inf(0) + if posInf.IsNormal() { + t.Errorf("posInf.IsNormal() returned true, wanted false") + } + + negInf := float16.Inf(-1) + if negInf.IsNormal() { + t.Errorf("negInf.IsNormal() returned true, wanted false") + } + + nan := float16.NaN() + if nan.IsNormal() { + t.Errorf("nan.IsNormal() returned true, wanted false") + } + + subnormal := float16.Frombits(0x0001) + if subnormal.IsNormal() { + t.Errorf("subnormal.IsNormal() returned true, wanted false") + } + + normal := float16.Fromfloat32(float32(1.5)) + if !normal.IsNormal() { + t.Errorf("normal.IsNormal() returned false, wanted true") + } + +} + +func TestSignbit(t *testing.T) { + + f16 := float16.Fromfloat32(float32(0.0)) + if f16.Signbit() { + t.Errorf("float16.Fromfloat32(float32(0)).Signbit() returned true, wanted false") + } + + f16 = float16.Fromfloat32(float32(2.0)) + if f16.Signbit() { + t.Errorf("float16.Fromfloat32(float32(2)).Signbit() returned true, wanted false") + } + + f16 = float16.Fromfloat32(float32(-2.0)) + if !f16.Signbit() { + t.Errorf("float16.Fromfloat32(float32(-2)).Signbit() returned false, wanted true") + } + +} + +func TestString(t *testing.T) { + f16 := float16.Fromfloat32(1.5) + s := f16.String() + if s != "1.5" { + t.Errorf("Float16(1.5).String() returned %s, wanted 1.5", s) + } + + f16 = float16.Fromfloat32(3.141593) + s = f16.String() + if s != "3.140625" { + t.Errorf("Float16(3.141593).String() returned %s, wanted 3.140625", s) + } + +} + +func TestIsInf(t *testing.T) { + + f16 := float16.Float16(0) + if f16.IsInf(0) { + t.Errorf("Float16(0).IsInf(0) returned true, wanted false") + } + + f16 = float16.Float16(0x7c00) + if !f16.IsInf(0) { + t.Errorf("Float16(0x7c00).IsInf(0) returned false, wanted true") + } + + f16 = float16.Float16(0x7c00) + if !f16.IsInf(1) { + t.Errorf("Float16(0x7c00).IsInf(1) returned false, wanted true") + } + + f16 = float16.Float16(0x7c00) + if f16.IsInf(-1) { + t.Errorf("Float16(0x7c00).IsInf(-1) returned true, wanted false") + } + + f16 = float16.Float16(0xfc00) + if !f16.IsInf(0) { + t.Errorf("Float16(0xfc00).IsInf(0) returned false, wanted true") + } + + f16 = float16.Float16(0xfc00) + if f16.IsInf(1) { + t.Errorf("Float16(0xfc00).IsInf(1) returned true, wanted false") + } + + f16 = float16.Float16(0xfc00) + if !f16.IsInf(-1) { + t.Errorf("Float16(0xfc00).IsInf(-1) returned false, wanted true") + } +} + +func float32parts(f32 float32) (exp int32, coef uint32, dropped uint32) { + const COEFMASK uint32 = 0x7fffff // 23 least significant bits + const EXPSHIFT uint32 = 23 + const EXPBIAS uint32 = 127 + const EXPMASK uint32 = uint32(0xff) << EXPSHIFT + const DROPMASK uint32 = COEFMASK >> 10 + u32 := math.Float32bits(f32) + exp = int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS) + coef = u32 & COEFMASK + dropped = coef & DROPMASK + return exp, coef, dropped +} + +func isNaN32(f32 float32) bool { + exp, coef, _ := float32parts(f32) + return (exp == 128) && (coef != 0) +} + +func isQuietNaN32(f32 float32) bool { + exp, coef, _ := float32parts(f32) + return (exp == 128) && (coef != 0) && ((coef & 0x00400000) != 0) +} + +func checkFromNaN32ps(t *testing.T, f32 float32, f16 float16.Float16) { + + if !isNaN32(f32) { + return + } + + u32 := math.Float32bits(f32) + nan16, err := float16.FromNaN32ps(f32) + + if isQuietNaN32(f32) { + // result should be the same + if err != nil { + t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err) + } + if uint16(nan16) != uint16(f16) { + t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted nan16 = %v, got nan16 = %v", u32, f32, f16, nan16) + } + } else { + // result should differ only by the signaling/quiet bit unless payload is empty + if err != nil { + t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err) + } + + coef := uint16(f16) & uint16(0x03ff) + payload := uint16(f16) & uint16(0x01ff) + diff := uint16(nan16 ^ f16) + + if payload == 0 { + // the lowest bit needed to be set to prevent turning sNaN into infinity, so 2 bits differ + if diff != 0x0201 { + t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0201, got 0x%04x", u32, f32, diff) + } + } else { + // only the quiet bit was restored, so 1 bit differs + if diff != 0x0200 { + t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0200, got 0x%04x. f16=0x%04x n16=0x%04x coef=0x%04x", u32, f32, diff, uint16(f16), uint16(nan16), coef) + } + } + } +} + +func checkPrecision(t *testing.T, f32 float32, f16 float16.Float16, i uint64) { + // TODO: rewrite this test when time allows + + u32 := math.Float32bits(f32) + u16 := f16.Bits() + f32bis := f16.Float32() + u32bis := math.Float32bits(f32bis) + pre := float16.PrecisionFromfloat32(f32) + roundtripped := u32 == u32bis + exp32, coef32, dropped32 := float32parts(f32) + + if roundtripped { + checkRoundTrippedPrecision(t, u32, u16, u32bis, exp32, coef32, dropped32) + return + } + + if pre == float16.PrecisionExact { + // this should only happen if both input and output are NaN + if !(f16.IsNaN() && isNaN32(f32)) { + t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionExact when roundtrip failed with non-special value", i, u32, f32, u16, u32bis, f32bis) + } + + } else if pre == float16.PrecisionUnknown { + if exp32 < -24 { + t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionUnderflow", i, u32, f32, u16, u32bis, f32bis) + } + if dropped32 != 0 { + t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionInexact", i, u32, f32, u16, u32bis, f32bis) + } + } else if pre == float16.PrecisionInexact { + checkPrecisionInexact(t, u32, u16, u32bis, exp32, coef32, dropped32) + } else if pre == float16.PrecisionUnderflow { + if exp32 >= -14 { + t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnderflow when exp32 is >= -14", i, u32, f32, u16, u32bis, f32bis) + } + } else if pre == float16.PrecisionOverflow { + if exp32 <= 15 { + t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionOverflow when exp32 is <= 15", i, u32, f32, u16, u32bis, f32bis) + } + } +} + +func checkPrecisionInexact(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) { + f32 := math.Float32frombits(u32) + f32bis := math.Float32frombits(u32bis) + + if exp32 < -24 { + t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionUnderflow", u32, f32, u16, u32bis, f32bis) + } + if exp32 > 15 { + t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionOverflow", u32, f32, u16, u32bis, f32bis) + } + if coef32 == 0 { + t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when coef32 is 0", u32, f32, u16, u32bis, f32bis) + } + if dropped32 == 0 { + t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when dropped32 is 0", u32, f32, u16, u32bis, f32bis) + } +} + +func checkRoundTrippedPrecision(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) { + f32 := math.Float32frombits(u32) + f32bis := math.Float32frombits(u32bis) + pre := float16.PrecisionFromfloat32(f32) + f16 := float16.Frombits(u16) + + if dropped32 != 0 { + t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), dropped32 != 0 with successful roundtrip", u32, f32, u16, u32bis, f32bis) + } + + if pre != float16.PrecisionExact { + // there are 2046 values that are subnormal and can round-trip float32->float16->float32 + if pre != float16.PrecisionUnknown { + t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%032b) (%f), out f16bits=0x%04x (%v), back=0x%08x (%f), got %v, wanted PrecisionExact, exp=%d, coef=%d, drpd=%d", u32, u32, f32, u16, f16, u32bis, f32bis, pre, exp32, coef32, dropped32) + } + } + +} diff --git a/pickle/pickle_example_test.go b/pickle/pickle_example_test.go index 421951b..8a34868 100644 --- a/pickle/pickle_example_test.go +++ b/pickle/pickle_example_test.go @@ -1,6 +1,7 @@ package pickle_test import ( + "fmt" "log" "github.com/sugarme/gotch" @@ -18,11 +19,13 @@ func ExampleLoadInfo() { panic(err) } - err = pickle.LoadInfo(modelFile) + m, err := pickle.LoadModelInfo(modelFile) if err != nil { log.Fatal(err) } + fmt.Println(m) + // Output: // classifier.0.bias - [4096] // classifier.0.weight - [4096 25088] @@ -57,4 +60,25 @@ func ExampleLoadInfo() { // features.7.bias - [128] // features.7.weight - [128 128 3 3] // Num of variables: 32 + // Tensor DType: Float +} + +func ExampleModelFloat16() { + modelName := "HuggingFaceH4/tiny-random-LlamaForCausalLM" + url := "https://huggingface.co/HuggingFaceH4/tiny-random-LlamaForCausalLM/resolve/main/pytorch_model.bin" + + modelFile, err := gotch.CachedPath(url, modelName) + if err != nil { + panic(err) + } + + m, err := pickle.LoadModelInfo(modelFile) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Model DType: %v\n", m.DType()) + + // Output: + // Model DType: Half } diff --git a/pickle/serialization.go b/pickle/serialization.go index c5d9617..2a00484 100644 --- a/pickle/serialization.go +++ b/pickle/serialization.go @@ -20,6 +20,7 @@ import ( "reflect" "sort" + "github.com/sugarme/gotch" "github.com/sugarme/gotch/nn" "github.com/sugarme/gotch/ts" ) @@ -60,83 +61,41 @@ func Decode(filename string) (map[string]*ts.Tensor, error) { dictResult := *result.(*Dict) for _, item := range dictResult { name := item.Key - itemTyp := reflect.TypeOf(item.Value).String() - switch itemTyp { - case "*pickle.Dict": // Nested *pickle.Dict case - subResult := *item.Value.(*Dict) - for _, subItem := range subResult { - subName := subItem.Key - x, ok := subItem.Value.(*StorageTensor) - if !ok { - log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName) - continue - } - data := x.Source.GetData() - size := x.Size - dtype := x.Source.DType() - device := x.Source.Device() - stride := x.Stride - storageOffset := x.StorageOffset - if reflect.ValueOf(data).Len() == 0 { - log.Printf("INFO: skip weight %q with zero data length.\n", name.(string)) - continue - } - - // TODO. should we just skip them? - if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { - size = []int64{1} - stride = []int64{1} - } - - x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) - if x.RequiresGrad { - x1.MustRequiresGrad_(x.RequiresGrad) - } - - namedTensors[name.(string)] = x1 - } - - default: - sx, isStorageTensor := item.Value.(*StorageTensor) - - // if !isStorageTensor { - // err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String()) - // return nil, err - // } - if !isStorageTensor { - log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name) - continue - } - - data := sx.Source.GetData() - size := sx.Size - dtype := sx.Source.DType() - device := sx.Source.Device() - stride := sx.Stride - storageOffset := sx.StorageOffset - - // log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) - // log.Printf("data: %v\n", data) - - // Dealing with Pytorch `..._tracked` variables. - if reflect.ValueOf(data).Len() == 0 { - log.Printf("INFO: skip weight %q with zero data length.\n", name.(string)) - continue - } - - // TODO. should we just skip them? - if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { - size = []int64{1} - stride = []int64{1} - } - - x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) - if sx.RequiresGrad { - x.MustRequiresGrad_(sx.RequiresGrad) - } - - namedTensors[name.(string)] = x + sx, isStorageTensor := item.Value.(*StorageTensor) + if !isStorageTensor { + err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String()) + return nil, err } + + data := sx.Source.GetData() + size := sx.Size + dtype := sx.Source.DType() + + device := sx.Source.Device() + stride := sx.Stride + storageOffset := sx.StorageOffset + + // log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset) + // log.Printf("data: %v\n", data) + + // Dealing with Pytorch `..._tracked` variables. + if reflect.ValueOf(data).Len() == 0 { + log.Printf("INFO: skip weight %q with zero data length.\n", name.(string)) + continue + } + + // TODO. should we just skip them? + if reflect.ValueOf(data).Len() == 1 && len(size) == 0 { + size = []int64{1} + stride = []int64{1} + } + + x := ts.MustOfSlice(data, ts.WithDType(dtype)).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true) + if sx.RequiresGrad { + x.MustRequiresGrad_(sx.RequiresGrad) + } + + namedTensors[name.(string)] = x } case "*pickle.OrderedDict": dictResult := result.(*OrderedDict) @@ -581,35 +540,74 @@ func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) { return missingVariables, nil } -// LoadInfo loads pretrained weights and prints out name and shape of weights. -func LoadInfo(modelFile string) error { - weights, err := Decode(modelFile) - if err != nil { - err = fmt.Errorf("LoadInfo() failed: %w", err) - return err - } +type ModelInfor struct { + weights map[string][]int64 + dtype gotch.DType +} - layers := make([]string, 0, len(weights)) - for tsName := range weights { +func NewModelInfor(weights map[string][]int64, dtype gotch.DType) *ModelInfor { + return &ModelInfor{ + weights: weights, + dtype: dtype, + } +} + +func (m *ModelInfor) String() string { + var summary string + layers := make([]string, 0, len(m.weights)) + for tsName := range m.weights { layers = append(layers, tsName) } sort.Strings(layers) for _, l := range layers { - var x *ts.Tensor - for tsName, tsVal := range weights { + var x []int64 + for tsName, shape := range m.weights { if tsName == l { - x = tsVal + x = shape break } } - fmt.Printf("%s - %+v\n", l, x.MustSize()) + + summary += fmt.Sprintf("%s - %+v\n", l, x) } - fmt.Printf("Num of variables: %v\n", len(weights)) + summary += fmt.Sprintf("Num of variables: %v\n", len(m.weights)) + summary += fmt.Sprintf("Tensor DType: %v\n", m.dtype) - for _, x := range weights { - x.MustDrop() - } - - return nil + return summary +} + +func (m *ModelInfor) DType() gotch.DType { + return m.dtype +} + +func (m *ModelInfor) Parameters() int { + return len(m.weights) +} + +// LoadInfo loads pretrained weights and prints out name and shape of weights. +func LoadModelInfo(modelFile string) (*ModelInfor, error) { + weights, err := Decode(modelFile) + if err != nil { + err = fmt.Errorf("LoadInfo() failed: %w", err) + return nil, err + } + + w := make(map[string][]int64) + var dtype gotch.DType + isFirst := true + for n, x := range weights { + w[n] = x.MustSize() + + if isFirst { + dtype = x.DType() + isFirst = false + } + } + + m := NewModelInfor(w, dtype) + + ts.CleanUp() + + return m, nil } diff --git a/pickle/storage.go b/pickle/storage.go index ce6aba0..45e35e6 100644 --- a/pickle/storage.go +++ b/pickle/storage.go @@ -7,6 +7,7 @@ import ( "math" "github.com/sugarme/gotch" + "github.com/sugarme/gotch/half" ) // This file implements Pytorch storage data types. @@ -67,7 +68,8 @@ func (s *HalfStorageClass) New(size int, location string) Storage { type HalfStorage struct { BaseStorage - Data []float32 + // Data []float32 + Data []half.Float16 } var _ Storage = &HalfStorage{} @@ -77,7 +79,7 @@ func (s *HalfStorage) SetFromFile(r io.Reader) error { } func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error { - data := make([]float32, size) + data := make([]half.Float16, size) br := NewLimitedBufferReader(r, size, 2, 512) for i := 0; i < size; i++ { bytes, err := br.ReadNext() @@ -85,7 +87,7 @@ func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error { return err } u16 := binary.LittleEndian.Uint16(bytes) - data[i] = math.Float32frombits(FloatBits16to32(u16)) + data[i] = half.Float16(u16) } s.Data = data return nil @@ -96,7 +98,7 @@ func (s *HalfStorage) GetData() interface{} { } func (s *HalfStorage) DType() gotch.DType { - return gotch.Float + return gotch.Half } func (s *HalfStorage) Device() gotch.Device { @@ -108,6 +110,62 @@ func (s *HalfStorage) Device() gotch.Device { } } +// BFloat16Storage: +// ================ +type BFloat16StorageClass struct{} + +var _ StorageClass = &BFloat16StorageClass{} + +func (s *BFloat16StorageClass) New(size int, location string) Storage { + return &BFloat16Storage{ + BaseStorage: BaseStorage{Size: size, Location: location}, + Data: nil, + } +} + +type BFloat16Storage struct { + BaseStorage + Data []half.BFloat16 +} + +var _ Storage = &BFloat16Storage{} + +func (s *BFloat16Storage) SetFromFile(r io.Reader) error { + return setFromFile(s, r) +} + +func (s *BFloat16Storage) SetFromFileWithSize(r io.Reader, size int) error { + data := make([]half.BFloat16, size) + br := NewLimitedBufferReader(r, size, 2, 512) + for i := 0; i < size; i++ { + bytes, err := br.ReadNext() + if err != nil { + return err + } + u16 := binary.LittleEndian.Uint16(bytes) + data[i] = half.BFloat16(u16) + } + s.Data = data + return nil +} + +func (s *BFloat16Storage) GetData() interface{} { + return s.Data +} + +func (s *BFloat16Storage) DType() gotch.DType { + return gotch.BFloat16 +} + +func (s *BFloat16Storage) Device() gotch.Device { + switch s.Location { + case "cuda": + return gotch.CudaIfAvailable() + default: + return gotch.CPU + } +} + // FloatStorage: // ============= diff --git a/ts/iter.go b/ts/iter.go index 694f740..cc81ce3 100644 --- a/ts/iter.go +++ b/ts/iter.go @@ -26,7 +26,7 @@ func (it *Iterable) Next() (item interface{}, ok bool) { } var err error - switch it.ItemKind.Kind().String() { + switch it.ItemKind.GoKind().String() { case "int64": item, err = it.Content.Int64Value([]int64{it.Index}) if err != nil { diff --git a/ts/npy.go b/ts/npy.go index 8c97a6a..e621f84 100644 --- a/ts/npy.go +++ b/ts/npy.go @@ -101,22 +101,22 @@ func (h *NpyHeader) ToString() (string, error) { shape := strings.Join(shapeStr, ",") var descr string - switch h.descr.Kind().String() { - // case "float16": // NOTE. No float16 in Go primary types. TODO. implement - // descr = "f2" - case "float32": + switch h.descr { + case gotch.Half: + descr = "f2" + case gotch.Float: descr = "f4" - case "float64": + case gotch.Double: descr = "f8" - case "int": + case gotch.Int: descr = "i4" - case "int64": + case gotch.Int64: descr = "i8" - case "int16": + case gotch.Int16: descr = "i2" - case "int8": + case gotch.Int8: descr = "i1" - case "uint8": + case gotch.Uint8: descr = "u1" default: err := fmt.Errorf("Unsupported kind: %v\n", h.descr) diff --git a/ts/print.go b/ts/print.go index fd4dd38..c2e2b99 100644 --- a/ts/print.go +++ b/ts/print.go @@ -11,41 +11,7 @@ import ( ) func (ts *Tensor) ValueGo() interface{} { - dtype := ts.DType() - numel := ts.Numel() - var dst interface{} - switch dtype { - case gotch.Uint8: - dst = make([]uint8, numel) - case gotch.Int8: - dst = make([]int8, numel) - case gotch.Int16: - dst = make([]int16, numel) - case gotch.Int: - dst = make([]int32, numel) - case gotch.Int64: - dst = make([]int64, numel) - case gotch.Float: - dst = make([]float32, numel) - case gotch.Double: - dst = make([]float64, numel) - case gotch.Bool: - dst = make([]bool, numel) - default: - err := fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType()) - log.Fatal(err) - } - err := ts.CopyData(dst, ts.Numel()) - if err != nil { - log.Fatal(err) - } - - // convert []int32 -> int - if reflect.TypeOf(dst).String() == "[]int32" { - dst = sliceInt32ToInt(dst.([]int32)) - } - - return dst + return ts.Vals() } func shapeToSize(shape []int64) int { @@ -250,8 +216,8 @@ func (f *fmtState) fmtVerb(ts *Tensor) { // var typ T typ := ts.DType() - switch typ.String() { - case "float32", "float64": + switch typ { + case gotch.Half, gotch.BFloat16, gotch.Float, gotch.Double: switch f.verb { case 'f', 'e', 'E', 'G', 'b': // accepted. Do nothing @@ -259,7 +225,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) { f.verb = 'g' } - case "uint8", "int8", "int16", "int32", "int64": + case gotch.Uint8, gotch.Int8, gotch.Int16, gotch.Int, gotch.Int64: switch f.verb { case 'b': f.base = 2 @@ -273,7 +239,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) { f.base = 10 f.verb = 'd' } - case "bool": + case gotch.Bool: f.verb = 't' default: f.verb = 'v' @@ -318,7 +284,7 @@ func (ts *Tensor) Format(s fmt.State, verb rune) { shape := toSliceInt(ts.MustSize()) strides := shapeToStrides(shape) device := ts.MustDevice() - dtype := ts.DType().String() + dtype := ts.DType() defined := ts.MustDefined() if verb == 'i' { fmt.Fprintf( diff --git a/ts/tensor.go b/ts/tensor.go index bc8c67c..d0bdbd5 100644 --- a/ts/tensor.go +++ b/ts/tensor.go @@ -273,27 +273,13 @@ func (ts *Tensor) nbytes() int64 { if numel == 0 { return 0 // ts.None } - dtype := ts.DType() - eltSizeInBytes, err := gotch.DTypeSize(dtype) - if err != nil { - log.Fatal(err) - } - nbytes := int64(numel) * int64(eltSizeInBytes) - - return nbytes + return int64(numel * ts.DType().Size()) } func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 { - // Decode sz - // 1. Count number of elements in data - elementNum := nsize - // 2. Element size in bytes - eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64) - if err != nil { - log.Fatal(err) - } - nbytes := int(eltSizeInBytes) * int(elementNum) + dtype := gotch.Int64 // tensor size dtype = int64 + nbytes := int(nsize) * int(dtype.Size()) dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes] r := bytes.NewReader(dataSlice) dataIn := make([]int64, nsize) @@ -304,8 +290,49 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 { return dataIn } +// TensorOptions constructs options to build/rebuild tensor. +type TensorOptions struct { + Name string + DType gotch.DType + Quantized bool + // TODO. can expand as needed +} + +type TensorOpt func(*TensorOptions) + +func DefaultTensorOptions() *TensorOptions { + return &TensorOptions{ + Name: "", + DType: gotch.Float, + Quantized: false, + } +} + +func WithName(v string) TensorOpt { + return func(o *TensorOptions) { + o.Name = v + } +} + +func WithDType(v gotch.DType) TensorOpt { + return func(o *TensorOptions) { + o.DType = v + } +} + +func WithQuantized(v bool) TensorOpt { + return func(o *TensorOptions) { + o.Quantized = v + } +} + // OfSlice creates tensor from a slice data -func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) { +func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) { + o := DefaultTensorOptions() + for _, opt := range opts { + opt(o) + } + // convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size! if reflect.TypeOf(data).String() == "[]int" { data = sliceIntToInt32(data.([]int)) @@ -318,10 +345,10 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) { return nil, err } - typ := reflect.TypeOf(data).Elem() + elementKind := reflect.TypeOf(data).Elem().Kind() dataLen := v.Len() - dtype, err := gotch.ToDType(typ) + dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized)) if err != nil { return nil, err } @@ -329,12 +356,7 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) { shape := []int64{int64(dataLen)} elementNum := ElementCount(shape) - eltSizeInBytes, err := gotch.DTypeSize(dtype) - if err != nil { - return nil, err - } - - nbytes := int(eltSizeInBytes) * int(elementNum) + nbytes := int(dtype.Size()) * elementNum dataPtr, buff := CMalloc(nbytes) defer C.free(unsafe.Pointer(dataPtr)) @@ -343,30 +365,25 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) { return nil, err } - cint, err := gotch.DType2CInt(dtype) - if err != nil { - return nil, err - } - - ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint)) + ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind())) if err = TorchErr(); err != nil { return nil, err } - return newTensor(ctensor, nameOpt...), nil + return newTensor(ctensor, o.Name), nil // return newTensor(ctensor), nil } // OfDataSize creates Tensor from input byte data, shape and dtype. -func OfDataSize(data []byte, shape []int64, dtype gotch.DType, nameOpt ...string) (*Tensor, error) { - - elementNum := ElementCount(shape) - eltSizeInBytes, err := gotch.DTypeSize(dtype) - if err != nil { - return nil, err +func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) { + o := DefaultTensorOptions() + for _, opt := range opts { + opt(o) } - nbytes := int(eltSizeInBytes) * int(elementNum) + elementNum := ElementCount(shape) + + nbytes := elementNum * int(dtype.Size()) if nbytes != len(data) { err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape) @@ -380,24 +397,19 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType, nameOpt ...string return nil, err } - cint, err := gotch.DType2CInt(dtype) - if err != nil { + ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind())) + if err := TorchErr(); err != nil { return nil, err } - ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint)) - if err = TorchErr(); err != nil { - return nil, err - } - - return newTensor(ctensor, nameOpt...), nil + return newTensor(ctensor, o.Name), nil // return newTensor(ctensor), nil } // MustOfDataSize create Tensor from input byte data and specified shape and dtype // or panic if error -func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor { - ts, err := OfDataSize(data, size, dtype) +func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor { + ts, err := OfDataSize(data, size, dtype, opts...) if err != nil { log.Fatal(err) } @@ -406,8 +418,8 @@ func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor { } // MustOfSlice create a tensor from slice of data. It will be panic if error. -func MustOfSlice(data interface{}) *Tensor { - ts, err := OfSlice(data) +func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor { + ts, err := OfSlice(data, opts...) if err != nil { log.Fatal(err) } @@ -416,8 +428,8 @@ func MustOfSlice(data interface{}) *Tensor { } // TensorFrom create a tensor from slice of data. It will be panic if error. -func TensorFrom(data interface{}, nameOpt ...string) *Tensor { - ts, err := OfSlice(data, nameOpt...) +func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor { + ts, err := OfSlice(data, opts...) if err != nil { log.Fatal(err) } @@ -436,7 +448,12 @@ func (ts *Tensor) Print() { } // NewTensorFromData creates tensor from given data and shape -func NewTensorFromData(data interface{}, shape []int64, nameOpt ...string) (*Tensor, error) { +func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) { + o := DefaultTensorOptions() + for _, opt := range opts { + opt(o) + } + // 1. Check whether data and shape match elementNum, err := DataDim(data) if err != nil { @@ -463,33 +480,18 @@ func NewTensorFromData(data interface{}, shape []int64, nameOpt ...string) (*Ten return nil, err } - eltSizeInBytes, err := gotch.DTypeSize(dtype) - if err != nil { - return nil, err - } - - cint, err := gotch.DType2CInt(dtype) - if err != nil { - return nil, err - } - - ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint)) + ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind())) if err = TorchErr(); err != nil { return nil, err } - return newTensor(ctensor, nameOpt...), nil + return newTensor(ctensor, o.Name), nil } func (ts *Tensor) DType() gotch.DType { cint := lib.AtScalarType(ts.ctensor) - dtype, err := gotch.CInt2DType(cint) - if err != nil { - log.Fatalf("Tensor DType error: %v\n", err) - } - - return dtype + return gotch.CKind2DType(cint) } func (ts *Tensor) Device() (gotch.Device, error) { @@ -545,6 +547,7 @@ func (ts *Tensor) MustDevice() gotch.Device { * return retVal * } * */ + // Float64Value returns a float value on tensors holding a single element. // An error is returned otherwise. // double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len); @@ -748,7 +751,6 @@ func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGra // // NOTE: `dst` located in Go memory. Should it be? func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error { - // NOTE: we must make sure that `dst` has same len as `numel`. Otherwise, // there will be memory leak and or out of range error. if len(dst) < int(numel) { @@ -757,12 +759,9 @@ func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error { } vs := unsafe.Pointer(&dst[0]) - elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8) - if err != nil { - return err - } - lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes) - if err = TorchErr(); err != nil { + dtype := gotch.Uint8 + lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size()) + if err := TorchErr(); err != nil { return err } @@ -784,55 +783,20 @@ func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) { // and number of elements to C land. This may break in the future // if Go policy changes. func (ts *Tensor) CopyData(dst interface{}, numel uint) error { - - gotype, dlen, err := DataCheck(dst) - if err != nil { - return err - } - - dtype, err := gotch.ToDType(gotype) - if err != nil { - return err - } - + dtype, dlen, err := DataCheck(dst) if dlen < int(numel) { - err = fmt.Errorf("CopyData Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel) + err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel) return err } - if ts.DType() != dtype { - err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType()) + err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType()) return err } - var vs unsafe.Pointer - switch dtype { - case gotch.Uint8: - vs = unsafe.Pointer(&dst.([]uint8)[0]) - case gotch.Int8: - vs = unsafe.Pointer(&dst.([]int8)[0]) - case gotch.Int16: - vs = unsafe.Pointer(&dst.([]int16)[0]) - case gotch.Int: - vs = unsafe.Pointer(&dst.([]int32)[0]) - case gotch.Int64: - vs = unsafe.Pointer(&dst.([]int64)[0]) - case gotch.Float: - vs = unsafe.Pointer(&dst.([]float32)[0]) - case gotch.Double: - vs = unsafe.Pointer(&dst.([]float64)[0]) - case gotch.Bool: - vs = unsafe.Pointer(&dst.([]bool)[0]) - default: - err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType()) - return err - } + // Get data pointer + dataPtr := reflect.ValueOf(dst).UnsafePointer() - elt_size_in_bytes, err := gotch.DTypeSize(dtype) - if err != nil { - return err - } - lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes) + lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size()) if err = TorchErr(); err != nil { return err } @@ -863,7 +827,6 @@ func (ts *Tensor) Numel() uint { // ShallowClone returns a new tensor that share storage with the input tensor. func (ts *Tensor) ShallowClone() (*Tensor, error) { - ctensor := lib.AtShallowClone(ts.ctensor) if err := TorchErr(); err != nil { @@ -1309,33 +1272,18 @@ func (ts *Tensor) Int64Values(delOpt ...bool) []int64 { // E.g. res := xs.Vals().([]int64) func (ts *Tensor) Vals() interface{} { dtype := ts.DType() - numel := ts.Numel() + numel := int(ts.Numel()) - var retVal interface{} - - switch dtype.Name() { - case "uint8": - retVal = make([]uint8, numel) - case "int8": - retVal = make([]int8, numel) - case "int16": - retVal = make([]int16, numel) - case "int32": - retVal = make([]int32, numel) - case "int64": - retVal = make([]int64, numel) - case "float32": - retVal = make([]float32, numel) - case "float64": - retVal = make([]float64, numel) - case "bool": - retVal = make([]bool, numel) - default: - log.Fatalf("Unsupported dtype (%v)", dtype) + typ, err := dtype.GoType() + if err != nil { + log.Fatal(err) } - ts.CopyData(retVal, numel) - return retVal + dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface() + + ts.CopyData(dataSlice, uint(numel)) + + return dataSlice } // FlatView flattens a tensor. diff --git a/ts/util.go b/ts/util.go index 0df8bd6..ec6efb2 100644 --- a/ts/util.go +++ b/ts/util.go @@ -75,7 +75,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { if err := w.WriteByte(b); err != nil { return err } - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: if err := binary.Write(w, nativeEndian, v.Interface()); err != nil { return err } @@ -86,14 +86,14 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { if v.Kind() == reflect.Slice { expected := int(shape[0]) if v.Len() != expected { - return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected) + return fmt.Errorf("EncodeTensor() failed: mismatched slice lengths: %d and %d", v.Len(), expected) } } // Optimisation: if only one dimension is left we can use binary.Write() directly for this slice if len(shape) == 1 && v.Len() > 0 { switch v.Index(0).Kind() { - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: return binary.Write(w, nativeEndian, v.Interface()) } } @@ -107,7 +107,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { } default: - return fmt.Errorf("unsupported type %v", v.Type()) + return fmt.Errorf("EncodeTensor() failed: unsupported type %v", v.Type()) } return nil } @@ -122,7 +122,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. return err } ptr.Elem().SetBool(b == 1) - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil { return err } @@ -134,7 +134,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. // Optimization: if only one dimension is left we can use binary.Read() directly for this slice if len(shape) == 1 && val.Len() > 0 { switch val.Index(0).Kind() { - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: return binary.Read(r, nativeEndian, val.Interface()) } } @@ -152,10 +152,10 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. } // ElementCount counts number of element in the tensor given a shape -func ElementCount(shape []int64) int64 { - n := int64(1) +func ElementCount(shape []int64) int { + n := 1 for _, d := range shape { - n *= d + n *= int(d) } return n } @@ -163,54 +163,48 @@ func ElementCount(shape []int64) int64 { // DataDim returns number of elements in data // NOTE: only support scalar and (nested) slice/array of scalar type func DataDim(data interface{}) (retVal int, err error) { - _, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0) return count, err } // DataCheck checks the input data for element Go type and number of elements. -// It will return errors if element type is not supported. -func DataCheck(data interface{}) (k reflect.Type, n int, err error) { - +// It will return errors if element dtype is not supported. +func DataCheck(data interface{}) (dtype gotch.DType, n int, err error) { return dataCheck(reflect.ValueOf(data).Interface(), 0) } // NOTE: 0 is reflect.Kind() of Invalid // See: https://golang.org/pkg/reflect/#Kind -func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) { +func dataCheck(data interface{}, count int) (dtype gotch.DType, n int, err error) { v := reflect.ValueOf(data) - var goType reflect.Type = reflect.TypeOf(data) var total int = count var round = 0 - switch v.Kind() { - case reflect.Slice, reflect.Array: + if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { if round == 0 { round = v.Len() } for i := 0; i < v.Len(); i++ { round-- - goType, total, err = dataCheck(v.Index(i).Interface(), total) + dtype, total, err = dataCheck(v.Index(i).Interface(), total) if err != nil { - return reflect.TypeOf(reflect.Zero), 0, err + return gotch.Invalid, 0, err } } - return goType, total, nil - - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: - total++ - if goType.String() != "invalid" { - goType = v.Type() - } - default: - err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind()) - return reflect.TypeOf(reflect.Zero), 0, err + return dtype, total, nil } - return goType, total, nil + total += 1 + dtype, err = gotch.GoKind2DType(v.Kind()) + if err != nil { + err = fmt.Errorf("DataCheck() failed: unsupported data structure or type: %v\n", v.Kind()) + return gotch.Invalid, 0, err + } + + return dtype, total, nil } // DataAsPtr write to C memory and returns a C pointer. @@ -219,25 +213,19 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) { // Supported data types are scalar, slice/array of scalar type equivalent to // DType. func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) { - // 1. Count number of elements in data elementNum, err := DataDim(data) if err != nil { return nil, err } - // 2. Element size in bytes + // 2. Number of bytes dtype, err := gotch.DTypeFromData(data) if err != nil { return nil, err } - eltSizeInBytes, err := gotch.DTypeSize(dtype) - if err != nil { - return nil, err - } - - nbytes := int(eltSizeInBytes) * int(elementNum) + nbytes := int(dtype.Size()) * int(elementNum) // 3. Get C pointer and prepare C memory buffer for writing dataPtr, buff := CMalloc(nbytes)