reworked gotch.dtype with more dtypes

This commit is contained in:
sugarme 2023-07-07 00:01:23 +10:00
parent 640af9d2df
commit 523061eca6
16 changed files with 2028 additions and 638 deletions

624
dtype.go
View File

@ -3,8 +3,6 @@ package gotch
import (
"fmt"
"log"
// "log"
"reflect"
)
@ -12,151 +10,148 @@ import (
type CInt = int32
// DType represents different kind of element that a tensor can hold.
// It has an embedded `reflect.Type` for type reflection.
type DType struct {
reflect.Type
}
// Ref. https://github.com/pytorch/pytorch/blob/a290cbf32b0c282aa60fa521ca5c6cd19c7f779f/c10/core/ScalarType.h
type DType int
/*
* // Custom-made Float16 as not exist in Go
* // Ref: https://github.com/golang/go/issues/32022
* type GoFloat16 = int16 // not implemented yet
* type GoComplexHalf = interface{} // not implemented yet!
* */
// TODO: double check these Torch DType to Go type
var (
Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0
Int8 DType = DType{reflect.TypeOf(int8(1))} // 1
Int16 DType = DType{reflect.TypeOf(int16(1))} // 2
Int DType = DType{reflect.TypeOf(int32(1))} // 3
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
Half DType = DType{reflect.TypeOf(float32(1))} // 5
Float DType = DType{reflect.TypeOf(float32(1))} // 6
Double DType = DType{reflect.TypeOf(float64(1))} // 7
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
// ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9
// ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10
Bool DType = DType{reflect.TypeOf(true)} // 11
const (
Invalid DType = -1
Uint8 DType = 0
Int8 DType = 1
Int16 DType = 2
Int DType = 3
Int64 DType = 4
Half DType = 5
Float DType = 6
Double DType = 7
ComplexHalf DType = 8
ComplexFloat DType = 9
ComplexDouble DType = 10
Bool DType = 11
QInt8 DType = 12
QUInt8 DType = 13
QInt32 DType = 14
BFloat16 DType = 15
// ---not implemented ---
QUInt4x2 DType = 16
QUInt2x4 DType = 17
Bits1x8 DType = 18
Bits2x4 DType = 19
Bits4x2 DType = 20
Bits8 DType = 21
Bits16 DType = 22
)
var dtypeGoType = map[DType]reflect.Type{
Uint8: reflect.TypeOf(uint8(1)),
Int8: reflect.TypeOf(int8(1)),
Int16: reflect.TypeOf(int16(1)),
Int: reflect.TypeOf(int32(1)),
Int64: reflect.TypeOf(int64(1)),
Half: reflect.TypeOf(float32(1)),
Float: reflect.TypeOf(float32(1)),
Double: reflect.TypeOf(float64(1)),
Bool: reflect.TypeOf(true),
var dtype2CKind = map[DType]CInt{
Uint8: 0,
Int8: 1,
Int16: 2,
Int: 3,
Int64: 4,
Half: 5,
Float: 6,
Double: 7,
ComplexHalf: 8,
ComplexFloat: 9,
ComplexDouble: 10,
Bool: 11,
QInt8: 12,
QUInt8: 13,
QInt32: 14,
BFloat16: 15,
// ---not implemented ---
QUInt4x2: 16,
QUInt2x4: 17,
Bits1x8: 18,
Bits2x4: 19,
Bits4x2: 20,
Bits8: 21,
Bits16: 22,
}
// ToDType infers and returns supported equivalent DType from given Go type
func ToDType(typ reflect.Type) (retVal DType, err error) {
var found = false
for key, val := range dtypeGoType {
if val == typ {
retVal = key
found = true
break
}
func (dt DType) CKind() CInt {
if cint, ok := dtype2CKind[dt]; ok {
return cint
}
if !found {
err = fmt.Errorf("Unsupported Go type: %v", typ)
return DType{}, err
if Debug {
log.Printf("WARNING: dt.CKind() failed: no corresponding CKind to this DType %v\n", dt)
}
return retVal, nil
return -1 // invalid
}
// ToGoType infers and returns supported equivalent Go type from given DType
func ToGoType(dtype DType) (retVal reflect.Type, err error) {
if _, ok := dtypeGoType[dtype]; !ok {
err = fmt.Errorf("Unsupported DType %v", dtype)
return nil, err
}
retVal = dtypeGoType[dtype]
return retVal, nil
// Back compat
func (dt DType) CInt() CInt {
return dt.CKind()
}
var dtypeCInt = map[DType]CInt{
Uint8: 0,
Int8: 1,
Int16: 2,
Int: 3,
Int64: 4,
Half: 5,
Float: 6,
Double: 7,
Bool: 11,
var ckind2DType map[CInt]DType = map[CInt]DType{
0: Uint8,
1: Int8,
2: Int16,
3: Int,
4: Int64,
5: Half,
6: Float,
7: Double,
8: ComplexHalf,
9: ComplexFloat,
10: ComplexDouble,
11: Bool,
12: QInt8,
13: QUInt8,
14: QInt32,
15: BFloat16,
// ---not implemented ---
16: QUInt4x2,
17: QUInt2x4,
18: Bits1x8,
19: Bits2x4,
20: Bits4x2,
21: Bits8,
22: Bits16,
}
func DType2CInt(dt DType) (retVal CInt, err error) {
if _, ok := dtypeCInt[dt]; !ok {
err = fmt.Errorf("Unsupported CInt conversion from DType: %v\n", dt)
func CKind2DType(ckind int32) DType {
if dtype, ok := ckind2DType[ckind]; ok {
return dtype
}
retVal = dtypeCInt[dt]
return retVal, nil
if Debug {
log.Printf("WARNING: CKind2DType() failed: no corresponding DType to input CInt %v\n", ckind)
}
return -1 // invalid
}
func (dt DType) CInt() (retVal CInt) {
retVal, err := DType2CInt(dt)
if err != nil {
log.Fatal(err)
}
return retVal
var dtypeSize map[DType]uint = map[DType]uint{
Uint8: 1,
Int8: 1,
Int16: 2,
Int: 4,
Int64: 8,
Half: 2,
Float: 4,
Double: 8,
ComplexHalf: 4,
ComplexFloat: 8,
ComplexDouble: 16,
Bool: 1,
QInt8: 1,
QUInt8: 1,
QInt32: 4,
BFloat16: 2,
QUInt4x2: 2,
QUInt2x4: 1,
// ---not implemented ---
Bits1x8: 1,
Bits2x4: 1,
Bits4x2: 1,
Bits8: 1,
Bits16: 2,
}
func CInt2DType(v CInt) (dtype DType, err error) {
var found = false
for key, val := range dtypeCInt {
if val == v {
dtype = key
found = true
break
}
}
if !found {
err = fmt.Errorf("Unsuported DType for CInt %v\n", v)
return DType{}, err
}
return dtype, nil
}
// dtypeSize is a map of DType and its size in Bytes
var dtypeSize = map[DType]uint{
Uint8: 1,
Int8: 1,
Int16: 2,
Int: 4,
Int64: 8,
Half: 4, // Should it be?
Float: 4,
Double: 8,
Bool: 1,
}
// DTypeSize returns DType size in Bytes
func DTypeSize(dt DType) (retVal uint, err error) {
if _, ok := dtypeSize[dt]; !ok {
err = fmt.Errorf("Unsupported conversion DType size in Byte for DType: %v\n", dt)
return 99, err
}
retVal = dtypeSize[dt]
return retVal, nil
// Size returns dtype size in Bytes.
func (dt DType) Size() uint {
return dtypeSize[dt]
}
type DTypeDevice struct {
@ -174,201 +169,228 @@ var (
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
)
// Type Inferring:
// ===============
// DTypeFromData infers returns equavalent DType from given data
func DTypeFromData(data interface{}) (retVal DType, err error) {
// NOTE: call `Interface()` to get data type back to interface{} type
typ, _, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
if err != nil {
return retVal, err
}
if typ.Kind() == reflect.Slice {
return ToDType(typ.Elem())
}
return ToDType(typ)
var dtype2GoKind map[DType]reflect.Kind = map[DType]reflect.Kind{
Uint8: reflect.Uint8,
Int8: reflect.Int8,
Int16: reflect.Int16,
Int: reflect.Int32,
Int64: reflect.Int64,
Half: reflect.Uint16, // <- just uint16
Float: reflect.Float32,
Double: reflect.Float64,
ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
ComplexFloat: reflect.Complex64,
ComplexDouble: reflect.Complex128,
Bool: reflect.Bool,
QInt8: reflect.Int8,
QUInt8: reflect.Uint8,
QInt32: reflect.Int32,
BFloat16: reflect.Uint16, // <- just uint16
// ---not implemented ---
QUInt4x2: reflect.Invalid,
QUInt2x4: reflect.Invalid,
Bits1x8: reflect.Invalid,
Bits2x4: reflect.Invalid,
Bits4x2: reflect.Invalid,
Bits8: reflect.Invalid,
Bits16: reflect.Invalid,
}
// NOTE: 0 is reflect.Kind() of Invalid
// See: https://golang.org/pkg/reflect/#Kind
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
v := reflect.ValueOf(data)
var goType reflect.Type = reflect.TypeOf(data)
var total int = count
var round = 0
func (dt DType) GoKind() reflect.Kind {
if kind, ok := dtype2GoKind[dt]; ok && kind != reflect.Invalid {
return kind
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
if round == 0 {
round = v.Len()
if Debug {
log.Printf("WARNING: DType.GoKind() failed: no corresponding Go reflect.Kind to given DType %v\n", dt)
}
return reflect.Invalid
}
const (
// NOTE. reflect.Kind 0-26
QUInt8Kind reflect.Kind = 27
QInt8Kind reflect.Kind = 28
QInt32Kind reflect.Kind = 29
Float16Kind reflect.Kind = 30
BFloat16Kind reflect.Kind = 31
QUInt4x2Kind reflect.Kind = 32
QUInt2x4Kind reflect.Kind = 33
Bits1x8Kind reflect.Kind = 34
Bits2x4Kind reflect.Kind = 35
Bits4x2Kind reflect.Kind = 36
Bits8Kind reflect.Kind = 37
Bits16Kind reflect.Kind = 38
ComplexHalfKind reflect.Kind = 39
)
var goKind2DType map[reflect.Kind]DType = map[reflect.Kind]DType{
reflect.Uint8: Uint8,
reflect.Int8: Int8,
reflect.Int16: Int16,
reflect.Int32: Int,
reflect.Int64: Int64,
reflect.Float32: Float,
reflect.Float64: Double,
reflect.Complex64: ComplexFloat,
reflect.Complex128: ComplexDouble,
reflect.Bool: Bool,
reflect.Uint16: Half,
// Added Kinds
QUInt8Kind: QUInt8,
QInt8Kind: QInt8,
QInt32Kind: QInt32,
// Float16Kind: Half,
BFloat16Kind: BFloat16,
QUInt4x2Kind: QUInt4x2,
QUInt2x4Kind: QUInt2x4,
Bits1x8Kind: Bits1x8,
Bits2x4Kind: Bits2x4,
Bits4x2Kind: Bits4x2,
Bits8Kind: Bits8,
Bits16Kind: Bits16,
ComplexHalfKind: ComplexHalf,
}
type DTypeOptions struct {
HalfDTypePref DType
Quantized bool
}
type DTypeOpt func(*DTypeOptions)
func DefaultDTypeOptions() *DTypeOptions {
return &DTypeOptions{
HalfDTypePref: Half,
Quantized: false,
}
}
func HalfDTypePref(v DType) DTypeOpt {
if v != Half && v != BFloat16 {
if Debug {
log.Printf("WARNING: HalfDTypePref(): Ignoring invalid HalfDTypePref. HalfDTypePref either 'gotch.Half' or 'gotch.BFloat16'. Got %v\n", v)
}
for i := 0; i < v.Len(); i++ {
round--
goType, total, err = dataCheck(v.Index(i).Interface(), total)
if err != nil {
return reflect.TypeOf(reflect.Zero), 0, err
}
}
return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++
if goType.String() != "invalid" {
goType = v.Type()
}
default:
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
return reflect.TypeOf(reflect.Zero), 0, err
}
return goType, total, nil
}
// ElementGoType infers and returns Go type of element in given data
func ElementGoType(data interface{}) (retVal reflect.Type, err error) {
dataValue := reflect.ValueOf(data)
return elementType(dataValue)
}
func elementType(data reflect.Value) (dataType reflect.Type, err error) {
dataKind := data.Kind()
switch dataKind {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
dataType = data.Type()
case reflect.Slice, reflect.Array:
data = data.Elem()
dataType, err = elementType(data) // recursively type inferring
default:
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
return DType{}, err
return func(o *DTypeOptions) {
o.HalfDTypePref = v
}
return dataType, nil
}
// DataDType infers and returns data type of tensor data
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
// assuming that all elements in data have the same type
switch {
case len(shape) == 0:
retVal, err = ElementDType(v)
case len(shape) > 0:
return ElementDType(v.([]interface{})[0])
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
return DType{}, err
func WithQuantized(v bool) DTypeOpt {
return func(o *DTypeOptions) {
o.Quantized = v
}
return DType{}, nil
}
// ElementDType infers and returns its own tensor data type
func ElementDType(v interface{}) (retVal DType, err error) {
switch v.(type) {
case uint8:
retVal = Uint8
case int8:
retVal = Int8
case int16:
retVal = Int16
case int32:
retVal = Int
case int64:
retVal = Int64
case float32:
retVal = Float
case float64:
retVal = Double
case bool:
retVal = Bool
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
}
return retVal, nil
}
// TypeOf infers and returns element Go type from given tensor DType and shape
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
var typ reflect.Type
if typ, err = ToGoType(dt); err != nil {
return nil, err
func GoKind2DType(kind reflect.Kind, opts ...DTypeOpt) (DType, error) {
o := DefaultDTypeOptions()
for _, opt := range opts {
opt(o)
}
switch {
case len(shape) == 0:
return typ, nil
case len(shape) > 0:
return reflect.SliceOf(typ), nil
case kind == reflect.Uint16 && o.HalfDTypePref == Half:
return Half, nil
case kind == reflect.Uint16 && o.HalfDTypePref == BFloat16:
return BFloat16, nil
case kind == reflect.Int8 && o.Quantized:
return QInt8, nil
case kind == reflect.Uint8 && o.Quantized:
return QUInt8, nil
case kind == reflect.Int32 && o.Quantized:
return QInt32, nil
default:
err = fmt.Errorf("Unsupported data type.")
return nil, err
dtype, ok := goKind2DType[kind]
if !ok {
err := fmt.Errorf("GoKind2DType() failed: no corresponding DType to given Go reflect.Kind %v\n", kind)
return Invalid, err
}
return dtype, nil
}
}
/*
* // TypeCheck checks whether data Go type matching DType
* func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
* dataValue := reflect.ValueOf(data)
* var dataType reflect.Type
* var err error
* dataType, err = elementType(dataValue)
* if err != nil {
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
* msg += err.Error()
* return false, msg
* }
*
* matched = dataType == dtype.Type
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
*
* return matched, msg
* }
* */
var supportedTypes = map[reflect.Kind]bool{
reflect.Uint8: true,
reflect.Int8: true,
reflect.Int16: true,
reflect.Int32: true,
reflect.Int64: true,
reflect.Float32: true,
reflect.Float64: true,
reflect.Bool: true,
var dtype2GoType map[DType]reflect.Type = map[DType]reflect.Type{
Uint8: reflect.TypeOf(uint8(0)),
Int8: reflect.TypeOf(int8(0)),
Int16: reflect.TypeOf(int16(0)),
Int: reflect.TypeOf(int(0)),
Int64: reflect.TypeOf(int64(0)),
Half: reflect.TypeOf(uint16(0)), // <- just uint16
Float: reflect.TypeOf(float32(0)),
Double: reflect.TypeOf(float64(0)),
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
ComplexFloat: reflect.TypeOf(complex64(0)),
ComplexDouble: reflect.TypeOf(complex128(0)),
Bool: reflect.TypeOf(true),
QInt8: reflect.TypeOf(int8(0)),
QUInt8: reflect.TypeOf(uint8(0)),
QInt32: reflect.TypeOf(int32(0)),
BFloat16: reflect.TypeOf(uint16(0)), // <- just uint16
// ---not implemented ---
QUInt4x2: reflect.TypeOf(int8(0)),
QUInt2x4: reflect.TypeOf(uint8(0)),
Bits1x8: reflect.TypeOf(uint8(0)),
Bits2x4: reflect.TypeOf(uint8(0)),
Bits4x2: reflect.TypeOf(uint8(0)),
Bits8: reflect.TypeOf(uint8(0)),
Bits16: reflect.TypeOf(uint16(0)),
}
var scalarTypes = map[reflect.Kind]bool{
reflect.Bool: true,
reflect.Int: true,
reflect.Int8: true,
reflect.Int16: true,
reflect.Int32: true,
reflect.Int64: true,
reflect.Uint: true,
reflect.Uint8: true,
reflect.Uint16: true,
reflect.Uint32: true,
reflect.Uint64: true,
reflect.Uintptr: true,
reflect.Float32: true,
reflect.Float64: true,
reflect.Complex64: true,
reflect.Complex128: true,
func (dt DType) GoType() (reflect.Type, error) {
typ, ok := dtype2GoType[dt]
if !ok {
err := fmt.Errorf("DType.GoType() failed: no corresponding Go type to given DType %v\n", typ.String())
return nil, err
}
return typ, nil
}
// IsSupportedScalar checks whether given SCALAR type is supported
// TODO: check input is a scalar.
func IsSupportedScalar(k reflect.Kind) bool {
// if _, ok := scalarTypes[k]; !ok {
// log.Fatalf("Input type: %v is not a Go scalar type.", k)
// }
_, retVal := supportedTypes[k]
return retVal
var dtypeNames map[DType]string = map[DType]string{
Uint8: "Uint8",
Int8: "Int8",
Int16: "Int16",
Int: "Int",
Int64: "Int64",
Half: "Half", // <- just uint16
Float: "Float",
Double: "Double",
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
ComplexFloat: "ComplexFloat",
ComplexDouble: "ComplexDouble",
Bool: "Bool",
QInt8: "QInt8",
QUInt8: "QUInt8",
QInt32: "QInt32",
BFloat16: "BFloat16", // <- just uint16
// ---not implemented ---
QUInt4x2: "QUInt4x2",
QUInt2x4: "QUInt2x4",
Bits1x8: "Bits1x8",
Bits2x4: "Bits2x4",
Bits4x2: "Bits4x2",
Bits8: "Bits8",
Bits16: "Bits16",
}
func (dt DType) String() string {
return dtypeNames[dt]
}
func DTypeFromData(data interface{}) (DType, error) {
dataKind := reflect.TypeOf(data).Kind()
// Data is a slice/array
if dataKind == reflect.Slice || dataKind == reflect.Array {
elementKind := reflect.TypeOf(data).Elem().Kind()
return GoKind2DType(elementKind)
}
// single element
return GoKind2DType(dataKind)
}

View File

@ -113,10 +113,15 @@ var ModelUrls map[string]string = map[string]string{
// 1. Resolves input string to a fullpath cached filename candidate.
// 2. Check it at `CachedDir`, if exists, then return the candidate. If not
// 3. Retrieves and Caches data to `CachedDir` and returns path to cached data
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
func CachedPath(filenameOrUrl string, folderOpt ...string) (resolvedPath string, err error) {
filename := path.Base(filenameOrUrl)
// Resolves to "candidate" filename at `CachedDir`
cachedFileCandidate := fmt.Sprintf("%s/%s", CachedDir, filename)
fullPath := CachedDir
if len(folderOpt) > 0 {
fullPath = fmt.Sprintf("%v/%v", CachedDir, folderOpt[0])
}
cachedFileCandidate := fmt.Sprintf("%s/%s", fullPath, filename)
// 1. Cached candidate file exists
if _, err := os.Stat(cachedFileCandidate); err == nil {

24
go.sum
View File

@ -1,7 +1,31 @@
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5 h1:QelT11PB4FXiDEXucrfNckHoFxwt8USGY1ajP1ZF5lM=
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI=
golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

167
half/bfloat16.go Normal file
View File

@ -0,0 +1,167 @@
package half
import (
"math"
"math/bits"
)
// A 16-bit floating point type implementing the bfloat16 format.
// Ref. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
// https://github.com/starkat99/half-rs/tree/main/src/bfloat
// The bfloat16 - Google 'brain' floating point format is a truncated 16-bit version of the IEEE 754 standard binary32.
// bfloat16 has approximately the same dynamic range as float32 (8 bits -> 3.4 × 10^38) by having a lower precision than float16.
// While float16 has a precision of 10 bits, bfloat16 has a precision of only 7 bits.
//
// +------------+------------------------+----------------------------+
// | 1-bit sign | 8-bit exponent (range) | 7-bit fraction (precision) |
// +------------+------------------------+----------------------------+
type BFloat16 uint16
// Ref.https://github.com/starkat99/half-rs/blob/cabfc74e2a48b44b4556780f9d1550dd50a708be/src/bfloat/convert.rs#L5C1-L24C1
func Float32ToBFloat16(value float32) uint16 {
// convert to raw bytes
x := math.Float32bits(value)
// Check for NaN
if (x & 0x7FFF_FFFF) > 0x7F80_0000 {
// keep high part of current mantissa but also set most significant mantissa bit
return uint16((x >> 16) | 0x0040)
}
// Round and shift
var roundBit uint32 = 0x0000_8000
if ((x & roundBit) != 0) && ((x & (3*roundBit - 1)) != 0) {
return uint16(x>>16) + 1
} else {
return uint16(x >> 16)
}
}
func Float64ToBFloat16(value float64) uint16 {
// Convert o raw bytes, truncating the last 32-bits of mantissa
// that precision will always be lost on half-precision
val := math.Float64bits(value)
x := uint32(val >> 32)
// Extract IEEE754 components
sign := x & 0x8000_0000
exp := x & 0x7FF0_0000
man := x & 0x000F_FFFF
// Check for all exponent bit being set, which is Infinity or NaN
if exp == 0x7FF0_0000 {
// Set mantissa MSB for NaN and also keep shifted mantissa bits.
// Also check the last 32 bits.
var nanBit uint32 = 0x0040
if man == 0 && (uint32(val) == 0) {
nanBit = 0
}
return uint16((sign >> 16) | 0x7F80 | nanBit | (man >> 13))
}
// The number is normalized, start assembling half precision version
halfSign := sign >> 16
// Unbias the exponent, then bias for bfloat16 precision
unbiasedExp := (int64(exp>>20) - 1023)
halfExp := unbiasedExp + 127
// Check for exponent overflow, return +infinity
if halfExp >= 0xFF {
return uint16(halfSign | 0x7F80)
}
// Check for underflow
if halfExp <= 0 {
// Check mantissa for what we can do
if 7-halfExp > 21 {
// No rounding possibility, so this is a full underflow, return signed zero
return uint16(halfSign)
}
// Don't forget about hidden leading mantissa bit when assembling mantissa
man = man | 0x0010_0000
halfMan := man >> (14 - halfExp)
// Check for rounding
var roundBit uint32 = 1 << (13 - halfExp)
if ((man & roundBit) != 0) && ((man & (3*roundBit - 1)) != 0) {
halfMan += 1
}
// No exponent for subnormals
return uint16(halfSign | halfMan)
}
// Rebias the exponent
halfExp1 := uint32(halfExp) << 7
halfMan1 := man >> 13
// Check for rounding
var roundBit1 uint32 = 0x0000_1000
if ((man & roundBit1) != 0) && ((man & (3*roundBit1 - 1)) != 0) {
// Round it
return uint16((halfSign | halfExp1 | halfMan1) + 1)
} else {
return uint16(halfSign | halfExp1 | halfMan1)
}
}
func BFloat16ToFloat32(i uint16) float32 {
// If NaN, keep current mantissa but also set most significant mantissa bit
if i&0x7FFF > 0x7F80 {
return math.Float32frombits((uint32(i) | 0x0040) << 16)
} else {
return math.Float32frombits(uint32(i) << 16)
}
}
func BFloat16ToFloat64(i uint16) float64 {
// Check for signed zero
if i&0x7FFF == 0 {
return math.Float64frombits(uint64(i) << 48)
}
halfSign := uint64(i & 0x8000)
halfExp := uint64(i & 0x7F80)
halfMan := uint64(i & 0x007F)
// Check for an infinity or NaN when all exponent bits set
if halfExp == 0x7F80 {
// Check for signed infinity if mantissa is zero
if halfMan == 0 {
return math.Float64frombits((halfSign << 48) | 0x7FF0_0000_0000_0000)
} else {
// NaN, keep current mantissa but also set most significant mantissa bit
return math.Float64frombits((halfSign << 48) | 0x7FF8_0000_0000_0000 | (halfMan << 45))
}
}
// Calculate double-precision components with adjusted exponent
sign := halfSign << 48
// Unbias exponent
unbiasedExp := (int64(halfExp) >> 7) - 127
// Check for subnormals, which will be normalized by adjusting exponent
if halfExp == 0 {
// Calculate how much to adjust the exponent by
// leading zeros uint16
e := bits.LeadingZeros16(uint16(halfMan)) - 9
// Rebias and adjust exponent
exp := (uint64(1023-127-e) << 52)
man := (halfMan << (46 + e)) & 0xF_FFFF_FFFF_FFFF
return math.Float64frombits(sign | exp | man)
}
// Rebias exponent for a normalized normal
exp := uint64(unbiasedExp+1023) << 52
man := (halfMan & 0x007F) << 45
return math.Float64frombits(sign | exp | man)
}

1
half/bfloat16_test.go Normal file
View File

@ -0,0 +1 @@
package half

303
half/float16.go Normal file
View File

@ -0,0 +1,303 @@
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
//
// Special thanks to Kathryn Long for her Rust implementation
// of float16 at github.com/starkat99/half-rs (MIT license)
// Package half defines support for half-precision floating-point numbers.
package half
import (
"math"
"strconv"
)
// Float16 represents IEEE 754 half-precision floating-point numbers (binary16).
type Float16 uint16
// Precision indicates whether the conversion to Float16 is
// exact, subnormal without dropped bits, inexact, underflow, or overflow.
type Precision int
const (
// PrecisionExact is for non-subnormals that don't drop bits during conversion.
// All of these can round-trip. Should always convert to float16.
PrecisionExact Precision = iota
// PrecisionUnknown is for subnormals that don't drop bits during conversion but
// not all of these can round-trip so precision is unknown without more effort.
// Only 2046 of these can round-trip and the rest cannot round-trip.
PrecisionUnknown
// PrecisionInexact is for dropped significand bits and cannot round-trip.
// Some of these are subnormals. Cannot round-trip float32->float16->float32.
PrecisionInexact
// PrecisionUnderflow is for Underflows. Cannot round-trip float32->float16->float32.
PrecisionUnderflow
// PrecisionOverflow is for Overflows. Cannot round-trip float32->float16->float32.
PrecisionOverflow
)
// PrecisionFromfloat32 returns Precision without performing
// the conversion. Conversions from both Infinity and NaN
// values will always report PrecisionExact even if NaN payload
// or NaN-Quiet-Bit is lost. This function is kept simple to
// allow inlining and run < 0.5 ns/op, to serve as a fast filter.
func PrecisionFromfloat32(f32 float32) Precision {
u32 := math.Float32bits(f32)
if u32 == 0 || u32 == 0x80000000 {
// +- zero will always be exact conversion
return PrecisionExact
}
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
const EXPSHIFT uint32 = 23
const EXPBIAS uint32 = 127
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
const DROPMASK uint32 = COEFMASK >> 10
exp := int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
coef := u32 & COEFMASK
if exp == 128 {
// +- infinity or NaN
// apps may want to do extra checks for NaN separately
return PrecisionExact
}
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format says,
// "Decimals between 2^24 (minimum positive subnormal) and 2^14 (maximum subnormal): fixed interval 2^24"
if exp < -24 {
return PrecisionUnderflow
}
if exp > 15 {
return PrecisionOverflow
}
if (coef & DROPMASK) != uint32(0) {
// these include subnormals and non-subnormals that dropped bits
return PrecisionInexact
}
if exp < -14 {
// Subnormals. Caller may want to test these further.
// There are 2046 subnormals that can successfully round-trip f32->f16->f32
// and 20 of those 2046 have 32-bit input coef == 0.
// RFC 7049 and 7049bis Draft 12 don't precisely define "preserves value"
// so some protocols and libraries will choose to handle subnormals differently
// when deciding to encode them to CBOR float32 vs float16.
return PrecisionUnknown
}
return PrecisionExact
}
// Frombits returns the float16 number corresponding to the IEEE 754 binary16
// representation u16, with the sign bit of u16 and the result in the same bit
// position. Frombits(Bits(x)) == x.
func Frombits(u16 uint16) Float16 {
return Float16(u16)
}
// Fromfloat32 returns a Float16 value converted from f32. Conversion uses
// IEEE default rounding (nearest int, with ties to even).
func Fromfloat32(f32 float32) Float16 {
return Float16(f32bitsToF16bits(math.Float32bits(f32)))
}
// ErrInvalidNaNValue indicates a NaN was not received.
const ErrInvalidNaNValue = float16Error("float16: invalid NaN value, expected IEEE 754 NaN")
type float16Error string
func (e float16Error) Error() string { return string(e) }
// FromNaN32ps converts nan to IEEE binary16 NaN while preserving both
// signaling and payload. Unlike Fromfloat32(), which can only return
// qNaN because it sets quiet bit = 1, this can return both sNaN and qNaN.
// If the result is infinity (sNaN with empty payload), then the
// lowest bit of payload is set to make the result a NaN.
// Returns ErrInvalidNaNValue and 0x7c01 (sNaN) if nan isn't IEEE 754 NaN.
// This function was kept simple to be able to inline.
func FromNaN32ps(nan float32) (Float16, error) {
const SNAN = Float16(uint16(0x7c01)) // signaling NaN
u32 := math.Float32bits(nan)
sign := u32 & 0x80000000
exp := u32 & 0x7f800000
coef := u32 & 0x007fffff
if (exp != 0x7f800000) || (coef == 0) {
return SNAN, ErrInvalidNaNValue
}
u16 := uint16((sign >> 16) | uint32(0x7c00) | (coef >> 13))
if (u16 & 0x03ff) == 0 {
// result became infinity, make it NaN by setting lowest bit in payload
u16 |= 0x0001
}
return Float16(u16), nil
}
// NaN returns a Float16 of IEEE 754 binary16 not-a-number (NaN).
// Returned NaN value 0x7e01 has all exponent bits = 1 with the
// first and last bits = 1 in the significand. This is consistent
// with Go's 64-bit math.NaN(). Canonical CBOR in RFC 7049 uses 0x7e00.
func NaN() Float16 {
return Float16(0x7e01)
}
// Inf returns a Float16 with an infinity value with the specified sign.
// A sign >= returns positive infinity.
// A sign < 0 returns negative infinity.
func Inf(sign int) Float16 {
if sign >= 0 {
return Float16(0x7c00)
}
return Float16(0x8000 | 0x7c00)
}
// Float32 returns a float32 converted from f (Float16).
// This is a lossless conversion.
func (f Float16) Float32() float32 {
u32 := f16bitsToF32bits(uint16(f))
return math.Float32frombits(u32)
}
// Bits returns the IEEE 754 binary16 representation of f, with the sign bit
// of f and the result in the same bit position. Bits(Frombits(x)) == x.
func (f Float16) Bits() uint16 {
return uint16(f)
}
// IsNaN reports whether f is an IEEE 754 binary16 “not-a-number” value.
func (f Float16) IsNaN() bool {
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0)
}
// IsQuietNaN reports whether f is a quiet (non-signaling) IEEE 754 binary16
// “not-a-number” value.
func (f Float16) IsQuietNaN() bool {
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0) && (f&0x0200 != 0)
}
// IsInf reports whether f is an infinity (inf).
// A sign > 0 reports whether f is positive inf.
// A sign < 0 reports whether f is negative inf.
// A sign == 0 reports whether f is either inf.
func (f Float16) IsInf(sign int) bool {
return ((f == 0x7c00) && sign >= 0) ||
(f == 0xfc00 && sign <= 0)
}
// IsFinite returns true if f is neither infinite nor NaN.
func (f Float16) IsFinite() bool {
return (uint16(f) & uint16(0x7c00)) != uint16(0x7c00)
}
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
func (f Float16) IsNormal() bool {
exp := uint16(f) & uint16(0x7c00)
return (exp != uint16(0x7c00)) && (exp != 0)
}
// Signbit reports whether f is negative or negative zero.
func (f Float16) Signbit() bool {
return (uint16(f) & uint16(0x8000)) != 0
}
// String satisfies the fmt.Stringer interface.
func (f Float16) String() string {
return strconv.FormatFloat(float64(f.Float32()), 'f', -1, 32)
}
// f16bitsToF32bits returns uint32 (float32 bits) converted from specified uint16.
func f16bitsToF32bits(in uint16) uint32 {
// All 65536 conversions with this were confirmed to be correct
// by Montgomery Edwards⁴⁴⁸ (github.com/x448).
sign := uint32(in&0x8000) << 16 // sign for 32-bit
exp := uint32(in&0x7c00) >> 10 // exponenent for 16-bit
coef := uint32(in&0x03ff) << 13 // significand for 32-bit
if exp == 0x1f {
if coef == 0 {
// infinity
return sign | 0x7f800000 | coef
}
// NaN
return sign | 0x7fc00000 | coef
}
if exp == 0 {
if coef == 0 {
// zero
return sign
}
// normalize subnormal numbers
exp++
for coef&0x7f800000 == 0 {
coef <<= 1
exp--
}
coef &= 0x007fffff
}
return sign | ((exp + (0x7f - 0xf)) << 23) | coef
}
// f32bitsToF16bits returns uint16 (Float16 bits) converted from the specified float32.
// Conversion rounds to nearest integer with ties to even.
func f32bitsToF16bits(u32 uint32) uint16 {
// Translated from Rust to Go by Montgomery Edwards⁴⁴⁸ (github.com/x448).
// All 4294967296 conversions with this were confirmed to be correct by x448.
// Original Rust implementation is by Kathryn Long (github.com/starkat99) with MIT license.
sign := u32 & 0x80000000
exp := u32 & 0x7f800000
coef := u32 & 0x007fffff
if exp == 0x7f800000 {
// NaN or Infinity
nanBit := uint32(0)
if coef != 0 {
nanBit = uint32(0x0200)
}
return uint16((sign >> 16) | uint32(0x7c00) | nanBit | (coef >> 13))
}
halfSign := sign >> 16
unbiasedExp := int32(exp>>23) - 127
halfExp := unbiasedExp + 15
if halfExp >= 0x1f {
return uint16(halfSign | uint32(0x7c00))
}
if halfExp <= 0 {
if 14-halfExp > 24 {
return uint16(halfSign)
}
c := coef | uint32(0x00800000)
halfCoef := c >> uint32(14-halfExp)
roundBit := uint32(1) << uint32(13-halfExp)
if (c&roundBit) != 0 && (c&(3*roundBit-1)) != 0 {
halfCoef++
}
return uint16(halfSign | halfCoef)
}
uHalfExp := uint32(halfExp) << 10
halfCoef := coef >> 13
roundBit := uint32(0x00001000)
if (coef&roundBit) != 0 && (coef&(3*roundBit-1)) != 0 {
return uint16((halfSign | uHalfExp | halfCoef) + 1)
}
return uint16(halfSign | uHalfExp | halfCoef)
}

View File

@ -0,0 +1,88 @@
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
package half_test
import (
"math"
"testing"
float16 "github.com/sugarme/gotch/half"
)
// prevent compiler optimizing out code by assigning to these
var resultF16 float16.Float16
var resultF32 float32
var resultStr string
var pcn float16.Precision
func BenchmarkFloat32pi(b *testing.B) {
result := float32(0)
pi32 := float32(math.Pi)
pi16 := float16.Fromfloat32(pi32)
for i := 0; i < b.N; i++ {
f16 := float16.Frombits(uint16(pi16))
result = f16.Float32()
}
resultF32 = result
}
func BenchmarkFrombits(b *testing.B) {
result := float16.Float16(0)
pi32 := float32(math.Pi)
pi16 := float16.Fromfloat32(pi32)
for i := 0; i < b.N; i++ {
result = float16.Frombits(uint16(pi16))
}
resultF16 = result
}
func BenchmarkFromFloat32pi(b *testing.B) {
result := float16.Float16(0)
pi := float32(math.Pi)
for i := 0; i < b.N; i++ {
result = float16.Fromfloat32(pi)
}
resultF16 = result
}
func BenchmarkFromFloat32nan(b *testing.B) {
result := float16.Float16(0)
nan := float32(math.NaN())
for i := 0; i < b.N; i++ {
result = float16.Fromfloat32(nan)
}
resultF16 = result
}
func BenchmarkFromFloat32subnorm(b *testing.B) {
result := float16.Float16(0)
subnorm := math.Float32frombits(0x007fffff)
for i := 0; i < b.N; i++ {
result = float16.Fromfloat32(subnorm)
}
resultF16 = result
}
func BenchmarkPrecisionFromFloat32(b *testing.B) {
var result float16.Precision
for i := 0; i < b.N; i++ {
f32 := float32(0.00001) + float32(0.00001)
result = float16.PrecisionFromfloat32(f32)
}
pcn = result
}
func BenchmarkString(b *testing.B) {
var result string
pi32 := float32(math.Pi)
pi16 := float16.Fromfloat32(pi32)
for i := 0; i < b.N; i++ {
result = pi16.String()
}
resultStr = result
}

798
half/float16_test.go Normal file
View File

@ -0,0 +1,798 @@
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
package half_test
import (
"bytes"
"crypto/sha512"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"testing"
float16 "github.com/sugarme/gotch/half"
)
// wantF32toF16bits is a tiny subset of expected values
var wantF32toF16bits = []struct {
in float32
out uint16
}{
// generated to provide 100% code coverage plus additional tests for rounding, etc.
{in: math.Float32frombits(0x00000000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00000001), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00001fff), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00002000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00003fff), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00004000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x007fffff), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00800000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x33000000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x33000001), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
{in: math.Float32frombits(0x33000002), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
{in: math.Float32frombits(0x387fc000), out: 0x03ff}, // in f32=0.000061, out f16=0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
{in: math.Float32frombits(0x387fffff), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
{in: math.Float32frombits(0x38800000), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
{in: math.Float32frombits(0x38801fff), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
{in: math.Float32frombits(0x38802000), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
{in: math.Float32frombits(0x38803fff), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
{in: math.Float32frombits(0x38804000), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
{in: math.Float32frombits(0x33bfffff), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
{in: math.Float32frombits(0x33c00000), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
{in: math.Float32frombits(0x33c00001), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
{in: math.Float32frombits(0x477fffff), out: 0x7c00}, // in f32=65535.996094, out f16=+Inf
{in: math.Float32frombits(0x47800000), out: 0x7c00}, // in f32=65536.000000, out f16=+Inf
{in: math.Float32frombits(0x7f7fffff), out: 0x7c00}, // in f32=340282346638528859811704183484516925440.000000, out f16=+Inf
{in: math.Float32frombits(0x7f800000), out: 0x7c00}, // in f32=+Inf, out f16=+Inf
{in: math.Float32frombits(0x7f801fff), out: 0x7e00}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7f802000), out: 0x7e01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7f803fff), out: 0x7e01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7f804000), out: 0x7e02}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7fffffff), out: 0x7fff}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x80000000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80001fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80002000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80003fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80004000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x807fffff), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80800000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0xb87fc000), out: 0x83ff}, // in f32=-0.000061, out f16=-0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
{in: math.Float32frombits(0xb87fffff), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
{in: math.Float32frombits(0xb8800000), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
{in: math.Float32frombits(0xb8801fff), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
{in: math.Float32frombits(0xb8802000), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
{in: math.Float32frombits(0xb8803fff), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
{in: math.Float32frombits(0xb8804000), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
{in: math.Float32frombits(0xc77fffff), out: 0xfc00}, // in f32=-65535.996094, out f16=-Inf
{in: math.Float32frombits(0xc7800000), out: 0xfc00}, // in f32=-65536.000000, out f16=-Inf
{in: math.Float32frombits(0xff7fffff), out: 0xfc00}, // in f32=-340282346638528859811704183484516925440.000000, out f16=-Inf
{in: math.Float32frombits(0xff800000), out: 0xfc00}, // in f32=-Inf, out f16=-Inf
{in: math.Float32frombits(0xff801fff), out: 0xfe00}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0xff802000), out: 0xfe01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0xff803fff), out: 0xfe01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0xff804000), out: 0xfe02}, // in f32=NaN, out f16=NaN
// additional tests
{in: math.Float32frombits(0xc77ff000), out: 0xfc00}, // in f32=-65520.000000, out f16=-Inf
{in: math.Float32frombits(0xc77fef00), out: 0xfbff}, // in f32=-65519.000000, out f16=-65504
{in: math.Float32frombits(0xc77fee00), out: 0xfbff}, // in f32=-65518.000000, out f16=-65504
{in: math.Float32frombits(0xc5802000), out: 0xec01}, // in f32=-4100.000000, out f16=-4100
{in: math.Float32frombits(0xc5801800), out: 0xec01}, // in f32=-4099.000000, out f16=-4100
{in: math.Float32frombits(0xc5801000), out: 0xec00}, // in f32=-4098.000000, out f16=-4096
{in: math.Float32frombits(0xc5800800), out: 0xec00}, // in f32=-4097.000000, out f16=-4096
{in: math.Float32frombits(0xc5800000), out: 0xec00}, // in f32=-4096.000000, out f16=-4096
{in: math.Float32frombits(0xc57ff000), out: 0xec00}, // in f32=-4095.000000, out f16=-4096
{in: math.Float32frombits(0xc57fe000), out: 0xebff}, // in f32=-4094.000000, out f16=-4094
{in: math.Float32frombits(0xc57fd000), out: 0xebfe}, // in f32=-4093.000000, out f16=-4092
{in: math.Float32frombits(0xc5002000), out: 0xe801}, // in f32=-2050.000000, out f16=-2050
{in: math.Float32frombits(0xc5001000), out: 0xe800}, // in f32=-2049.000000, out f16=-2048
{in: math.Float32frombits(0xc5000829), out: 0xe800}, // in f32=-2048.510010, out f16=-2048
{in: math.Float32frombits(0xc5000800), out: 0xe800}, // in f32=-2048.500000, out f16=-2048
{in: math.Float32frombits(0xc50007d7), out: 0xe800}, // in f32=-2048.489990, out f16=-2048
{in: math.Float32frombits(0xc5000000), out: 0xe800}, // in f32=-2048.000000, out f16=-2048
{in: math.Float32frombits(0xc4fff052), out: 0xe800}, // in f32=-2047.510010, out f16=-2048
{in: math.Float32frombits(0xc4fff000), out: 0xe800}, // in f32=-2047.500000, out f16=-2048
{in: math.Float32frombits(0xc4ffefae), out: 0xe7ff}, // in f32=-2047.489990, out f16=-2047
{in: math.Float32frombits(0xc4ffe000), out: 0xe7ff}, // in f32=-2047.000000, out f16=-2047
{in: math.Float32frombits(0xc4ffc000), out: 0xe7fe}, // in f32=-2046.000000, out f16=-2046
{in: math.Float32frombits(0xc4ffa000), out: 0xe7fd}, // in f32=-2045.000000, out f16=-2045
{in: math.Float32frombits(0xbf800000), out: 0xbc00}, // in f32=-1.000000, out f16=-1
{in: math.Float32frombits(0xbf028f5c), out: 0xb814}, // in f32=-0.510000, out f16=-0.5097656
{in: math.Float32frombits(0xbf000000), out: 0xb800}, // in f32=-0.500000, out f16=-0.5
{in: math.Float32frombits(0xbefae148), out: 0xb7d7}, // in f32=-0.490000, out f16=-0.48999023
{in: math.Float32frombits(0x3efae148), out: 0x37d7}, // in f32=0.490000, out f16=0.48999023
{in: math.Float32frombits(0x3f000000), out: 0x3800}, // in f32=0.500000, out f16=0.5
{in: math.Float32frombits(0x3f028f5c), out: 0x3814}, // in f32=0.510000, out f16=0.5097656
{in: math.Float32frombits(0x3f800000), out: 0x3c00}, // in f32=1.000000, out f16=1
{in: math.Float32frombits(0x3fbeb852), out: 0x3df6}, // in f32=1.490000, out f16=1.4902344
{in: math.Float32frombits(0x3fc00000), out: 0x3e00}, // in f32=1.500000, out f16=1.5
{in: math.Float32frombits(0x3fc147ae), out: 0x3e0a}, // in f32=1.510000, out f16=1.5097656
{in: math.Float32frombits(0x3fcf1bbd), out: 0x3e79}, // in f32=1.618034, out f16=1.6181641
{in: math.Float32frombits(0x401f5c29), out: 0x40fb}, // in f32=2.490000, out f16=2.4902344
{in: math.Float32frombits(0x40200000), out: 0x4100}, // in f32=2.500000, out f16=2.5
{in: math.Float32frombits(0x4020a3d7), out: 0x4105}, // in f32=2.510000, out f16=2.5097656
{in: math.Float32frombits(0x402df854), out: 0x4170}, // in f32=2.718282, out f16=2.71875
{in: math.Float32frombits(0x40490fdb), out: 0x4248}, // in f32=3.141593, out f16=3.140625
{in: math.Float32frombits(0x40b00000), out: 0x4580}, // in f32=5.500000, out f16=5.5
{in: math.Float32frombits(0x44ffa000), out: 0x67fd}, // in f32=2045.000000, out f16=2045
{in: math.Float32frombits(0x44ffc000), out: 0x67fe}, // in f32=2046.000000, out f16=2046
{in: math.Float32frombits(0x44ffe000), out: 0x67ff}, // in f32=2047.000000, out f16=2047
{in: math.Float32frombits(0x44ffefae), out: 0x67ff}, // in f32=2047.489990, out f16=2047
{in: math.Float32frombits(0x44fff000), out: 0x6800}, // in f32=2047.500000, out f16=2048
{in: math.Float32frombits(0x44fff052), out: 0x6800}, // in f32=2047.510010, out f16=2048
{in: math.Float32frombits(0x45000000), out: 0x6800}, // in f32=2048.000000, out f16=2048
{in: math.Float32frombits(0x450007d7), out: 0x6800}, // in f32=2048.489990, out f16=2048
{in: math.Float32frombits(0x45000800), out: 0x6800}, // in f32=2048.500000, out f16=2048
{in: math.Float32frombits(0x45000829), out: 0x6800}, // in f32=2048.510010, out f16=2048
{in: math.Float32frombits(0x45001000), out: 0x6800}, // in f32=2049.000000, out f16=2048
{in: math.Float32frombits(0x450017d7), out: 0x6801}, // in f32=2049.489990, out f16=2050
{in: math.Float32frombits(0x45001800), out: 0x6801}, // in f32=2049.500000, out f16=2050
{in: math.Float32frombits(0x45001829), out: 0x6801}, // in f32=2049.510010, out f16=2050
{in: math.Float32frombits(0x45002000), out: 0x6801}, // in f32=2050.000000, out f16=2050
{in: math.Float32frombits(0x45003000), out: 0x6802}, // in f32=2051.000000, out f16=2052
{in: math.Float32frombits(0x457fd000), out: 0x6bfe}, // in f32=4093.000000, out f16=4092
{in: math.Float32frombits(0x457fe000), out: 0x6bff}, // in f32=4094.000000, out f16=4094
{in: math.Float32frombits(0x457ff000), out: 0x6c00}, // in f32=4095.000000, out f16=4096
{in: math.Float32frombits(0x45800000), out: 0x6c00}, // in f32=4096.000000, out f16=4096
{in: math.Float32frombits(0x45800800), out: 0x6c00}, // in f32=4097.000000, out f16=4096
{in: math.Float32frombits(0x45801000), out: 0x6c00}, // in f32=4098.000000, out f16=4096
{in: math.Float32frombits(0x45801800), out: 0x6c01}, // in f32=4099.000000, out f16=4100
{in: math.Float32frombits(0x45802000), out: 0x6c01}, // in f32=4100.000000, out f16=4100
{in: math.Float32frombits(0x45ad9c00), out: 0x6d6d}, // in f32=5555.500000, out f16=5556
{in: math.Float32frombits(0x45ffe800), out: 0x6fff}, // in f32=8189.000000, out f16=8188
{in: math.Float32frombits(0x45fff000), out: 0x7000}, // in f32=8190.000000, out f16=8192
{in: math.Float32frombits(0x45fff800), out: 0x7000}, // in f32=8191.000000, out f16=8192
{in: math.Float32frombits(0x46000000), out: 0x7000}, // in f32=8192.000000, out f16=8192
{in: math.Float32frombits(0x46000400), out: 0x7000}, // in f32=8193.000000, out f16=8192
{in: math.Float32frombits(0x46000800), out: 0x7000}, // in f32=8194.000000, out f16=8192
{in: math.Float32frombits(0x46000c00), out: 0x7000}, // in f32=8195.000000, out f16=8192
{in: math.Float32frombits(0x46001000), out: 0x7000}, // in f32=8196.000000, out f16=8192
{in: math.Float32frombits(0x46001400), out: 0x7001}, // in f32=8197.000000, out f16=8200
{in: math.Float32frombits(0x46001800), out: 0x7001}, // in f32=8198.000000, out f16=8200
{in: math.Float32frombits(0x46001c00), out: 0x7001}, // in f32=8199.000000, out f16=8200
{in: math.Float32frombits(0x46002000), out: 0x7001}, // in f32=8200.000000, out f16=8200
{in: math.Float32frombits(0x46002400), out: 0x7001}, // in f32=8201.000000, out f16=8200
{in: math.Float32frombits(0x46002800), out: 0x7001}, // in f32=8202.000000, out f16=8200
{in: math.Float32frombits(0x46002c00), out: 0x7001}, // in f32=8203.000000, out f16=8200
{in: math.Float32frombits(0x46003000), out: 0x7002}, // in f32=8204.000000, out f16=8208
{in: math.Float32frombits(0x467fec00), out: 0x73ff}, // in f32=16379.000000, out f16=16376
{in: math.Float32frombits(0x467ff000), out: 0x7400}, // in f32=16380.000000, out f16=16384
{in: math.Float32frombits(0x467ff400), out: 0x7400}, // in f32=16381.000000, out f16=16384
{in: math.Float32frombits(0x467ff800), out: 0x7400}, // in f32=16382.000000, out f16=16384
{in: math.Float32frombits(0x467ffc00), out: 0x7400}, // in f32=16383.000000, out f16=16384
{in: math.Float32frombits(0x46800000), out: 0x7400}, // in f32=16384.000000, out f16=16384
{in: math.Float32frombits(0x46800200), out: 0x7400}, // in f32=16385.000000, out f16=16384
{in: math.Float32frombits(0x46800400), out: 0x7400}, // in f32=16386.000000, out f16=16384
{in: math.Float32frombits(0x46800600), out: 0x7400}, // in f32=16387.000000, out f16=16384
{in: math.Float32frombits(0x46800800), out: 0x7400}, // in f32=16388.000000, out f16=16384
{in: math.Float32frombits(0x46800a00), out: 0x7400}, // in f32=16389.000000, out f16=16384
{in: math.Float32frombits(0x46800c00), out: 0x7400}, // in f32=16390.000000, out f16=16384
{in: math.Float32frombits(0x46800e00), out: 0x7400}, // in f32=16391.000000, out f16=16384
{in: math.Float32frombits(0x46801000), out: 0x7400}, // in f32=16392.000000, out f16=16384
{in: math.Float32frombits(0x46801200), out: 0x7401}, // in f32=16393.000000, out f16=16400
{in: math.Float32frombits(0x46801400), out: 0x7401}, // in f32=16394.000000, out f16=16400
{in: math.Float32frombits(0x46801600), out: 0x7401}, // in f32=16395.000000, out f16=16400
{in: math.Float32frombits(0x46801800), out: 0x7401}, // in f32=16396.000000, out f16=16400
{in: math.Float32frombits(0x46801a00), out: 0x7401}, // in f32=16397.000000, out f16=16400
{in: math.Float32frombits(0x46801c00), out: 0x7401}, // in f32=16398.000000, out f16=16400
{in: math.Float32frombits(0x46801e00), out: 0x7401}, // in f32=16399.000000, out f16=16400
{in: math.Float32frombits(0x46802000), out: 0x7401}, // in f32=16400.000000, out f16=16400
{in: math.Float32frombits(0x46802200), out: 0x7401}, // in f32=16401.000000, out f16=16400
{in: math.Float32frombits(0x46802400), out: 0x7401}, // in f32=16402.000000, out f16=16400
{in: math.Float32frombits(0x46802600), out: 0x7401}, // in f32=16403.000000, out f16=16400
{in: math.Float32frombits(0x46802800), out: 0x7401}, // in f32=16404.000000, out f16=16400
{in: math.Float32frombits(0x46802a00), out: 0x7401}, // in f32=16405.000000, out f16=16400
{in: math.Float32frombits(0x46802c00), out: 0x7401}, // in f32=16406.000000, out f16=16400
{in: math.Float32frombits(0x46802e00), out: 0x7401}, // in f32=16407.000000, out f16=16400
{in: math.Float32frombits(0x46803000), out: 0x7402}, // in f32=16408.000000, out f16=16416
{in: math.Float32frombits(0x46ffee00), out: 0x77ff}, // in f32=32759.000000, out f16=32752
{in: math.Float32frombits(0x46fff000), out: 0x7800}, // in f32=32760.000000, out f16=32768
{in: math.Float32frombits(0x46fff200), out: 0x7800}, // in f32=32761.000000, out f16=32768
{in: math.Float32frombits(0x46fff400), out: 0x7800}, // in f32=32762.000000, out f16=32768
{in: math.Float32frombits(0x46fff600), out: 0x7800}, // in f32=32763.000000, out f16=32768
{in: math.Float32frombits(0x46fff800), out: 0x7800}, // in f32=32764.000000, out f16=32768
{in: math.Float32frombits(0x46fffa00), out: 0x7800}, // in f32=32765.000000, out f16=32768
{in: math.Float32frombits(0x46fffc00), out: 0x7800}, // in f32=32766.000000, out f16=32768
{in: math.Float32frombits(0x46fffe00), out: 0x7800}, // in f32=32767.000000, out f16=32768
{in: math.Float32frombits(0x47000000), out: 0x7800}, // in f32=32768.000000, out f16=32768
{in: math.Float32frombits(0x47000100), out: 0x7800}, // in f32=32769.000000, out f16=32768
{in: math.Float32frombits(0x47000200), out: 0x7800}, // in f32=32770.000000, out f16=32768
{in: math.Float32frombits(0x47000300), out: 0x7800}, // in f32=32771.000000, out f16=32768
{in: math.Float32frombits(0x47000400), out: 0x7800}, // in f32=32772.000000, out f16=32768
{in: math.Float32frombits(0x47000500), out: 0x7800}, // in f32=32773.000000, out f16=32768
{in: math.Float32frombits(0x47000600), out: 0x7800}, // in f32=32774.000000, out f16=32768
{in: math.Float32frombits(0x47000700), out: 0x7800}, // in f32=32775.000000, out f16=32768
{in: math.Float32frombits(0x47000800), out: 0x7800}, // in f32=32776.000000, out f16=32768
{in: math.Float32frombits(0x47000900), out: 0x7800}, // in f32=32777.000000, out f16=32768
{in: math.Float32frombits(0x47000a00), out: 0x7800}, // in f32=32778.000000, out f16=32768
{in: math.Float32frombits(0x47000b00), out: 0x7800}, // in f32=32779.000000, out f16=32768
{in: math.Float32frombits(0x47000c00), out: 0x7800}, // in f32=32780.000000, out f16=32768
{in: math.Float32frombits(0x47000d00), out: 0x7800}, // in f32=32781.000000, out f16=32768
{in: math.Float32frombits(0x47000e00), out: 0x7800}, // in f32=32782.000000, out f16=32768
{in: math.Float32frombits(0x47000f00), out: 0x7800}, // in f32=32783.000000, out f16=32768
{in: math.Float32frombits(0x47001000), out: 0x7800}, // in f32=32784.000000, out f16=32768
{in: math.Float32frombits(0x47001100), out: 0x7801}, // in f32=32785.000000, out f16=32800
{in: math.Float32frombits(0x47001200), out: 0x7801}, // in f32=32786.000000, out f16=32800
{in: math.Float32frombits(0x47001300), out: 0x7801}, // in f32=32787.000000, out f16=32800
{in: math.Float32frombits(0x47001400), out: 0x7801}, // in f32=32788.000000, out f16=32800
{in: math.Float32frombits(0x47001500), out: 0x7801}, // in f32=32789.000000, out f16=32800
{in: math.Float32frombits(0x47001600), out: 0x7801}, // in f32=32790.000000, out f16=32800
{in: math.Float32frombits(0x47001700), out: 0x7801}, // in f32=32791.000000, out f16=32800
{in: math.Float32frombits(0x47001800), out: 0x7801}, // in f32=32792.000000, out f16=32800
{in: math.Float32frombits(0x47001900), out: 0x7801}, // in f32=32793.000000, out f16=32800
{in: math.Float32frombits(0x47001a00), out: 0x7801}, // in f32=32794.000000, out f16=32800
{in: math.Float32frombits(0x47001b00), out: 0x7801}, // in f32=32795.000000, out f16=32800
{in: math.Float32frombits(0x47001c00), out: 0x7801}, // in f32=32796.000000, out f16=32800
{in: math.Float32frombits(0x47001d00), out: 0x7801}, // in f32=32797.000000, out f16=32800
{in: math.Float32frombits(0x47001e00), out: 0x7801}, // in f32=32798.000000, out f16=32800
{in: math.Float32frombits(0x47001f00), out: 0x7801}, // in f32=32799.000000, out f16=32800
{in: math.Float32frombits(0x47002000), out: 0x7801}, // in f32=32800.000000, out f16=32800
{in: math.Float32frombits(0x47002100), out: 0x7801}, // in f32=32801.000000, out f16=32800
{in: math.Float32frombits(0x47002200), out: 0x7801}, // in f32=32802.000000, out f16=32800
{in: math.Float32frombits(0x47002300), out: 0x7801}, // in f32=32803.000000, out f16=32800
{in: math.Float32frombits(0x47002400), out: 0x7801}, // in f32=32804.000000, out f16=32800
{in: math.Float32frombits(0x47002500), out: 0x7801}, // in f32=32805.000000, out f16=32800
{in: math.Float32frombits(0x47002600), out: 0x7801}, // in f32=32806.000000, out f16=32800
{in: math.Float32frombits(0x47002700), out: 0x7801}, // in f32=32807.000000, out f16=32800
{in: math.Float32frombits(0x47002800), out: 0x7801}, // in f32=32808.000000, out f16=32800
{in: math.Float32frombits(0x47002900), out: 0x7801}, // in f32=32809.000000, out f16=32800
{in: math.Float32frombits(0x47002a00), out: 0x7801}, // in f32=32810.000000, out f16=32800
{in: math.Float32frombits(0x47002b00), out: 0x7801}, // in f32=32811.000000, out f16=32800
{in: math.Float32frombits(0x47002c00), out: 0x7801}, // in f32=32812.000000, out f16=32800
{in: math.Float32frombits(0x47002d00), out: 0x7801}, // in f32=32813.000000, out f16=32800
{in: math.Float32frombits(0x47002e00), out: 0x7801}, // in f32=32814.000000, out f16=32800
{in: math.Float32frombits(0x47002f00), out: 0x7801}, // in f32=32815.000000, out f16=32800
{in: math.Float32frombits(0x47003000), out: 0x7802}, // in f32=32816.000000, out f16=32832
{in: math.Float32frombits(0x477fe500), out: 0x7bff}, // in f32=65509.000000, out f16=65504
{in: math.Float32frombits(0x477fe100), out: 0x7bff}, // in f32=65505.000000, out f16=65504
{in: math.Float32frombits(0x477fee00), out: 0x7bff}, // in f32=65518.000000, out f16=65504
{in: math.Float32frombits(0x477fef00), out: 0x7bff}, // in f32=65519.000000, out f16=65504
{in: math.Float32frombits(0x477feffd), out: 0x7bff}, // in f32=65519.988281, out f16=65504
{in: math.Float32frombits(0x477ff000), out: 0x7c00}, // in f32=65520.000000, out f16=+Inf
}
func TestPrecisionFromfloat32(t *testing.T) {
for i, v := range wantF32toF16bits {
f16 := float16.Fromfloat32(v.in)
u16 := uint16(f16)
if u16 != v.out {
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
}
checkPrecision(t, v.in, f16, uint64(i))
}
f32 := float32(5.5) // value that doesn't drop any bits in the significand, is within normal exponent range
pre := float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionExact {
t.Errorf("f32bits=0x%08x, wanted=PrecisionExact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionExact, pre)
}
f32 = math.Float32frombits(0x38000000) // subnormal value with coef = 0 that can round-trip float32->float16->float32
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnknown {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
}
f32 = math.Float32frombits(0x387fc000) // subnormal value with coef !=0 that can round-trip float32->float16->float32
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnknown {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
}
f32 = math.Float32frombits(0x33c00000) // subnormal value with no dropped bits that cannot round-trip float32->float16->float32
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnknown {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
}
f32 = math.Float32frombits(0x38000001) // subnormal value with dropped non-zero bits > 0
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionInexact {
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
}
f32 = float32(math.Pi) // value that cannot "preserve value" because it drops bits in the significand
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionInexact {
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
}
f32 = math.Float32frombits(0x1) // value that will underflow
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnderflow {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
}
f32 = math.Float32frombits(0x33000000) // value that will underflow
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnderflow {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
}
f32 = math.Float32frombits(0x47800000) // value that will overflow
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionOverflow {
t.Errorf("f32bits=0x%08x, wanted=PrecisionOverflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionOverflow, pre)
}
}
func TestFromNaN32ps(t *testing.T) {
for i, v := range wantF32toF16bits {
f16 := float16.Fromfloat32(v.in)
u16 := uint16(f16)
if u16 != v.out {
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
}
checkFromNaN32ps(t, v.in, f16)
}
// since checkFromNaN32ps rejects non-NaN input, try one here
nan, err := float16.FromNaN32ps(float32(math.Pi))
if err != float16.ErrInvalidNaNValue {
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted err float16.ErrInvalidNaNValue, got err = %q", err)
}
if err.Error() != "float16: invalid NaN value, expected IEEE 754 NaN" {
t.Errorf("unexpected string value returned by err.Error() for ErrInvalidNaNValue: %s", err.Error())
}
if uint16(nan) != 0x7c01 { // signaling NaN
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted nan = 0x7c01, got nan = 0x%04x", uint16(nan))
}
}
// Test a small subset of possible conversions from float32 to Float16.
// TestSomeFromFloat32 runs in under 1 second while TestAllFromFloat32 takes about 45 seconds.
func TestSomeFromFloat32(t *testing.T) {
for i, v := range wantF32toF16bits {
f16 := float16.Fromfloat32(v.in)
u16 := uint16(f16)
if u16 != v.out {
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
}
}
}
// Test all possible 4294967296 float32 input values and results for
// Fromfloat32(), FromNaN32ps(), and PrecisionFromfloat32().
func TestAllFromFloat32(t *testing.T) {
if testing.Short() {
t.Skip("skipping TestAllFromFloat32 in short mode.")
}
fmt.Printf("WARNING: TestAllFromFloat32 should take about 1-2 minutes to run on amd64, other platforms may take longer...\n")
// Blake2b is "3f310bc5608a087462d361644fe66feeb4c68145f6f18eb6f1439cd7914888b6df9e30ae5350dce0635162cc6a2f23b31b3e4353ca132a3c552bdbd58baa54e6"
const wantSHA512 = "08670429a475164d6c4a080969e35231c77ef7069b430b5f38af22e013796b7818bbe8f5942a6ddf26de0e1dfc67d02243f483d85729ebc3762fc2948a5ca1f8"
const batchSize uint32 = 16384
results := make([]uint16, batchSize)
buf := new(bytes.Buffer)
h := sha512.New()
for i := uint64(0); i < uint64(0xFFFFFFFF); i += uint64(batchSize) {
// fill results
for j := uint32(0); j < batchSize; j++ {
inF32 := math.Float32frombits(uint32(i) + j)
f16 := float16.Fromfloat32(inF32)
results[j] = uint16(f16)
checkPrecision(t, inF32, f16, i)
checkFromNaN32ps(t, inF32, f16)
}
// convert results to []byte
err := binary.Write(buf, binary.LittleEndian, results)
if err != nil {
panic(err)
}
// update hash with []byte of results
_, err = h.Write(buf.Bytes())
if err != nil {
panic(err)
}
buf.Reset()
}
// display hash digest in hex
digest := h.Sum(nil)
gotSHA512hex := hex.EncodeToString(digest)
if gotSHA512hex != wantSHA512 {
t.Errorf("gotSHA512hex = %s", gotSHA512hex)
}
}
// Test all 65536 conversions from float16 to float32.
// TestAllToFloat32 runs in under 1 second.
func TestAllToFloat32(t *testing.T) {
// Blake2b is "078d8e3fac9480de1493f22c8f9bfc1eb2051537c536f00f621557d70eed1af057a487c3e252f6d593769f5288d5ab66d8e9cd1adba359838802944bdb731f4d"
const wantSHA512 = "1a4ccec9fd7b6e83310c6b4958a25778cd95f8d4f88b19950e4b8d6932a955f7fbd96b1c9bd9b2a79c3a9d34d653f55e671f8f86e6a5a876660cd38479001aa6"
const batchSize uint32 = 16384
results := make([]float32, batchSize)
buf := new(bytes.Buffer)
h := sha512.New()
for i := uint64(0); i < uint64(0xFFFF); i += uint64(batchSize) {
// fill results
for j := uint32(0); j < batchSize; j++ {
inU16 := uint16(i) + uint16(j)
f16 := float16.Float16(inU16)
results[j] = f16.Float32()
}
// convert results to []byte
err := binary.Write(buf, binary.LittleEndian, results)
if err != nil {
panic(err)
}
// update hash with []byte of results
_, err = h.Write(buf.Bytes())
if err != nil {
panic(err)
}
buf.Reset()
}
// display hash digest in hex
digest := h.Sum(nil)
gotSHA512hex := hex.EncodeToString(digest)
if gotSHA512hex != wantSHA512 {
t.Errorf("Float16toFloat32: gotSHA512hex = %s", gotSHA512hex)
}
}
func TestFrombits(t *testing.T) {
x := uint16(0x1234)
f16 := float16.Frombits(x)
if uint16(f16) != f16.Bits() || uint16(f16) != x {
t.Errorf("float16.Frombits(0x7fff) returned %04x, wanted %04x", uint16(f16), x)
}
}
func TestNaN(t *testing.T) {
nan := float16.NaN()
if !nan.IsNaN() {
t.Errorf("nan.IsNaN() returned false, wanted true")
}
}
func TestInf(t *testing.T) {
posInf := float16.Inf(0)
if uint16(posInf) != 0x7c00 {
t.Errorf("float16.Inf(0) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
}
posInf = float16.Inf(1)
if uint16(posInf) != 0x7c00 {
t.Errorf("float16.Inf(1) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
}
negInf := float16.Inf(-1)
if uint16(negInf) != 0xfc00 {
t.Errorf("float16.Inf(-1) returned %04x, wanted %04x", uint16(negInf), 0xfc00)
}
}
func TestBits(t *testing.T) {
x := uint16(0x1234)
f16 := float16.Frombits(x)
if uint16(f16) != f16.Bits() || f16.Bits() != x {
t.Errorf("Bits() returned %04x, wanted %04x", uint16(f16), x)
}
}
func TestIsFinite(t *testing.T) {
// IsFinite returns true if f is neither infinite nor NaN.
finite := float16.Fromfloat32(float32(1.5))
if !finite.IsFinite() {
t.Errorf("finite.Infinite() returned false, wanted true")
}
posInf := float16.Inf(0)
if posInf.IsFinite() {
t.Errorf("posInf.Infinite() returned true, wanted false")
}
negInf := float16.Inf(-1)
if negInf.IsFinite() {
t.Errorf("negInf.Infinite() returned true, wanted false")
}
nan := float16.NaN()
if nan.IsFinite() {
t.Errorf("nan.Infinite() returned true, wanted false")
}
}
func TestIsNaN(t *testing.T) {
f16 := float16.Float16(0)
if f16.IsNaN() {
t.Errorf("Float16(0).IsNaN() returned true, wanted false")
}
f16 = float16.Float16(0x7e00)
if !f16.IsNaN() {
t.Errorf("Float16(0x7e00).IsNaN() returned false, wanted true")
}
}
func TestIsQuietNaN(t *testing.T) {
f16 := float16.Float16(0)
if f16.IsQuietNaN() {
t.Errorf("Float16(0).IsQuietNaN() returned true, wanted false")
}
f16 = float16.Float16(0x7e00)
if !f16.IsQuietNaN() {
t.Errorf("Float16(0x7e00).IsQuietNaN() returned false, wanted true")
}
f16 = float16.Float16(0x7e00 ^ 0x0200)
if f16.IsQuietNaN() {
t.Errorf("Float16(0x7e00 ^ 0x0200).IsQuietNaN() returned true, wanted false")
}
}
func TestIsNormal(t *testing.T) {
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
zero := float16.Frombits(0)
if zero.IsNormal() {
t.Errorf("zero.IsNormal() returned true, wanted false")
}
posInf := float16.Inf(0)
if posInf.IsNormal() {
t.Errorf("posInf.IsNormal() returned true, wanted false")
}
negInf := float16.Inf(-1)
if negInf.IsNormal() {
t.Errorf("negInf.IsNormal() returned true, wanted false")
}
nan := float16.NaN()
if nan.IsNormal() {
t.Errorf("nan.IsNormal() returned true, wanted false")
}
subnormal := float16.Frombits(0x0001)
if subnormal.IsNormal() {
t.Errorf("subnormal.IsNormal() returned true, wanted false")
}
normal := float16.Fromfloat32(float32(1.5))
if !normal.IsNormal() {
t.Errorf("normal.IsNormal() returned false, wanted true")
}
}
func TestSignbit(t *testing.T) {
f16 := float16.Fromfloat32(float32(0.0))
if f16.Signbit() {
t.Errorf("float16.Fromfloat32(float32(0)).Signbit() returned true, wanted false")
}
f16 = float16.Fromfloat32(float32(2.0))
if f16.Signbit() {
t.Errorf("float16.Fromfloat32(float32(2)).Signbit() returned true, wanted false")
}
f16 = float16.Fromfloat32(float32(-2.0))
if !f16.Signbit() {
t.Errorf("float16.Fromfloat32(float32(-2)).Signbit() returned false, wanted true")
}
}
func TestString(t *testing.T) {
f16 := float16.Fromfloat32(1.5)
s := f16.String()
if s != "1.5" {
t.Errorf("Float16(1.5).String() returned %s, wanted 1.5", s)
}
f16 = float16.Fromfloat32(3.141593)
s = f16.String()
if s != "3.140625" {
t.Errorf("Float16(3.141593).String() returned %s, wanted 3.140625", s)
}
}
func TestIsInf(t *testing.T) {
f16 := float16.Float16(0)
if f16.IsInf(0) {
t.Errorf("Float16(0).IsInf(0) returned true, wanted false")
}
f16 = float16.Float16(0x7c00)
if !f16.IsInf(0) {
t.Errorf("Float16(0x7c00).IsInf(0) returned false, wanted true")
}
f16 = float16.Float16(0x7c00)
if !f16.IsInf(1) {
t.Errorf("Float16(0x7c00).IsInf(1) returned false, wanted true")
}
f16 = float16.Float16(0x7c00)
if f16.IsInf(-1) {
t.Errorf("Float16(0x7c00).IsInf(-1) returned true, wanted false")
}
f16 = float16.Float16(0xfc00)
if !f16.IsInf(0) {
t.Errorf("Float16(0xfc00).IsInf(0) returned false, wanted true")
}
f16 = float16.Float16(0xfc00)
if f16.IsInf(1) {
t.Errorf("Float16(0xfc00).IsInf(1) returned true, wanted false")
}
f16 = float16.Float16(0xfc00)
if !f16.IsInf(-1) {
t.Errorf("Float16(0xfc00).IsInf(-1) returned false, wanted true")
}
}
func float32parts(f32 float32) (exp int32, coef uint32, dropped uint32) {
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
const EXPSHIFT uint32 = 23
const EXPBIAS uint32 = 127
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
const DROPMASK uint32 = COEFMASK >> 10
u32 := math.Float32bits(f32)
exp = int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
coef = u32 & COEFMASK
dropped = coef & DROPMASK
return exp, coef, dropped
}
func isNaN32(f32 float32) bool {
exp, coef, _ := float32parts(f32)
return (exp == 128) && (coef != 0)
}
func isQuietNaN32(f32 float32) bool {
exp, coef, _ := float32parts(f32)
return (exp == 128) && (coef != 0) && ((coef & 0x00400000) != 0)
}
func checkFromNaN32ps(t *testing.T, f32 float32, f16 float16.Float16) {
if !isNaN32(f32) {
return
}
u32 := math.Float32bits(f32)
nan16, err := float16.FromNaN32ps(f32)
if isQuietNaN32(f32) {
// result should be the same
if err != nil {
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
}
if uint16(nan16) != uint16(f16) {
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted nan16 = %v, got nan16 = %v", u32, f32, f16, nan16)
}
} else {
// result should differ only by the signaling/quiet bit unless payload is empty
if err != nil {
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
}
coef := uint16(f16) & uint16(0x03ff)
payload := uint16(f16) & uint16(0x01ff)
diff := uint16(nan16 ^ f16)
if payload == 0 {
// the lowest bit needed to be set to prevent turning sNaN into infinity, so 2 bits differ
if diff != 0x0201 {
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0201, got 0x%04x", u32, f32, diff)
}
} else {
// only the quiet bit was restored, so 1 bit differs
if diff != 0x0200 {
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0200, got 0x%04x. f16=0x%04x n16=0x%04x coef=0x%04x", u32, f32, diff, uint16(f16), uint16(nan16), coef)
}
}
}
}
func checkPrecision(t *testing.T, f32 float32, f16 float16.Float16, i uint64) {
// TODO: rewrite this test when time allows
u32 := math.Float32bits(f32)
u16 := f16.Bits()
f32bis := f16.Float32()
u32bis := math.Float32bits(f32bis)
pre := float16.PrecisionFromfloat32(f32)
roundtripped := u32 == u32bis
exp32, coef32, dropped32 := float32parts(f32)
if roundtripped {
checkRoundTrippedPrecision(t, u32, u16, u32bis, exp32, coef32, dropped32)
return
}
if pre == float16.PrecisionExact {
// this should only happen if both input and output are NaN
if !(f16.IsNaN() && isNaN32(f32)) {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionExact when roundtrip failed with non-special value", i, u32, f32, u16, u32bis, f32bis)
}
} else if pre == float16.PrecisionUnknown {
if exp32 < -24 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionUnderflow", i, u32, f32, u16, u32bis, f32bis)
}
if dropped32 != 0 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionInexact", i, u32, f32, u16, u32bis, f32bis)
}
} else if pre == float16.PrecisionInexact {
checkPrecisionInexact(t, u32, u16, u32bis, exp32, coef32, dropped32)
} else if pre == float16.PrecisionUnderflow {
if exp32 >= -14 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnderflow when exp32 is >= -14", i, u32, f32, u16, u32bis, f32bis)
}
} else if pre == float16.PrecisionOverflow {
if exp32 <= 15 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionOverflow when exp32 is <= 15", i, u32, f32, u16, u32bis, f32bis)
}
}
}
func checkPrecisionInexact(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
f32 := math.Float32frombits(u32)
f32bis := math.Float32frombits(u32bis)
if exp32 < -24 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionUnderflow", u32, f32, u16, u32bis, f32bis)
}
if exp32 > 15 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionOverflow", u32, f32, u16, u32bis, f32bis)
}
if coef32 == 0 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when coef32 is 0", u32, f32, u16, u32bis, f32bis)
}
if dropped32 == 0 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when dropped32 is 0", u32, f32, u16, u32bis, f32bis)
}
}
func checkRoundTrippedPrecision(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
f32 := math.Float32frombits(u32)
f32bis := math.Float32frombits(u32bis)
pre := float16.PrecisionFromfloat32(f32)
f16 := float16.Frombits(u16)
if dropped32 != 0 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), dropped32 != 0 with successful roundtrip", u32, f32, u16, u32bis, f32bis)
}
if pre != float16.PrecisionExact {
// there are 2046 values that are subnormal and can round-trip float32->float16->float32
if pre != float16.PrecisionUnknown {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%032b) (%f), out f16bits=0x%04x (%v), back=0x%08x (%f), got %v, wanted PrecisionExact, exp=%d, coef=%d, drpd=%d", u32, u32, f32, u16, f16, u32bis, f32bis, pre, exp32, coef32, dropped32)
}
}
}

View File

@ -1,6 +1,7 @@
package pickle_test
import (
"fmt"
"log"
"github.com/sugarme/gotch"
@ -18,11 +19,13 @@ func ExampleLoadInfo() {
panic(err)
}
err = pickle.LoadInfo(modelFile)
m, err := pickle.LoadModelInfo(modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Println(m)
// Output:
// classifier.0.bias - [4096]
// classifier.0.weight - [4096 25088]
@ -57,4 +60,25 @@ func ExampleLoadInfo() {
// features.7.bias - [128]
// features.7.weight - [128 128 3 3]
// Num of variables: 32
// Tensor DType: Float
}
func ExampleModelFloat16() {
modelName := "HuggingFaceH4/tiny-random-LlamaForCausalLM"
url := "https://huggingface.co/HuggingFaceH4/tiny-random-LlamaForCausalLM/resolve/main/pytorch_model.bin"
modelFile, err := gotch.CachedPath(url, modelName)
if err != nil {
panic(err)
}
m, err := pickle.LoadModelInfo(modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Model DType: %v\n", m.DType())
// Output:
// Model DType: Half
}

View File

@ -20,6 +20,7 @@ import (
"reflect"
"sort"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/ts"
)
@ -60,83 +61,41 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
dictResult := *result.(*Dict)
for _, item := range dictResult {
name := item.Key
itemTyp := reflect.TypeOf(item.Value).String()
switch itemTyp {
case "*pickle.Dict": // Nested *pickle.Dict case
subResult := *item.Value.(*Dict)
for _, subItem := range subResult {
subName := subItem.Key
x, ok := subItem.Value.(*StorageTensor)
if !ok {
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName)
continue
}
data := x.Source.GetData()
size := x.Size
dtype := x.Source.DType()
device := x.Source.Device()
stride := x.Stride
storageOffset := x.StorageOffset
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if x.RequiresGrad {
x1.MustRequiresGrad_(x.RequiresGrad)
}
namedTensors[name.(string)] = x1
}
default:
sx, isStorageTensor := item.Value.(*StorageTensor)
// if !isStorageTensor {
// err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
// return nil, err
// }
if !isStorageTensor {
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name)
continue
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
sx, isStorageTensor := item.Value.(*StorageTensor)
if !isStorageTensor {
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
return nil, err
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data, ts.WithDType(dtype)).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
case "*pickle.OrderedDict":
dictResult := result.(*OrderedDict)
@ -581,35 +540,74 @@ func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
return missingVariables, nil
}
// LoadInfo loads pretrained weights and prints out name and shape of weights.
func LoadInfo(modelFile string) error {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadInfo() failed: %w", err)
return err
}
type ModelInfor struct {
weights map[string][]int64
dtype gotch.DType
}
layers := make([]string, 0, len(weights))
for tsName := range weights {
func NewModelInfor(weights map[string][]int64, dtype gotch.DType) *ModelInfor {
return &ModelInfor{
weights: weights,
dtype: dtype,
}
}
func (m *ModelInfor) String() string {
var summary string
layers := make([]string, 0, len(m.weights))
for tsName := range m.weights {
layers = append(layers, tsName)
}
sort.Strings(layers)
for _, l := range layers {
var x *ts.Tensor
for tsName, tsVal := range weights {
var x []int64
for tsName, shape := range m.weights {
if tsName == l {
x = tsVal
x = shape
break
}
}
fmt.Printf("%s - %+v\n", l, x.MustSize())
summary += fmt.Sprintf("%s - %+v\n", l, x)
}
fmt.Printf("Num of variables: %v\n", len(weights))
summary += fmt.Sprintf("Num of variables: %v\n", len(m.weights))
summary += fmt.Sprintf("Tensor DType: %v\n", m.dtype)
for _, x := range weights {
x.MustDrop()
}
return nil
return summary
}
func (m *ModelInfor) DType() gotch.DType {
return m.dtype
}
func (m *ModelInfor) Parameters() int {
return len(m.weights)
}
// LoadInfo loads pretrained weights and prints out name and shape of weights.
func LoadModelInfo(modelFile string) (*ModelInfor, error) {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadInfo() failed: %w", err)
return nil, err
}
w := make(map[string][]int64)
var dtype gotch.DType
isFirst := true
for n, x := range weights {
w[n] = x.MustSize()
if isFirst {
dtype = x.DType()
isFirst = false
}
}
m := NewModelInfor(w, dtype)
ts.CleanUp()
return m, nil
}

View File

@ -7,6 +7,7 @@ import (
"math"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/half"
)
// This file implements Pytorch storage data types.
@ -67,7 +68,8 @@ func (s *HalfStorageClass) New(size int, location string) Storage {
type HalfStorage struct {
BaseStorage
Data []float32
// Data []float32
Data []half.Float16
}
var _ Storage = &HalfStorage{}
@ -77,7 +79,7 @@ func (s *HalfStorage) SetFromFile(r io.Reader) error {
}
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]float32, size)
data := make([]half.Float16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
@ -85,7 +87,7 @@ func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
return err
}
u16 := binary.LittleEndian.Uint16(bytes)
data[i] = math.Float32frombits(FloatBits16to32(u16))
data[i] = half.Float16(u16)
}
s.Data = data
return nil
@ -96,7 +98,7 @@ func (s *HalfStorage) GetData() interface{} {
}
func (s *HalfStorage) DType() gotch.DType {
return gotch.Float
return gotch.Half
}
func (s *HalfStorage) Device() gotch.Device {
@ -108,6 +110,62 @@ func (s *HalfStorage) Device() gotch.Device {
}
}
// BFloat16Storage:
// ================
type BFloat16StorageClass struct{}
var _ StorageClass = &BFloat16StorageClass{}
func (s *BFloat16StorageClass) New(size int, location string) Storage {
return &BFloat16Storage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type BFloat16Storage struct {
BaseStorage
Data []half.BFloat16
}
var _ Storage = &BFloat16Storage{}
func (s *BFloat16Storage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *BFloat16Storage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]half.BFloat16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
u16 := binary.LittleEndian.Uint16(bytes)
data[i] = half.BFloat16(u16)
}
s.Data = data
return nil
}
func (s *BFloat16Storage) GetData() interface{} {
return s.Data
}
func (s *BFloat16Storage) DType() gotch.DType {
return gotch.BFloat16
}
func (s *BFloat16Storage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// FloatStorage:
// =============

View File

@ -26,7 +26,7 @@ func (it *Iterable) Next() (item interface{}, ok bool) {
}
var err error
switch it.ItemKind.Kind().String() {
switch it.ItemKind.GoKind().String() {
case "int64":
item, err = it.Content.Int64Value([]int64{it.Index})
if err != nil {

View File

@ -101,22 +101,22 @@ func (h *NpyHeader) ToString() (string, error) {
shape := strings.Join(shapeStr, ",")
var descr string
switch h.descr.Kind().String() {
// case "float16": // NOTE. No float16 in Go primary types. TODO. implement
// descr = "f2"
case "float32":
switch h.descr {
case gotch.Half:
descr = "f2"
case gotch.Float:
descr = "f4"
case "float64":
case gotch.Double:
descr = "f8"
case "int":
case gotch.Int:
descr = "i4"
case "int64":
case gotch.Int64:
descr = "i8"
case "int16":
case gotch.Int16:
descr = "i2"
case "int8":
case gotch.Int8:
descr = "i1"
case "uint8":
case gotch.Uint8:
descr = "u1"
default:
err := fmt.Errorf("Unsupported kind: %v\n", h.descr)

View File

@ -11,41 +11,7 @@ import (
)
func (ts *Tensor) ValueGo() interface{} {
dtype := ts.DType()
numel := ts.Numel()
var dst interface{}
switch dtype {
case gotch.Uint8:
dst = make([]uint8, numel)
case gotch.Int8:
dst = make([]int8, numel)
case gotch.Int16:
dst = make([]int16, numel)
case gotch.Int:
dst = make([]int32, numel)
case gotch.Int64:
dst = make([]int64, numel)
case gotch.Float:
dst = make([]float32, numel)
case gotch.Double:
dst = make([]float64, numel)
case gotch.Bool:
dst = make([]bool, numel)
default:
err := fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
log.Fatal(err)
}
err := ts.CopyData(dst, ts.Numel())
if err != nil {
log.Fatal(err)
}
// convert []int32 -> int
if reflect.TypeOf(dst).String() == "[]int32" {
dst = sliceInt32ToInt(dst.([]int32))
}
return dst
return ts.Vals()
}
func shapeToSize(shape []int64) int {
@ -250,8 +216,8 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
// var typ T
typ := ts.DType()
switch typ.String() {
case "float32", "float64":
switch typ {
case gotch.Half, gotch.BFloat16, gotch.Float, gotch.Double:
switch f.verb {
case 'f', 'e', 'E', 'G', 'b':
// accepted. Do nothing
@ -259,7 +225,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
f.verb = 'g'
}
case "uint8", "int8", "int16", "int32", "int64":
case gotch.Uint8, gotch.Int8, gotch.Int16, gotch.Int, gotch.Int64:
switch f.verb {
case 'b':
f.base = 2
@ -273,7 +239,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
f.base = 10
f.verb = 'd'
}
case "bool":
case gotch.Bool:
f.verb = 't'
default:
f.verb = 'v'
@ -318,7 +284,7 @@ func (ts *Tensor) Format(s fmt.State, verb rune) {
shape := toSliceInt(ts.MustSize())
strides := shapeToStrides(shape)
device := ts.MustDevice()
dtype := ts.DType().String()
dtype := ts.DType()
defined := ts.MustDefined()
if verb == 'i' {
fmt.Fprintf(

View File

@ -273,27 +273,13 @@ func (ts *Tensor) nbytes() int64 {
if numel == 0 {
return 0 // ts.None
}
dtype := ts.DType()
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
log.Fatal(err)
}
nbytes := int64(numel) * int64(eltSizeInBytes)
return nbytes
return int64(numel * ts.DType().Size())
}
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
// Decode sz
// 1. Count number of elements in data
elementNum := nsize
// 2. Element size in bytes
eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64)
if err != nil {
log.Fatal(err)
}
nbytes := int(eltSizeInBytes) * int(elementNum)
dtype := gotch.Int64 // tensor size dtype = int64
nbytes := int(nsize) * int(dtype.Size())
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
r := bytes.NewReader(dataSlice)
dataIn := make([]int64, nsize)
@ -304,8 +290,49 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
return dataIn
}
// TensorOptions constructs options to build/rebuild tensor.
type TensorOptions struct {
Name string
DType gotch.DType
Quantized bool
// TODO. can expand as needed
}
type TensorOpt func(*TensorOptions)
func DefaultTensorOptions() *TensorOptions {
return &TensorOptions{
Name: "",
DType: gotch.Float,
Quantized: false,
}
}
func WithName(v string) TensorOpt {
return func(o *TensorOptions) {
o.Name = v
}
}
func WithDType(v gotch.DType) TensorOpt {
return func(o *TensorOptions) {
o.DType = v
}
}
func WithQuantized(v bool) TensorOpt {
return func(o *TensorOptions) {
o.Quantized = v
}
}
// OfSlice creates tensor from a slice data
func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
if reflect.TypeOf(data).String() == "[]int" {
data = sliceIntToInt32(data.([]int))
@ -318,10 +345,10 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
return nil, err
}
typ := reflect.TypeOf(data).Elem()
elementKind := reflect.TypeOf(data).Elem().Kind()
dataLen := v.Len()
dtype, err := gotch.ToDType(typ)
dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized))
if err != nil {
return nil, err
}
@ -329,12 +356,7 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
nbytes := int(dtype.Size()) * elementNum
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
@ -343,30 +365,25 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
}
return newTensor(ctensor, nameOpt...), nil
return newTensor(ctensor, o.Name), nil
// return newTensor(ctensor), nil
}
// OfDataSize creates Tensor from input byte data, shape and dtype.
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, nameOpt ...string) (*Tensor, error) {
elementNum := ElementCount(shape)
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
nbytes := int(eltSizeInBytes) * int(elementNum)
elementNum := ElementCount(shape)
nbytes := elementNum * int(dtype.Size())
if nbytes != len(data) {
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
@ -380,24 +397,19 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType, nameOpt ...string
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err := TorchErr(); err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
if err = TorchErr(); err != nil {
return nil, err
}
return newTensor(ctensor, nameOpt...), nil
return newTensor(ctensor, o.Name), nil
// return newTensor(ctensor), nil
}
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
// or panic if error
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
ts, err := OfDataSize(data, size, dtype)
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor {
ts, err := OfDataSize(data, size, dtype, opts...)
if err != nil {
log.Fatal(err)
}
@ -406,8 +418,8 @@ func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
}
// MustOfSlice create a tensor from slice of data. It will be panic if error.
func MustOfSlice(data interface{}) *Tensor {
ts, err := OfSlice(data)
func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
log.Fatal(err)
}
@ -416,8 +428,8 @@ func MustOfSlice(data interface{}) *Tensor {
}
// TensorFrom create a tensor from slice of data. It will be panic if error.
func TensorFrom(data interface{}, nameOpt ...string) *Tensor {
ts, err := OfSlice(data, nameOpt...)
func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
log.Fatal(err)
}
@ -436,7 +448,12 @@ func (ts *Tensor) Print() {
}
// NewTensorFromData creates tensor from given data and shape
func NewTensorFromData(data interface{}, shape []int64, nameOpt ...string) (*Tensor, error) {
func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
// 1. Check whether data and shape match
elementNum, err := DataDim(data)
if err != nil {
@ -463,33 +480,18 @@ func NewTensorFromData(data interface{}, shape []int64, nameOpt ...string) (*Ten
return nil, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
}
return newTensor(ctensor, nameOpt...), nil
return newTensor(ctensor, o.Name), nil
}
func (ts *Tensor) DType() gotch.DType {
cint := lib.AtScalarType(ts.ctensor)
dtype, err := gotch.CInt2DType(cint)
if err != nil {
log.Fatalf("Tensor DType error: %v\n", err)
}
return dtype
return gotch.CKind2DType(cint)
}
func (ts *Tensor) Device() (gotch.Device, error) {
@ -545,6 +547,7 @@ func (ts *Tensor) MustDevice() gotch.Device {
* return retVal
* }
* */
// Float64Value returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
@ -748,7 +751,6 @@ func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGra
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
// there will be memory leak and or out of range error.
if len(dst) < int(numel) {
@ -757,12 +759,9 @@ func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
}
vs := unsafe.Pointer(&dst[0])
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
dtype := gotch.Uint8
lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size())
if err := TorchErr(); err != nil {
return err
}
@ -784,55 +783,20 @@ func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
// and number of elements to C land. This may break in the future
// if Go policy changes.
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
gotype, dlen, err := DataCheck(dst)
if err != nil {
return err
}
dtype, err := gotch.ToDType(gotype)
if err != nil {
return err
}
dtype, dlen, err := DataCheck(dst)
if dlen < int(numel) {
err = fmt.Errorf("CopyData Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
return err
}
if ts.DType() != dtype {
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
var vs unsafe.Pointer
switch dtype {
case gotch.Uint8:
vs = unsafe.Pointer(&dst.([]uint8)[0])
case gotch.Int8:
vs = unsafe.Pointer(&dst.([]int8)[0])
case gotch.Int16:
vs = unsafe.Pointer(&dst.([]int16)[0])
case gotch.Int:
vs = unsafe.Pointer(&dst.([]int32)[0])
case gotch.Int64:
vs = unsafe.Pointer(&dst.([]int64)[0])
case gotch.Float:
vs = unsafe.Pointer(&dst.([]float32)[0])
case gotch.Double:
vs = unsafe.Pointer(&dst.([]float64)[0])
case gotch.Bool:
vs = unsafe.Pointer(&dst.([]bool)[0])
default:
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
// Get data pointer
dataPtr := reflect.ValueOf(dst).UnsafePointer()
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size())
if err = TorchErr(); err != nil {
return err
}
@ -863,7 +827,6 @@ func (ts *Tensor) Numel() uint {
// ShallowClone returns a new tensor that share storage with the input tensor.
func (ts *Tensor) ShallowClone() (*Tensor, error) {
ctensor := lib.AtShallowClone(ts.ctensor)
if err := TorchErr(); err != nil {
@ -1309,33 +1272,18 @@ func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
// E.g. res := xs.Vals().([]int64)
func (ts *Tensor) Vals() interface{} {
dtype := ts.DType()
numel := ts.Numel()
numel := int(ts.Numel())
var retVal interface{}
switch dtype.Name() {
case "uint8":
retVal = make([]uint8, numel)
case "int8":
retVal = make([]int8, numel)
case "int16":
retVal = make([]int16, numel)
case "int32":
retVal = make([]int32, numel)
case "int64":
retVal = make([]int64, numel)
case "float32":
retVal = make([]float32, numel)
case "float64":
retVal = make([]float64, numel)
case "bool":
retVal = make([]bool, numel)
default:
log.Fatalf("Unsupported dtype (%v)", dtype)
typ, err := dtype.GoType()
if err != nil {
log.Fatal(err)
}
ts.CopyData(retVal, numel)
return retVal
dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface()
ts.CopyData(dataSlice, uint(numel))
return dataSlice
}
// FlatView flattens a tensor.

View File

@ -75,7 +75,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
if err := w.WriteByte(b); err != nil {
return err
}
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
}
@ -86,14 +86,14 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
if v.Kind() == reflect.Slice {
expected := int(shape[0])
if v.Len() != expected {
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
return fmt.Errorf("EncodeTensor() failed: mismatched slice lengths: %d and %d", v.Len(), expected)
}
}
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
if len(shape) == 1 && v.Len() > 0 {
switch v.Index(0).Kind() {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
return binary.Write(w, nativeEndian, v.Interface())
}
}
@ -107,7 +107,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
}
default:
return fmt.Errorf("unsupported type %v", v.Type())
return fmt.Errorf("EncodeTensor() failed: unsupported type %v", v.Type())
}
return nil
}
@ -122,7 +122,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
return err
}
ptr.Elem().SetBool(b == 1)
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err
}
@ -134,7 +134,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
if len(shape) == 1 && val.Len() > 0 {
switch val.Index(0).Kind() {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
return binary.Read(r, nativeEndian, val.Interface())
}
}
@ -152,10 +152,10 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
}
// ElementCount counts number of element in the tensor given a shape
func ElementCount(shape []int64) int64 {
n := int64(1)
func ElementCount(shape []int64) int {
n := 1
for _, d := range shape {
n *= d
n *= int(d)
}
return n
}
@ -163,54 +163,48 @@ func ElementCount(shape []int64) int64 {
// DataDim returns number of elements in data
// NOTE: only support scalar and (nested) slice/array of scalar type
func DataDim(data interface{}) (retVal int, err error) {
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
return count, err
}
// DataCheck checks the input data for element Go type and number of elements.
// It will return errors if element type is not supported.
func DataCheck(data interface{}) (k reflect.Type, n int, err error) {
// It will return errors if element dtype is not supported.
func DataCheck(data interface{}) (dtype gotch.DType, n int, err error) {
return dataCheck(reflect.ValueOf(data).Interface(), 0)
}
// NOTE: 0 is reflect.Kind() of Invalid
// See: https://golang.org/pkg/reflect/#Kind
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
func dataCheck(data interface{}, count int) (dtype gotch.DType, n int, err error) {
v := reflect.ValueOf(data)
var goType reflect.Type = reflect.TypeOf(data)
var total int = count
var round = 0
switch v.Kind() {
case reflect.Slice, reflect.Array:
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
if round == 0 {
round = v.Len()
}
for i := 0; i < v.Len(); i++ {
round--
goType, total, err = dataCheck(v.Index(i).Interface(), total)
dtype, total, err = dataCheck(v.Index(i).Interface(), total)
if err != nil {
return reflect.TypeOf(reflect.Zero), 0, err
return gotch.Invalid, 0, err
}
}
return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++
if goType.String() != "invalid" {
goType = v.Type()
}
default:
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
return reflect.TypeOf(reflect.Zero), 0, err
return dtype, total, nil
}
return goType, total, nil
total += 1
dtype, err = gotch.GoKind2DType(v.Kind())
if err != nil {
err = fmt.Errorf("DataCheck() failed: unsupported data structure or type: %v\n", v.Kind())
return gotch.Invalid, 0, err
}
return dtype, total, nil
}
// DataAsPtr write to C memory and returns a C pointer.
@ -219,25 +213,19 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
// Supported data types are scalar, slice/array of scalar type equivalent to
// DType.
func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
// 1. Count number of elements in data
elementNum, err := DataDim(data)
if err != nil {
return nil, err
}
// 2. Element size in bytes
// 2. Number of bytes
dtype, err := gotch.DTypeFromData(data)
if err != nil {
return nil, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
nbytes := int(dtype.Size()) * int(elementNum)
// 3. Get C pointer and prepare C memory buffer for writing
dataPtr, buff := CMalloc(nbytes)