reworked gotch.dtype with more dtypes
This commit is contained in:
parent
640af9d2df
commit
523061eca6
624
dtype.go
624
dtype.go
|
@ -3,8 +3,6 @@ package gotch
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
// "log"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,151 +10,148 @@ import (
|
||||||
type CInt = int32
|
type CInt = int32
|
||||||
|
|
||||||
// DType represents different kind of element that a tensor can hold.
|
// DType represents different kind of element that a tensor can hold.
|
||||||
// It has an embedded `reflect.Type` for type reflection.
|
// Ref. https://github.com/pytorch/pytorch/blob/a290cbf32b0c282aa60fa521ca5c6cd19c7f779f/c10/core/ScalarType.h
|
||||||
type DType struct {
|
type DType int
|
||||||
reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
const (
|
||||||
* // Custom-made Float16 as not exist in Go
|
Invalid DType = -1
|
||||||
* // Ref: https://github.com/golang/go/issues/32022
|
Uint8 DType = 0
|
||||||
* type GoFloat16 = int16 // not implemented yet
|
Int8 DType = 1
|
||||||
* type GoComplexHalf = interface{} // not implemented yet!
|
Int16 DType = 2
|
||||||
* */
|
Int DType = 3
|
||||||
|
Int64 DType = 4
|
||||||
// TODO: double check these Torch DType to Go type
|
Half DType = 5
|
||||||
var (
|
Float DType = 6
|
||||||
Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0
|
Double DType = 7
|
||||||
Int8 DType = DType{reflect.TypeOf(int8(1))} // 1
|
ComplexHalf DType = 8
|
||||||
Int16 DType = DType{reflect.TypeOf(int16(1))} // 2
|
ComplexFloat DType = 9
|
||||||
Int DType = DType{reflect.TypeOf(int32(1))} // 3
|
ComplexDouble DType = 10
|
||||||
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
|
Bool DType = 11
|
||||||
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
|
QInt8 DType = 12
|
||||||
Half DType = DType{reflect.TypeOf(float32(1))} // 5
|
QUInt8 DType = 13
|
||||||
Float DType = DType{reflect.TypeOf(float32(1))} // 6
|
QInt32 DType = 14
|
||||||
Double DType = DType{reflect.TypeOf(float64(1))} // 7
|
BFloat16 DType = 15
|
||||||
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
|
// ---not implemented ---
|
||||||
// ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9
|
QUInt4x2 DType = 16
|
||||||
// ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10
|
QUInt2x4 DType = 17
|
||||||
Bool DType = DType{reflect.TypeOf(true)} // 11
|
Bits1x8 DType = 18
|
||||||
|
Bits2x4 DType = 19
|
||||||
|
Bits4x2 DType = 20
|
||||||
|
Bits8 DType = 21
|
||||||
|
Bits16 DType = 22
|
||||||
)
|
)
|
||||||
|
|
||||||
var dtypeGoType = map[DType]reflect.Type{
|
var dtype2CKind = map[DType]CInt{
|
||||||
Uint8: reflect.TypeOf(uint8(1)),
|
Uint8: 0,
|
||||||
Int8: reflect.TypeOf(int8(1)),
|
Int8: 1,
|
||||||
Int16: reflect.TypeOf(int16(1)),
|
Int16: 2,
|
||||||
Int: reflect.TypeOf(int32(1)),
|
Int: 3,
|
||||||
Int64: reflect.TypeOf(int64(1)),
|
Int64: 4,
|
||||||
Half: reflect.TypeOf(float32(1)),
|
Half: 5,
|
||||||
Float: reflect.TypeOf(float32(1)),
|
Float: 6,
|
||||||
Double: reflect.TypeOf(float64(1)),
|
Double: 7,
|
||||||
Bool: reflect.TypeOf(true),
|
ComplexHalf: 8,
|
||||||
|
ComplexFloat: 9,
|
||||||
|
ComplexDouble: 10,
|
||||||
|
Bool: 11,
|
||||||
|
QInt8: 12,
|
||||||
|
QUInt8: 13,
|
||||||
|
QInt32: 14,
|
||||||
|
BFloat16: 15,
|
||||||
|
// ---not implemented ---
|
||||||
|
QUInt4x2: 16,
|
||||||
|
QUInt2x4: 17,
|
||||||
|
Bits1x8: 18,
|
||||||
|
Bits2x4: 19,
|
||||||
|
Bits4x2: 20,
|
||||||
|
Bits8: 21,
|
||||||
|
Bits16: 22,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToDType infers and returns supported equivalent DType from given Go type
|
func (dt DType) CKind() CInt {
|
||||||
func ToDType(typ reflect.Type) (retVal DType, err error) {
|
if cint, ok := dtype2CKind[dt]; ok {
|
||||||
var found = false
|
return cint
|
||||||
for key, val := range dtypeGoType {
|
|
||||||
if val == typ {
|
|
||||||
retVal = key
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !found {
|
if Debug {
|
||||||
err = fmt.Errorf("Unsupported Go type: %v", typ)
|
log.Printf("WARNING: dt.CKind() failed: no corresponding CKind to this DType %v\n", dt)
|
||||||
return DType{}, err
|
|
||||||
}
|
}
|
||||||
|
return -1 // invalid
|
||||||
return retVal, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToGoType infers and returns supported equivalent Go type from given DType
|
// Back compat
|
||||||
func ToGoType(dtype DType) (retVal reflect.Type, err error) {
|
func (dt DType) CInt() CInt {
|
||||||
if _, ok := dtypeGoType[dtype]; !ok {
|
return dt.CKind()
|
||||||
err = fmt.Errorf("Unsupported DType %v", dtype)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
retVal = dtypeGoType[dtype]
|
|
||||||
|
|
||||||
return retVal, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var dtypeCInt = map[DType]CInt{
|
var ckind2DType map[CInt]DType = map[CInt]DType{
|
||||||
Uint8: 0,
|
0: Uint8,
|
||||||
Int8: 1,
|
1: Int8,
|
||||||
Int16: 2,
|
2: Int16,
|
||||||
Int: 3,
|
3: Int,
|
||||||
Int64: 4,
|
4: Int64,
|
||||||
Half: 5,
|
5: Half,
|
||||||
Float: 6,
|
6: Float,
|
||||||
Double: 7,
|
7: Double,
|
||||||
Bool: 11,
|
8: ComplexHalf,
|
||||||
|
9: ComplexFloat,
|
||||||
|
10: ComplexDouble,
|
||||||
|
11: Bool,
|
||||||
|
12: QInt8,
|
||||||
|
13: QUInt8,
|
||||||
|
14: QInt32,
|
||||||
|
15: BFloat16,
|
||||||
|
// ---not implemented ---
|
||||||
|
16: QUInt4x2,
|
||||||
|
17: QUInt2x4,
|
||||||
|
18: Bits1x8,
|
||||||
|
19: Bits2x4,
|
||||||
|
20: Bits4x2,
|
||||||
|
21: Bits8,
|
||||||
|
22: Bits16,
|
||||||
}
|
}
|
||||||
|
|
||||||
func DType2CInt(dt DType) (retVal CInt, err error) {
|
func CKind2DType(ckind int32) DType {
|
||||||
if _, ok := dtypeCInt[dt]; !ok {
|
if dtype, ok := ckind2DType[ckind]; ok {
|
||||||
err = fmt.Errorf("Unsupported CInt conversion from DType: %v\n", dt)
|
return dtype
|
||||||
}
|
}
|
||||||
|
|
||||||
retVal = dtypeCInt[dt]
|
if Debug {
|
||||||
|
log.Printf("WARNING: CKind2DType() failed: no corresponding DType to input CInt %v\n", ckind)
|
||||||
return retVal, nil
|
}
|
||||||
|
return -1 // invalid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dt DType) CInt() (retVal CInt) {
|
var dtypeSize map[DType]uint = map[DType]uint{
|
||||||
retVal, err := DType2CInt(dt)
|
Uint8: 1,
|
||||||
if err != nil {
|
Int8: 1,
|
||||||
log.Fatal(err)
|
Int16: 2,
|
||||||
}
|
Int: 4,
|
||||||
|
Int64: 8,
|
||||||
return retVal
|
Half: 2,
|
||||||
|
Float: 4,
|
||||||
|
Double: 8,
|
||||||
|
ComplexHalf: 4,
|
||||||
|
ComplexFloat: 8,
|
||||||
|
ComplexDouble: 16,
|
||||||
|
Bool: 1,
|
||||||
|
QInt8: 1,
|
||||||
|
QUInt8: 1,
|
||||||
|
QInt32: 4,
|
||||||
|
BFloat16: 2,
|
||||||
|
QUInt4x2: 2,
|
||||||
|
QUInt2x4: 1,
|
||||||
|
// ---not implemented ---
|
||||||
|
Bits1x8: 1,
|
||||||
|
Bits2x4: 1,
|
||||||
|
Bits4x2: 1,
|
||||||
|
Bits8: 1,
|
||||||
|
Bits16: 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
func CInt2DType(v CInt) (dtype DType, err error) {
|
// Size returns dtype size in Bytes.
|
||||||
var found = false
|
func (dt DType) Size() uint {
|
||||||
for key, val := range dtypeCInt {
|
return dtypeSize[dt]
|
||||||
if val == v {
|
|
||||||
dtype = key
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !found {
|
|
||||||
err = fmt.Errorf("Unsuported DType for CInt %v\n", v)
|
|
||||||
return DType{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return dtype, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// dtypeSize is a map of DType and its size in Bytes
|
|
||||||
var dtypeSize = map[DType]uint{
|
|
||||||
Uint8: 1,
|
|
||||||
Int8: 1,
|
|
||||||
Int16: 2,
|
|
||||||
Int: 4,
|
|
||||||
Int64: 8,
|
|
||||||
Half: 4, // Should it be?
|
|
||||||
Float: 4,
|
|
||||||
Double: 8,
|
|
||||||
Bool: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
// DTypeSize returns DType size in Bytes
|
|
||||||
func DTypeSize(dt DType) (retVal uint, err error) {
|
|
||||||
if _, ok := dtypeSize[dt]; !ok {
|
|
||||||
err = fmt.Errorf("Unsupported conversion DType size in Byte for DType: %v\n", dt)
|
|
||||||
return 99, err
|
|
||||||
}
|
|
||||||
|
|
||||||
retVal = dtypeSize[dt]
|
|
||||||
|
|
||||||
return retVal, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type DTypeDevice struct {
|
type DTypeDevice struct {
|
||||||
|
@ -174,201 +169,228 @@ var (
|
||||||
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
|
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
|
||||||
)
|
)
|
||||||
|
|
||||||
// Type Inferring:
|
var dtype2GoKind map[DType]reflect.Kind = map[DType]reflect.Kind{
|
||||||
// ===============
|
Uint8: reflect.Uint8,
|
||||||
|
Int8: reflect.Int8,
|
||||||
// DTypeFromData infers returns equavalent DType from given data
|
Int16: reflect.Int16,
|
||||||
func DTypeFromData(data interface{}) (retVal DType, err error) {
|
Int: reflect.Int32,
|
||||||
|
Int64: reflect.Int64,
|
||||||
// NOTE: call `Interface()` to get data type back to interface{} type
|
Half: reflect.Uint16, // <- just uint16
|
||||||
typ, _, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
Float: reflect.Float32,
|
||||||
if err != nil {
|
Double: reflect.Float64,
|
||||||
return retVal, err
|
ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
||||||
}
|
ComplexFloat: reflect.Complex64,
|
||||||
|
ComplexDouble: reflect.Complex128,
|
||||||
if typ.Kind() == reflect.Slice {
|
Bool: reflect.Bool,
|
||||||
return ToDType(typ.Elem())
|
QInt8: reflect.Int8,
|
||||||
}
|
QUInt8: reflect.Uint8,
|
||||||
|
QInt32: reflect.Int32,
|
||||||
return ToDType(typ)
|
BFloat16: reflect.Uint16, // <- just uint16
|
||||||
|
// ---not implemented ---
|
||||||
|
QUInt4x2: reflect.Invalid,
|
||||||
|
QUInt2x4: reflect.Invalid,
|
||||||
|
Bits1x8: reflect.Invalid,
|
||||||
|
Bits2x4: reflect.Invalid,
|
||||||
|
Bits4x2: reflect.Invalid,
|
||||||
|
Bits8: reflect.Invalid,
|
||||||
|
Bits16: reflect.Invalid,
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: 0 is reflect.Kind() of Invalid
|
func (dt DType) GoKind() reflect.Kind {
|
||||||
// See: https://golang.org/pkg/reflect/#Kind
|
if kind, ok := dtype2GoKind[dt]; ok && kind != reflect.Invalid {
|
||||||
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
return kind
|
||||||
v := reflect.ValueOf(data)
|
}
|
||||||
var goType reflect.Type = reflect.TypeOf(data)
|
|
||||||
var total int = count
|
|
||||||
var round = 0
|
|
||||||
|
|
||||||
switch v.Kind() {
|
if Debug {
|
||||||
case reflect.Slice, reflect.Array:
|
log.Printf("WARNING: DType.GoKind() failed: no corresponding Go reflect.Kind to given DType %v\n", dt)
|
||||||
if round == 0 {
|
}
|
||||||
round = v.Len()
|
|
||||||
|
return reflect.Invalid
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// NOTE. reflect.Kind 0-26
|
||||||
|
QUInt8Kind reflect.Kind = 27
|
||||||
|
QInt8Kind reflect.Kind = 28
|
||||||
|
QInt32Kind reflect.Kind = 29
|
||||||
|
Float16Kind reflect.Kind = 30
|
||||||
|
BFloat16Kind reflect.Kind = 31
|
||||||
|
QUInt4x2Kind reflect.Kind = 32
|
||||||
|
QUInt2x4Kind reflect.Kind = 33
|
||||||
|
Bits1x8Kind reflect.Kind = 34
|
||||||
|
Bits2x4Kind reflect.Kind = 35
|
||||||
|
Bits4x2Kind reflect.Kind = 36
|
||||||
|
Bits8Kind reflect.Kind = 37
|
||||||
|
Bits16Kind reflect.Kind = 38
|
||||||
|
ComplexHalfKind reflect.Kind = 39
|
||||||
|
)
|
||||||
|
|
||||||
|
var goKind2DType map[reflect.Kind]DType = map[reflect.Kind]DType{
|
||||||
|
reflect.Uint8: Uint8,
|
||||||
|
reflect.Int8: Int8,
|
||||||
|
reflect.Int16: Int16,
|
||||||
|
reflect.Int32: Int,
|
||||||
|
reflect.Int64: Int64,
|
||||||
|
reflect.Float32: Float,
|
||||||
|
reflect.Float64: Double,
|
||||||
|
reflect.Complex64: ComplexFloat,
|
||||||
|
reflect.Complex128: ComplexDouble,
|
||||||
|
reflect.Bool: Bool,
|
||||||
|
reflect.Uint16: Half,
|
||||||
|
|
||||||
|
// Added Kinds
|
||||||
|
QUInt8Kind: QUInt8,
|
||||||
|
QInt8Kind: QInt8,
|
||||||
|
QInt32Kind: QInt32,
|
||||||
|
// Float16Kind: Half,
|
||||||
|
BFloat16Kind: BFloat16,
|
||||||
|
QUInt4x2Kind: QUInt4x2,
|
||||||
|
QUInt2x4Kind: QUInt2x4,
|
||||||
|
Bits1x8Kind: Bits1x8,
|
||||||
|
Bits2x4Kind: Bits2x4,
|
||||||
|
Bits4x2Kind: Bits4x2,
|
||||||
|
Bits8Kind: Bits8,
|
||||||
|
Bits16Kind: Bits16,
|
||||||
|
ComplexHalfKind: ComplexHalf,
|
||||||
|
}
|
||||||
|
|
||||||
|
type DTypeOptions struct {
|
||||||
|
HalfDTypePref DType
|
||||||
|
Quantized bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type DTypeOpt func(*DTypeOptions)
|
||||||
|
|
||||||
|
func DefaultDTypeOptions() *DTypeOptions {
|
||||||
|
return &DTypeOptions{
|
||||||
|
HalfDTypePref: Half,
|
||||||
|
Quantized: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func HalfDTypePref(v DType) DTypeOpt {
|
||||||
|
if v != Half && v != BFloat16 {
|
||||||
|
if Debug {
|
||||||
|
log.Printf("WARNING: HalfDTypePref(): Ignoring invalid HalfDTypePref. HalfDTypePref either 'gotch.Half' or 'gotch.BFloat16'. Got %v\n", v)
|
||||||
}
|
}
|
||||||
for i := 0; i < v.Len(); i++ {
|
|
||||||
round--
|
|
||||||
goType, total, err = dataCheck(v.Index(i).Interface(), total)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return reflect.TypeOf(reflect.Zero), 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return goType, total, nil
|
|
||||||
|
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
|
||||||
total++
|
|
||||||
if goType.String() != "invalid" {
|
|
||||||
goType = v.Type()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
|
|
||||||
return reflect.TypeOf(reflect.Zero), 0, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return goType, total, nil
|
return func(o *DTypeOptions) {
|
||||||
}
|
o.HalfDTypePref = v
|
||||||
|
|
||||||
// ElementGoType infers and returns Go type of element in given data
|
|
||||||
func ElementGoType(data interface{}) (retVal reflect.Type, err error) {
|
|
||||||
dataValue := reflect.ValueOf(data)
|
|
||||||
return elementType(dataValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
func elementType(data reflect.Value) (dataType reflect.Type, err error) {
|
|
||||||
dataKind := data.Kind()
|
|
||||||
switch dataKind {
|
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
|
||||||
dataType = data.Type()
|
|
||||||
case reflect.Slice, reflect.Array:
|
|
||||||
data = data.Elem()
|
|
||||||
dataType, err = elementType(data) // recursively type inferring
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
|
|
||||||
return DType{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return dataType, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataDType infers and returns data type of tensor data
|
func WithQuantized(v bool) DTypeOpt {
|
||||||
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
|
return func(o *DTypeOptions) {
|
||||||
// assuming that all elements in data have the same type
|
o.Quantized = v
|
||||||
switch {
|
|
||||||
case len(shape) == 0:
|
|
||||||
retVal, err = ElementDType(v)
|
|
||||||
case len(shape) > 0:
|
|
||||||
return ElementDType(v.([]interface{})[0])
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
|
||||||
return DType{}, err
|
|
||||||
}
|
}
|
||||||
return DType{}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ElementDType infers and returns its own tensor data type
|
func GoKind2DType(kind reflect.Kind, opts ...DTypeOpt) (DType, error) {
|
||||||
func ElementDType(v interface{}) (retVal DType, err error) {
|
o := DefaultDTypeOptions()
|
||||||
switch v.(type) {
|
for _, opt := range opts {
|
||||||
case uint8:
|
opt(o)
|
||||||
retVal = Uint8
|
|
||||||
case int8:
|
|
||||||
retVal = Int8
|
|
||||||
case int16:
|
|
||||||
retVal = Int16
|
|
||||||
case int32:
|
|
||||||
retVal = Int
|
|
||||||
case int64:
|
|
||||||
retVal = Int64
|
|
||||||
case float32:
|
|
||||||
retVal = Float
|
|
||||||
case float64:
|
|
||||||
retVal = Double
|
|
||||||
case bool:
|
|
||||||
retVal = Bool
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
return retVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TypeOf infers and returns element Go type from given tensor DType and shape
|
|
||||||
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
|
||||||
var typ reflect.Type
|
|
||||||
if typ, err = ToGoType(dt); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case len(shape) == 0:
|
case kind == reflect.Uint16 && o.HalfDTypePref == Half:
|
||||||
return typ, nil
|
return Half, nil
|
||||||
case len(shape) > 0:
|
case kind == reflect.Uint16 && o.HalfDTypePref == BFloat16:
|
||||||
return reflect.SliceOf(typ), nil
|
return BFloat16, nil
|
||||||
|
case kind == reflect.Int8 && o.Quantized:
|
||||||
|
return QInt8, nil
|
||||||
|
case kind == reflect.Uint8 && o.Quantized:
|
||||||
|
return QUInt8, nil
|
||||||
|
case kind == reflect.Int32 && o.Quantized:
|
||||||
|
return QInt32, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("Unsupported data type.")
|
dtype, ok := goKind2DType[kind]
|
||||||
return nil, err
|
if !ok {
|
||||||
|
err := fmt.Errorf("GoKind2DType() failed: no corresponding DType to given Go reflect.Kind %v\n", kind)
|
||||||
|
return Invalid, err
|
||||||
|
}
|
||||||
|
return dtype, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
var dtype2GoType map[DType]reflect.Type = map[DType]reflect.Type{
|
||||||
* // TypeCheck checks whether data Go type matching DType
|
Uint8: reflect.TypeOf(uint8(0)),
|
||||||
* func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
Int8: reflect.TypeOf(int8(0)),
|
||||||
* dataValue := reflect.ValueOf(data)
|
Int16: reflect.TypeOf(int16(0)),
|
||||||
* var dataType reflect.Type
|
Int: reflect.TypeOf(int(0)),
|
||||||
* var err error
|
Int64: reflect.TypeOf(int64(0)),
|
||||||
* dataType, err = elementType(dataValue)
|
Half: reflect.TypeOf(uint16(0)), // <- just uint16
|
||||||
* if err != nil {
|
Float: reflect.TypeOf(float32(0)),
|
||||||
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
Double: reflect.TypeOf(float64(0)),
|
||||||
* msg += err.Error()
|
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
||||||
* return false, msg
|
ComplexFloat: reflect.TypeOf(complex64(0)),
|
||||||
* }
|
ComplexDouble: reflect.TypeOf(complex128(0)),
|
||||||
*
|
Bool: reflect.TypeOf(true),
|
||||||
* matched = dataType == dtype.Type
|
QInt8: reflect.TypeOf(int8(0)),
|
||||||
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
QUInt8: reflect.TypeOf(uint8(0)),
|
||||||
*
|
QInt32: reflect.TypeOf(int32(0)),
|
||||||
* return matched, msg
|
BFloat16: reflect.TypeOf(uint16(0)), // <- just uint16
|
||||||
* }
|
// ---not implemented ---
|
||||||
* */
|
QUInt4x2: reflect.TypeOf(int8(0)),
|
||||||
|
QUInt2x4: reflect.TypeOf(uint8(0)),
|
||||||
var supportedTypes = map[reflect.Kind]bool{
|
Bits1x8: reflect.TypeOf(uint8(0)),
|
||||||
reflect.Uint8: true,
|
Bits2x4: reflect.TypeOf(uint8(0)),
|
||||||
reflect.Int8: true,
|
Bits4x2: reflect.TypeOf(uint8(0)),
|
||||||
reflect.Int16: true,
|
Bits8: reflect.TypeOf(uint8(0)),
|
||||||
reflect.Int32: true,
|
Bits16: reflect.TypeOf(uint16(0)),
|
||||||
reflect.Int64: true,
|
|
||||||
reflect.Float32: true,
|
|
||||||
reflect.Float64: true,
|
|
||||||
reflect.Bool: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var scalarTypes = map[reflect.Kind]bool{
|
func (dt DType) GoType() (reflect.Type, error) {
|
||||||
reflect.Bool: true,
|
typ, ok := dtype2GoType[dt]
|
||||||
reflect.Int: true,
|
if !ok {
|
||||||
reflect.Int8: true,
|
err := fmt.Errorf("DType.GoType() failed: no corresponding Go type to given DType %v\n", typ.String())
|
||||||
reflect.Int16: true,
|
return nil, err
|
||||||
reflect.Int32: true,
|
}
|
||||||
reflect.Int64: true,
|
|
||||||
reflect.Uint: true,
|
return typ, nil
|
||||||
reflect.Uint8: true,
|
|
||||||
reflect.Uint16: true,
|
|
||||||
reflect.Uint32: true,
|
|
||||||
reflect.Uint64: true,
|
|
||||||
reflect.Uintptr: true,
|
|
||||||
reflect.Float32: true,
|
|
||||||
reflect.Float64: true,
|
|
||||||
reflect.Complex64: true,
|
|
||||||
reflect.Complex128: true,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSupportedScalar checks whether given SCALAR type is supported
|
var dtypeNames map[DType]string = map[DType]string{
|
||||||
// TODO: check input is a scalar.
|
Uint8: "Uint8",
|
||||||
func IsSupportedScalar(k reflect.Kind) bool {
|
Int8: "Int8",
|
||||||
// if _, ok := scalarTypes[k]; !ok {
|
Int16: "Int16",
|
||||||
// log.Fatalf("Input type: %v is not a Go scalar type.", k)
|
Int: "Int",
|
||||||
// }
|
Int64: "Int64",
|
||||||
|
Half: "Half", // <- just uint16
|
||||||
_, retVal := supportedTypes[k]
|
Float: "Float",
|
||||||
|
Double: "Double",
|
||||||
return retVal
|
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
||||||
|
ComplexFloat: "ComplexFloat",
|
||||||
|
ComplexDouble: "ComplexDouble",
|
||||||
|
Bool: "Bool",
|
||||||
|
QInt8: "QInt8",
|
||||||
|
QUInt8: "QUInt8",
|
||||||
|
QInt32: "QInt32",
|
||||||
|
BFloat16: "BFloat16", // <- just uint16
|
||||||
|
// ---not implemented ---
|
||||||
|
QUInt4x2: "QUInt4x2",
|
||||||
|
QUInt2x4: "QUInt2x4",
|
||||||
|
Bits1x8: "Bits1x8",
|
||||||
|
Bits2x4: "Bits2x4",
|
||||||
|
Bits4x2: "Bits4x2",
|
||||||
|
Bits8: "Bits8",
|
||||||
|
Bits16: "Bits16",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (dt DType) String() string {
|
||||||
|
return dtypeNames[dt]
|
||||||
|
}
|
||||||
|
|
||||||
|
func DTypeFromData(data interface{}) (DType, error) {
|
||||||
|
dataKind := reflect.TypeOf(data).Kind()
|
||||||
|
|
||||||
|
// Data is a slice/array
|
||||||
|
if dataKind == reflect.Slice || dataKind == reflect.Array {
|
||||||
|
elementKind := reflect.TypeOf(data).Elem().Kind()
|
||||||
|
return GoKind2DType(elementKind)
|
||||||
|
}
|
||||||
|
|
||||||
|
// single element
|
||||||
|
return GoKind2DType(dataKind)
|
||||||
}
|
}
|
||||||
|
|
|
@ -113,10 +113,15 @@ var ModelUrls map[string]string = map[string]string{
|
||||||
// 1. Resolves input string to a fullpath cached filename candidate.
|
// 1. Resolves input string to a fullpath cached filename candidate.
|
||||||
// 2. Check it at `CachedDir`, if exists, then return the candidate. If not
|
// 2. Check it at `CachedDir`, if exists, then return the candidate. If not
|
||||||
// 3. Retrieves and Caches data to `CachedDir` and returns path to cached data
|
// 3. Retrieves and Caches data to `CachedDir` and returns path to cached data
|
||||||
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
|
func CachedPath(filenameOrUrl string, folderOpt ...string) (resolvedPath string, err error) {
|
||||||
filename := path.Base(filenameOrUrl)
|
filename := path.Base(filenameOrUrl)
|
||||||
// Resolves to "candidate" filename at `CachedDir`
|
// Resolves to "candidate" filename at `CachedDir`
|
||||||
cachedFileCandidate := fmt.Sprintf("%s/%s", CachedDir, filename)
|
fullPath := CachedDir
|
||||||
|
if len(folderOpt) > 0 {
|
||||||
|
fullPath = fmt.Sprintf("%v/%v", CachedDir, folderOpt[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
cachedFileCandidate := fmt.Sprintf("%s/%s", fullPath, filename)
|
||||||
|
|
||||||
// 1. Cached candidate file exists
|
// 1. Cached candidate file exists
|
||||||
if _, err := os.Stat(cachedFileCandidate); err == nil {
|
if _, err := os.Stat(cachedFileCandidate); err == nil {
|
||||||
|
|
24
go.sum
24
go.sum
|
@ -1,7 +1,31 @@
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||||
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||||
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5 h1:QelT11PB4FXiDEXucrfNckHoFxwt8USGY1ajP1ZF5lM=
|
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5 h1:QelT11PB4FXiDEXucrfNckHoFxwt8USGY1ajP1ZF5lM=
|
||||||
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||||
golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI=
|
golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI=
|
||||||
golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4=
|
golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4=
|
||||||
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||||
|
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
|
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
|
||||||
|
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
|
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||||
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|
167
half/bfloat16.go
Normal file
167
half/bfloat16.go
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
package half
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"math/bits"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A 16-bit floating point type implementing the bfloat16 format.
|
||||||
|
// Ref. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
|
||||||
|
// https://github.com/starkat99/half-rs/tree/main/src/bfloat
|
||||||
|
|
||||||
|
// The bfloat16 - Google 'brain' floating point format is a truncated 16-bit version of the IEEE 754 standard binary32.
|
||||||
|
// bfloat16 has approximately the same dynamic range as float32 (8 bits -> 3.4 × 10^38) by having a lower precision than float16.
|
||||||
|
// While float16 has a precision of 10 bits, bfloat16 has a precision of only 7 bits.
|
||||||
|
//
|
||||||
|
// +------------+------------------------+----------------------------+
|
||||||
|
// | 1-bit sign | 8-bit exponent (range) | 7-bit fraction (precision) |
|
||||||
|
// +------------+------------------------+----------------------------+
|
||||||
|
type BFloat16 uint16
|
||||||
|
|
||||||
|
// Ref.https://github.com/starkat99/half-rs/blob/cabfc74e2a48b44b4556780f9d1550dd50a708be/src/bfloat/convert.rs#L5C1-L24C1
|
||||||
|
func Float32ToBFloat16(value float32) uint16 {
|
||||||
|
// convert to raw bytes
|
||||||
|
x := math.Float32bits(value)
|
||||||
|
|
||||||
|
// Check for NaN
|
||||||
|
if (x & 0x7FFF_FFFF) > 0x7F80_0000 {
|
||||||
|
// keep high part of current mantissa but also set most significant mantissa bit
|
||||||
|
return uint16((x >> 16) | 0x0040)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Round and shift
|
||||||
|
var roundBit uint32 = 0x0000_8000
|
||||||
|
if ((x & roundBit) != 0) && ((x & (3*roundBit - 1)) != 0) {
|
||||||
|
return uint16(x>>16) + 1
|
||||||
|
} else {
|
||||||
|
return uint16(x >> 16)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Float64ToBFloat16(value float64) uint16 {
|
||||||
|
// Convert o raw bytes, truncating the last 32-bits of mantissa
|
||||||
|
// that precision will always be lost on half-precision
|
||||||
|
val := math.Float64bits(value)
|
||||||
|
x := uint32(val >> 32)
|
||||||
|
|
||||||
|
// Extract IEEE754 components
|
||||||
|
sign := x & 0x8000_0000
|
||||||
|
exp := x & 0x7FF0_0000
|
||||||
|
man := x & 0x000F_FFFF
|
||||||
|
|
||||||
|
// Check for all exponent bit being set, which is Infinity or NaN
|
||||||
|
if exp == 0x7FF0_0000 {
|
||||||
|
// Set mantissa MSB for NaN and also keep shifted mantissa bits.
|
||||||
|
// Also check the last 32 bits.
|
||||||
|
var nanBit uint32 = 0x0040
|
||||||
|
if man == 0 && (uint32(val) == 0) {
|
||||||
|
nanBit = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return uint16((sign >> 16) | 0x7F80 | nanBit | (man >> 13))
|
||||||
|
}
|
||||||
|
|
||||||
|
// The number is normalized, start assembling half precision version
|
||||||
|
halfSign := sign >> 16
|
||||||
|
|
||||||
|
// Unbias the exponent, then bias for bfloat16 precision
|
||||||
|
unbiasedExp := (int64(exp>>20) - 1023)
|
||||||
|
halfExp := unbiasedExp + 127
|
||||||
|
|
||||||
|
// Check for exponent overflow, return +infinity
|
||||||
|
if halfExp >= 0xFF {
|
||||||
|
return uint16(halfSign | 0x7F80)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for underflow
|
||||||
|
if halfExp <= 0 {
|
||||||
|
// Check mantissa for what we can do
|
||||||
|
if 7-halfExp > 21 {
|
||||||
|
// No rounding possibility, so this is a full underflow, return signed zero
|
||||||
|
return uint16(halfSign)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Don't forget about hidden leading mantissa bit when assembling mantissa
|
||||||
|
man = man | 0x0010_0000
|
||||||
|
halfMan := man >> (14 - halfExp)
|
||||||
|
|
||||||
|
// Check for rounding
|
||||||
|
var roundBit uint32 = 1 << (13 - halfExp)
|
||||||
|
if ((man & roundBit) != 0) && ((man & (3*roundBit - 1)) != 0) {
|
||||||
|
halfMan += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// No exponent for subnormals
|
||||||
|
return uint16(halfSign | halfMan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebias the exponent
|
||||||
|
halfExp1 := uint32(halfExp) << 7
|
||||||
|
halfMan1 := man >> 13
|
||||||
|
|
||||||
|
// Check for rounding
|
||||||
|
var roundBit1 uint32 = 0x0000_1000
|
||||||
|
|
||||||
|
if ((man & roundBit1) != 0) && ((man & (3*roundBit1 - 1)) != 0) {
|
||||||
|
// Round it
|
||||||
|
return uint16((halfSign | halfExp1 | halfMan1) + 1)
|
||||||
|
} else {
|
||||||
|
return uint16(halfSign | halfExp1 | halfMan1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BFloat16ToFloat32(i uint16) float32 {
|
||||||
|
// If NaN, keep current mantissa but also set most significant mantissa bit
|
||||||
|
if i&0x7FFF > 0x7F80 {
|
||||||
|
return math.Float32frombits((uint32(i) | 0x0040) << 16)
|
||||||
|
} else {
|
||||||
|
return math.Float32frombits(uint32(i) << 16)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BFloat16ToFloat64(i uint16) float64 {
|
||||||
|
// Check for signed zero
|
||||||
|
if i&0x7FFF == 0 {
|
||||||
|
return math.Float64frombits(uint64(i) << 48)
|
||||||
|
}
|
||||||
|
|
||||||
|
halfSign := uint64(i & 0x8000)
|
||||||
|
halfExp := uint64(i & 0x7F80)
|
||||||
|
halfMan := uint64(i & 0x007F)
|
||||||
|
|
||||||
|
// Check for an infinity or NaN when all exponent bits set
|
||||||
|
if halfExp == 0x7F80 {
|
||||||
|
// Check for signed infinity if mantissa is zero
|
||||||
|
if halfMan == 0 {
|
||||||
|
return math.Float64frombits((halfSign << 48) | 0x7FF0_0000_0000_0000)
|
||||||
|
} else {
|
||||||
|
// NaN, keep current mantissa but also set most significant mantissa bit
|
||||||
|
return math.Float64frombits((halfSign << 48) | 0x7FF8_0000_0000_0000 | (halfMan << 45))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate double-precision components with adjusted exponent
|
||||||
|
sign := halfSign << 48
|
||||||
|
|
||||||
|
// Unbias exponent
|
||||||
|
unbiasedExp := (int64(halfExp) >> 7) - 127
|
||||||
|
|
||||||
|
// Check for subnormals, which will be normalized by adjusting exponent
|
||||||
|
if halfExp == 0 {
|
||||||
|
// Calculate how much to adjust the exponent by
|
||||||
|
// leading zeros uint16
|
||||||
|
e := bits.LeadingZeros16(uint16(halfMan)) - 9
|
||||||
|
|
||||||
|
// Rebias and adjust exponent
|
||||||
|
exp := (uint64(1023-127-e) << 52)
|
||||||
|
man := (halfMan << (46 + e)) & 0xF_FFFF_FFFF_FFFF
|
||||||
|
|
||||||
|
return math.Float64frombits(sign | exp | man)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rebias exponent for a normalized normal
|
||||||
|
exp := uint64(unbiasedExp+1023) << 52
|
||||||
|
man := (halfMan & 0x007F) << 45
|
||||||
|
|
||||||
|
return math.Float64frombits(sign | exp | man)
|
||||||
|
}
|
1
half/bfloat16_test.go
Normal file
1
half/bfloat16_test.go
Normal file
|
@ -0,0 +1 @@
|
||||||
|
package half
|
303
half/float16.go
Normal file
303
half/float16.go
Normal file
|
@ -0,0 +1,303 @@
|
||||||
|
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
|
||||||
|
//
|
||||||
|
// Special thanks to Kathryn Long for her Rust implementation
|
||||||
|
// of float16 at github.com/starkat99/half-rs (MIT license)
|
||||||
|
|
||||||
|
// Package half defines support for half-precision floating-point numbers.
|
||||||
|
package half
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Float16 represents IEEE 754 half-precision floating-point numbers (binary16).
|
||||||
|
type Float16 uint16
|
||||||
|
|
||||||
|
// Precision indicates whether the conversion to Float16 is
|
||||||
|
// exact, subnormal without dropped bits, inexact, underflow, or overflow.
|
||||||
|
type Precision int
|
||||||
|
|
||||||
|
const (
|
||||||
|
|
||||||
|
// PrecisionExact is for non-subnormals that don't drop bits during conversion.
|
||||||
|
// All of these can round-trip. Should always convert to float16.
|
||||||
|
PrecisionExact Precision = iota
|
||||||
|
|
||||||
|
// PrecisionUnknown is for subnormals that don't drop bits during conversion but
|
||||||
|
// not all of these can round-trip so precision is unknown without more effort.
|
||||||
|
// Only 2046 of these can round-trip and the rest cannot round-trip.
|
||||||
|
PrecisionUnknown
|
||||||
|
|
||||||
|
// PrecisionInexact is for dropped significand bits and cannot round-trip.
|
||||||
|
// Some of these are subnormals. Cannot round-trip float32->float16->float32.
|
||||||
|
PrecisionInexact
|
||||||
|
|
||||||
|
// PrecisionUnderflow is for Underflows. Cannot round-trip float32->float16->float32.
|
||||||
|
PrecisionUnderflow
|
||||||
|
|
||||||
|
// PrecisionOverflow is for Overflows. Cannot round-trip float32->float16->float32.
|
||||||
|
PrecisionOverflow
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrecisionFromfloat32 returns Precision without performing
|
||||||
|
// the conversion. Conversions from both Infinity and NaN
|
||||||
|
// values will always report PrecisionExact even if NaN payload
|
||||||
|
// or NaN-Quiet-Bit is lost. This function is kept simple to
|
||||||
|
// allow inlining and run < 0.5 ns/op, to serve as a fast filter.
|
||||||
|
func PrecisionFromfloat32(f32 float32) Precision {
|
||||||
|
u32 := math.Float32bits(f32)
|
||||||
|
|
||||||
|
if u32 == 0 || u32 == 0x80000000 {
|
||||||
|
// +- zero will always be exact conversion
|
||||||
|
return PrecisionExact
|
||||||
|
}
|
||||||
|
|
||||||
|
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
|
||||||
|
const EXPSHIFT uint32 = 23
|
||||||
|
const EXPBIAS uint32 = 127
|
||||||
|
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
|
||||||
|
const DROPMASK uint32 = COEFMASK >> 10
|
||||||
|
|
||||||
|
exp := int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
|
||||||
|
coef := u32 & COEFMASK
|
||||||
|
|
||||||
|
if exp == 128 {
|
||||||
|
// +- infinity or NaN
|
||||||
|
// apps may want to do extra checks for NaN separately
|
||||||
|
return PrecisionExact
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format says,
|
||||||
|
// "Decimals between 2^−24 (minimum positive subnormal) and 2^−14 (maximum subnormal): fixed interval 2^−24"
|
||||||
|
if exp < -24 {
|
||||||
|
return PrecisionUnderflow
|
||||||
|
}
|
||||||
|
if exp > 15 {
|
||||||
|
return PrecisionOverflow
|
||||||
|
}
|
||||||
|
if (coef & DROPMASK) != uint32(0) {
|
||||||
|
// these include subnormals and non-subnormals that dropped bits
|
||||||
|
return PrecisionInexact
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp < -14 {
|
||||||
|
// Subnormals. Caller may want to test these further.
|
||||||
|
// There are 2046 subnormals that can successfully round-trip f32->f16->f32
|
||||||
|
// and 20 of those 2046 have 32-bit input coef == 0.
|
||||||
|
// RFC 7049 and 7049bis Draft 12 don't precisely define "preserves value"
|
||||||
|
// so some protocols and libraries will choose to handle subnormals differently
|
||||||
|
// when deciding to encode them to CBOR float32 vs float16.
|
||||||
|
return PrecisionUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
return PrecisionExact
|
||||||
|
}
|
||||||
|
|
||||||
|
// Frombits returns the float16 number corresponding to the IEEE 754 binary16
|
||||||
|
// representation u16, with the sign bit of u16 and the result in the same bit
|
||||||
|
// position. Frombits(Bits(x)) == x.
|
||||||
|
func Frombits(u16 uint16) Float16 {
|
||||||
|
return Float16(u16)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fromfloat32 returns a Float16 value converted from f32. Conversion uses
|
||||||
|
// IEEE default rounding (nearest int, with ties to even).
|
||||||
|
func Fromfloat32(f32 float32) Float16 {
|
||||||
|
return Float16(f32bitsToF16bits(math.Float32bits(f32)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrInvalidNaNValue indicates a NaN was not received.
|
||||||
|
const ErrInvalidNaNValue = float16Error("float16: invalid NaN value, expected IEEE 754 NaN")
|
||||||
|
|
||||||
|
type float16Error string
|
||||||
|
|
||||||
|
func (e float16Error) Error() string { return string(e) }
|
||||||
|
|
||||||
|
// FromNaN32ps converts nan to IEEE binary16 NaN while preserving both
|
||||||
|
// signaling and payload. Unlike Fromfloat32(), which can only return
|
||||||
|
// qNaN because it sets quiet bit = 1, this can return both sNaN and qNaN.
|
||||||
|
// If the result is infinity (sNaN with empty payload), then the
|
||||||
|
// lowest bit of payload is set to make the result a NaN.
|
||||||
|
// Returns ErrInvalidNaNValue and 0x7c01 (sNaN) if nan isn't IEEE 754 NaN.
|
||||||
|
// This function was kept simple to be able to inline.
|
||||||
|
func FromNaN32ps(nan float32) (Float16, error) {
|
||||||
|
const SNAN = Float16(uint16(0x7c01)) // signaling NaN
|
||||||
|
|
||||||
|
u32 := math.Float32bits(nan)
|
||||||
|
sign := u32 & 0x80000000
|
||||||
|
exp := u32 & 0x7f800000
|
||||||
|
coef := u32 & 0x007fffff
|
||||||
|
|
||||||
|
if (exp != 0x7f800000) || (coef == 0) {
|
||||||
|
return SNAN, ErrInvalidNaNValue
|
||||||
|
}
|
||||||
|
|
||||||
|
u16 := uint16((sign >> 16) | uint32(0x7c00) | (coef >> 13))
|
||||||
|
|
||||||
|
if (u16 & 0x03ff) == 0 {
|
||||||
|
// result became infinity, make it NaN by setting lowest bit in payload
|
||||||
|
u16 |= 0x0001
|
||||||
|
}
|
||||||
|
|
||||||
|
return Float16(u16), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NaN returns a Float16 of IEEE 754 binary16 not-a-number (NaN).
|
||||||
|
// Returned NaN value 0x7e01 has all exponent bits = 1 with the
|
||||||
|
// first and last bits = 1 in the significand. This is consistent
|
||||||
|
// with Go's 64-bit math.NaN(). Canonical CBOR in RFC 7049 uses 0x7e00.
|
||||||
|
func NaN() Float16 {
|
||||||
|
return Float16(0x7e01)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inf returns a Float16 with an infinity value with the specified sign.
|
||||||
|
// A sign >= returns positive infinity.
|
||||||
|
// A sign < 0 returns negative infinity.
|
||||||
|
func Inf(sign int) Float16 {
|
||||||
|
if sign >= 0 {
|
||||||
|
return Float16(0x7c00)
|
||||||
|
}
|
||||||
|
return Float16(0x8000 | 0x7c00)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float32 returns a float32 converted from f (Float16).
|
||||||
|
// This is a lossless conversion.
|
||||||
|
func (f Float16) Float32() float32 {
|
||||||
|
u32 := f16bitsToF32bits(uint16(f))
|
||||||
|
return math.Float32frombits(u32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bits returns the IEEE 754 binary16 representation of f, with the sign bit
|
||||||
|
// of f and the result in the same bit position. Bits(Frombits(x)) == x.
|
||||||
|
func (f Float16) Bits() uint16 {
|
||||||
|
return uint16(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNaN reports whether f is an IEEE 754 binary16 “not-a-number” value.
|
||||||
|
func (f Float16) IsNaN() bool {
|
||||||
|
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsQuietNaN reports whether f is a quiet (non-signaling) IEEE 754 binary16
|
||||||
|
// “not-a-number” value.
|
||||||
|
func (f Float16) IsQuietNaN() bool {
|
||||||
|
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0) && (f&0x0200 != 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInf reports whether f is an infinity (inf).
|
||||||
|
// A sign > 0 reports whether f is positive inf.
|
||||||
|
// A sign < 0 reports whether f is negative inf.
|
||||||
|
// A sign == 0 reports whether f is either inf.
|
||||||
|
func (f Float16) IsInf(sign int) bool {
|
||||||
|
return ((f == 0x7c00) && sign >= 0) ||
|
||||||
|
(f == 0xfc00 && sign <= 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsFinite returns true if f is neither infinite nor NaN.
|
||||||
|
func (f Float16) IsFinite() bool {
|
||||||
|
return (uint16(f) & uint16(0x7c00)) != uint16(0x7c00)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
|
||||||
|
func (f Float16) IsNormal() bool {
|
||||||
|
exp := uint16(f) & uint16(0x7c00)
|
||||||
|
return (exp != uint16(0x7c00)) && (exp != 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signbit reports whether f is negative or negative zero.
|
||||||
|
func (f Float16) Signbit() bool {
|
||||||
|
return (uint16(f) & uint16(0x8000)) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// String satisfies the fmt.Stringer interface.
|
||||||
|
func (f Float16) String() string {
|
||||||
|
return strconv.FormatFloat(float64(f.Float32()), 'f', -1, 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
// f16bitsToF32bits returns uint32 (float32 bits) converted from specified uint16.
|
||||||
|
func f16bitsToF32bits(in uint16) uint32 {
|
||||||
|
// All 65536 conversions with this were confirmed to be correct
|
||||||
|
// by Montgomery Edwards⁴⁴⁸ (github.com/x448).
|
||||||
|
|
||||||
|
sign := uint32(in&0x8000) << 16 // sign for 32-bit
|
||||||
|
exp := uint32(in&0x7c00) >> 10 // exponenent for 16-bit
|
||||||
|
coef := uint32(in&0x03ff) << 13 // significand for 32-bit
|
||||||
|
|
||||||
|
if exp == 0x1f {
|
||||||
|
if coef == 0 {
|
||||||
|
// infinity
|
||||||
|
return sign | 0x7f800000 | coef
|
||||||
|
}
|
||||||
|
// NaN
|
||||||
|
return sign | 0x7fc00000 | coef
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp == 0 {
|
||||||
|
if coef == 0 {
|
||||||
|
// zero
|
||||||
|
return sign
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize subnormal numbers
|
||||||
|
exp++
|
||||||
|
for coef&0x7f800000 == 0 {
|
||||||
|
coef <<= 1
|
||||||
|
exp--
|
||||||
|
}
|
||||||
|
coef &= 0x007fffff
|
||||||
|
}
|
||||||
|
|
||||||
|
return sign | ((exp + (0x7f - 0xf)) << 23) | coef
|
||||||
|
}
|
||||||
|
|
||||||
|
// f32bitsToF16bits returns uint16 (Float16 bits) converted from the specified float32.
|
||||||
|
// Conversion rounds to nearest integer with ties to even.
|
||||||
|
func f32bitsToF16bits(u32 uint32) uint16 {
|
||||||
|
// Translated from Rust to Go by Montgomery Edwards⁴⁴⁸ (github.com/x448).
|
||||||
|
// All 4294967296 conversions with this were confirmed to be correct by x448.
|
||||||
|
// Original Rust implementation is by Kathryn Long (github.com/starkat99) with MIT license.
|
||||||
|
|
||||||
|
sign := u32 & 0x80000000
|
||||||
|
exp := u32 & 0x7f800000
|
||||||
|
coef := u32 & 0x007fffff
|
||||||
|
|
||||||
|
if exp == 0x7f800000 {
|
||||||
|
// NaN or Infinity
|
||||||
|
nanBit := uint32(0)
|
||||||
|
if coef != 0 {
|
||||||
|
nanBit = uint32(0x0200)
|
||||||
|
}
|
||||||
|
return uint16((sign >> 16) | uint32(0x7c00) | nanBit | (coef >> 13))
|
||||||
|
}
|
||||||
|
|
||||||
|
halfSign := sign >> 16
|
||||||
|
|
||||||
|
unbiasedExp := int32(exp>>23) - 127
|
||||||
|
halfExp := unbiasedExp + 15
|
||||||
|
|
||||||
|
if halfExp >= 0x1f {
|
||||||
|
return uint16(halfSign | uint32(0x7c00))
|
||||||
|
}
|
||||||
|
|
||||||
|
if halfExp <= 0 {
|
||||||
|
if 14-halfExp > 24 {
|
||||||
|
return uint16(halfSign)
|
||||||
|
}
|
||||||
|
c := coef | uint32(0x00800000)
|
||||||
|
halfCoef := c >> uint32(14-halfExp)
|
||||||
|
roundBit := uint32(1) << uint32(13-halfExp)
|
||||||
|
if (c&roundBit) != 0 && (c&(3*roundBit-1)) != 0 {
|
||||||
|
halfCoef++
|
||||||
|
}
|
||||||
|
return uint16(halfSign | halfCoef)
|
||||||
|
}
|
||||||
|
|
||||||
|
uHalfExp := uint32(halfExp) << 10
|
||||||
|
halfCoef := coef >> 13
|
||||||
|
roundBit := uint32(0x00001000)
|
||||||
|
if (coef&roundBit) != 0 && (coef&(3*roundBit-1)) != 0 {
|
||||||
|
return uint16((halfSign | uHalfExp | halfCoef) + 1)
|
||||||
|
}
|
||||||
|
return uint16(halfSign | uHalfExp | halfCoef)
|
||||||
|
}
|
88
half/float16_bench_test.go
Normal file
88
half/float16_bench_test.go
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
|
||||||
|
|
||||||
|
package half_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
float16 "github.com/sugarme/gotch/half"
|
||||||
|
)
|
||||||
|
|
||||||
|
// prevent compiler optimizing out code by assigning to these
|
||||||
|
var resultF16 float16.Float16
|
||||||
|
var resultF32 float32
|
||||||
|
var resultStr string
|
||||||
|
var pcn float16.Precision
|
||||||
|
|
||||||
|
func BenchmarkFloat32pi(b *testing.B) {
|
||||||
|
result := float32(0)
|
||||||
|
pi32 := float32(math.Pi)
|
||||||
|
pi16 := float16.Fromfloat32(pi32)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
f16 := float16.Frombits(uint16(pi16))
|
||||||
|
result = f16.Float32()
|
||||||
|
}
|
||||||
|
resultF32 = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFrombits(b *testing.B) {
|
||||||
|
result := float16.Float16(0)
|
||||||
|
pi32 := float32(math.Pi)
|
||||||
|
pi16 := float16.Fromfloat32(pi32)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result = float16.Frombits(uint16(pi16))
|
||||||
|
}
|
||||||
|
resultF16 = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFromFloat32pi(b *testing.B) {
|
||||||
|
result := float16.Float16(0)
|
||||||
|
|
||||||
|
pi := float32(math.Pi)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result = float16.Fromfloat32(pi)
|
||||||
|
}
|
||||||
|
resultF16 = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFromFloat32nan(b *testing.B) {
|
||||||
|
result := float16.Float16(0)
|
||||||
|
|
||||||
|
nan := float32(math.NaN())
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result = float16.Fromfloat32(nan)
|
||||||
|
}
|
||||||
|
resultF16 = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFromFloat32subnorm(b *testing.B) {
|
||||||
|
result := float16.Float16(0)
|
||||||
|
|
||||||
|
subnorm := math.Float32frombits(0x007fffff)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result = float16.Fromfloat32(subnorm)
|
||||||
|
}
|
||||||
|
resultF16 = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPrecisionFromFloat32(b *testing.B) {
|
||||||
|
var result float16.Precision
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
f32 := float32(0.00001) + float32(0.00001)
|
||||||
|
result = float16.PrecisionFromfloat32(f32)
|
||||||
|
}
|
||||||
|
pcn = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkString(b *testing.B) {
|
||||||
|
var result string
|
||||||
|
|
||||||
|
pi32 := float32(math.Pi)
|
||||||
|
pi16 := float16.Fromfloat32(pi32)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
result = pi16.String()
|
||||||
|
}
|
||||||
|
resultStr = result
|
||||||
|
}
|
798
half/float16_test.go
Normal file
798
half/float16_test.go
Normal file
|
@ -0,0 +1,798 @@
|
||||||
|
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
|
||||||
|
|
||||||
|
package half_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha512"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
float16 "github.com/sugarme/gotch/half"
|
||||||
|
)
|
||||||
|
|
||||||
|
// wantF32toF16bits is a tiny subset of expected values
|
||||||
|
var wantF32toF16bits = []struct {
|
||||||
|
in float32
|
||||||
|
out uint16
|
||||||
|
}{
|
||||||
|
// generated to provide 100% code coverage plus additional tests for rounding, etc.
|
||||||
|
{in: math.Float32frombits(0x00000000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x00000001), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x00001fff), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x00002000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x00003fff), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x00004000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x007fffff), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x00800000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x33000000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||||
|
{in: math.Float32frombits(0x33000001), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
|
||||||
|
{in: math.Float32frombits(0x33000002), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
|
||||||
|
{in: math.Float32frombits(0x387fc000), out: 0x03ff}, // in f32=0.000061, out f16=0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
|
||||||
|
{in: math.Float32frombits(0x387fffff), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
|
||||||
|
{in: math.Float32frombits(0x38800000), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
|
||||||
|
{in: math.Float32frombits(0x38801fff), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
|
||||||
|
{in: math.Float32frombits(0x38802000), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
|
||||||
|
{in: math.Float32frombits(0x38803fff), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
|
||||||
|
{in: math.Float32frombits(0x38804000), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
|
||||||
|
{in: math.Float32frombits(0x33bfffff), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
|
||||||
|
{in: math.Float32frombits(0x33c00000), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
|
||||||
|
{in: math.Float32frombits(0x33c00001), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
|
||||||
|
{in: math.Float32frombits(0x477fffff), out: 0x7c00}, // in f32=65535.996094, out f16=+Inf
|
||||||
|
{in: math.Float32frombits(0x47800000), out: 0x7c00}, // in f32=65536.000000, out f16=+Inf
|
||||||
|
{in: math.Float32frombits(0x7f7fffff), out: 0x7c00}, // in f32=340282346638528859811704183484516925440.000000, out f16=+Inf
|
||||||
|
{in: math.Float32frombits(0x7f800000), out: 0x7c00}, // in f32=+Inf, out f16=+Inf
|
||||||
|
{in: math.Float32frombits(0x7f801fff), out: 0x7e00}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0x7f802000), out: 0x7e01}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0x7f803fff), out: 0x7e01}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0x7f804000), out: 0x7e02}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0x7fffffff), out: 0x7fff}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0x80000000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||||
|
{in: math.Float32frombits(0x80001fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||||
|
{in: math.Float32frombits(0x80002000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||||
|
{in: math.Float32frombits(0x80003fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||||
|
{in: math.Float32frombits(0x80004000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||||
|
{in: math.Float32frombits(0x807fffff), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||||
|
{in: math.Float32frombits(0x80800000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||||
|
{in: math.Float32frombits(0xb87fc000), out: 0x83ff}, // in f32=-0.000061, out f16=-0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
|
||||||
|
{in: math.Float32frombits(0xb87fffff), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
|
||||||
|
{in: math.Float32frombits(0xb8800000), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
|
||||||
|
{in: math.Float32frombits(0xb8801fff), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
|
||||||
|
{in: math.Float32frombits(0xb8802000), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
|
||||||
|
{in: math.Float32frombits(0xb8803fff), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
|
||||||
|
{in: math.Float32frombits(0xb8804000), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
|
||||||
|
{in: math.Float32frombits(0xc77fffff), out: 0xfc00}, // in f32=-65535.996094, out f16=-Inf
|
||||||
|
{in: math.Float32frombits(0xc7800000), out: 0xfc00}, // in f32=-65536.000000, out f16=-Inf
|
||||||
|
{in: math.Float32frombits(0xff7fffff), out: 0xfc00}, // in f32=-340282346638528859811704183484516925440.000000, out f16=-Inf
|
||||||
|
{in: math.Float32frombits(0xff800000), out: 0xfc00}, // in f32=-Inf, out f16=-Inf
|
||||||
|
{in: math.Float32frombits(0xff801fff), out: 0xfe00}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0xff802000), out: 0xfe01}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0xff803fff), out: 0xfe01}, // in f32=NaN, out f16=NaN
|
||||||
|
{in: math.Float32frombits(0xff804000), out: 0xfe02}, // in f32=NaN, out f16=NaN
|
||||||
|
// additional tests
|
||||||
|
{in: math.Float32frombits(0xc77ff000), out: 0xfc00}, // in f32=-65520.000000, out f16=-Inf
|
||||||
|
{in: math.Float32frombits(0xc77fef00), out: 0xfbff}, // in f32=-65519.000000, out f16=-65504
|
||||||
|
{in: math.Float32frombits(0xc77fee00), out: 0xfbff}, // in f32=-65518.000000, out f16=-65504
|
||||||
|
{in: math.Float32frombits(0xc5802000), out: 0xec01}, // in f32=-4100.000000, out f16=-4100
|
||||||
|
{in: math.Float32frombits(0xc5801800), out: 0xec01}, // in f32=-4099.000000, out f16=-4100
|
||||||
|
{in: math.Float32frombits(0xc5801000), out: 0xec00}, // in f32=-4098.000000, out f16=-4096
|
||||||
|
{in: math.Float32frombits(0xc5800800), out: 0xec00}, // in f32=-4097.000000, out f16=-4096
|
||||||
|
{in: math.Float32frombits(0xc5800000), out: 0xec00}, // in f32=-4096.000000, out f16=-4096
|
||||||
|
{in: math.Float32frombits(0xc57ff000), out: 0xec00}, // in f32=-4095.000000, out f16=-4096
|
||||||
|
{in: math.Float32frombits(0xc57fe000), out: 0xebff}, // in f32=-4094.000000, out f16=-4094
|
||||||
|
{in: math.Float32frombits(0xc57fd000), out: 0xebfe}, // in f32=-4093.000000, out f16=-4092
|
||||||
|
{in: math.Float32frombits(0xc5002000), out: 0xe801}, // in f32=-2050.000000, out f16=-2050
|
||||||
|
{in: math.Float32frombits(0xc5001000), out: 0xe800}, // in f32=-2049.000000, out f16=-2048
|
||||||
|
{in: math.Float32frombits(0xc5000829), out: 0xe800}, // in f32=-2048.510010, out f16=-2048
|
||||||
|
{in: math.Float32frombits(0xc5000800), out: 0xe800}, // in f32=-2048.500000, out f16=-2048
|
||||||
|
{in: math.Float32frombits(0xc50007d7), out: 0xe800}, // in f32=-2048.489990, out f16=-2048
|
||||||
|
{in: math.Float32frombits(0xc5000000), out: 0xe800}, // in f32=-2048.000000, out f16=-2048
|
||||||
|
{in: math.Float32frombits(0xc4fff052), out: 0xe800}, // in f32=-2047.510010, out f16=-2048
|
||||||
|
{in: math.Float32frombits(0xc4fff000), out: 0xe800}, // in f32=-2047.500000, out f16=-2048
|
||||||
|
{in: math.Float32frombits(0xc4ffefae), out: 0xe7ff}, // in f32=-2047.489990, out f16=-2047
|
||||||
|
{in: math.Float32frombits(0xc4ffe000), out: 0xe7ff}, // in f32=-2047.000000, out f16=-2047
|
||||||
|
{in: math.Float32frombits(0xc4ffc000), out: 0xe7fe}, // in f32=-2046.000000, out f16=-2046
|
||||||
|
{in: math.Float32frombits(0xc4ffa000), out: 0xe7fd}, // in f32=-2045.000000, out f16=-2045
|
||||||
|
{in: math.Float32frombits(0xbf800000), out: 0xbc00}, // in f32=-1.000000, out f16=-1
|
||||||
|
{in: math.Float32frombits(0xbf028f5c), out: 0xb814}, // in f32=-0.510000, out f16=-0.5097656
|
||||||
|
{in: math.Float32frombits(0xbf000000), out: 0xb800}, // in f32=-0.500000, out f16=-0.5
|
||||||
|
{in: math.Float32frombits(0xbefae148), out: 0xb7d7}, // in f32=-0.490000, out f16=-0.48999023
|
||||||
|
{in: math.Float32frombits(0x3efae148), out: 0x37d7}, // in f32=0.490000, out f16=0.48999023
|
||||||
|
{in: math.Float32frombits(0x3f000000), out: 0x3800}, // in f32=0.500000, out f16=0.5
|
||||||
|
{in: math.Float32frombits(0x3f028f5c), out: 0x3814}, // in f32=0.510000, out f16=0.5097656
|
||||||
|
{in: math.Float32frombits(0x3f800000), out: 0x3c00}, // in f32=1.000000, out f16=1
|
||||||
|
{in: math.Float32frombits(0x3fbeb852), out: 0x3df6}, // in f32=1.490000, out f16=1.4902344
|
||||||
|
{in: math.Float32frombits(0x3fc00000), out: 0x3e00}, // in f32=1.500000, out f16=1.5
|
||||||
|
{in: math.Float32frombits(0x3fc147ae), out: 0x3e0a}, // in f32=1.510000, out f16=1.5097656
|
||||||
|
{in: math.Float32frombits(0x3fcf1bbd), out: 0x3e79}, // in f32=1.618034, out f16=1.6181641
|
||||||
|
{in: math.Float32frombits(0x401f5c29), out: 0x40fb}, // in f32=2.490000, out f16=2.4902344
|
||||||
|
{in: math.Float32frombits(0x40200000), out: 0x4100}, // in f32=2.500000, out f16=2.5
|
||||||
|
{in: math.Float32frombits(0x4020a3d7), out: 0x4105}, // in f32=2.510000, out f16=2.5097656
|
||||||
|
{in: math.Float32frombits(0x402df854), out: 0x4170}, // in f32=2.718282, out f16=2.71875
|
||||||
|
{in: math.Float32frombits(0x40490fdb), out: 0x4248}, // in f32=3.141593, out f16=3.140625
|
||||||
|
{in: math.Float32frombits(0x40b00000), out: 0x4580}, // in f32=5.500000, out f16=5.5
|
||||||
|
{in: math.Float32frombits(0x44ffa000), out: 0x67fd}, // in f32=2045.000000, out f16=2045
|
||||||
|
{in: math.Float32frombits(0x44ffc000), out: 0x67fe}, // in f32=2046.000000, out f16=2046
|
||||||
|
{in: math.Float32frombits(0x44ffe000), out: 0x67ff}, // in f32=2047.000000, out f16=2047
|
||||||
|
{in: math.Float32frombits(0x44ffefae), out: 0x67ff}, // in f32=2047.489990, out f16=2047
|
||||||
|
{in: math.Float32frombits(0x44fff000), out: 0x6800}, // in f32=2047.500000, out f16=2048
|
||||||
|
{in: math.Float32frombits(0x44fff052), out: 0x6800}, // in f32=2047.510010, out f16=2048
|
||||||
|
{in: math.Float32frombits(0x45000000), out: 0x6800}, // in f32=2048.000000, out f16=2048
|
||||||
|
{in: math.Float32frombits(0x450007d7), out: 0x6800}, // in f32=2048.489990, out f16=2048
|
||||||
|
{in: math.Float32frombits(0x45000800), out: 0x6800}, // in f32=2048.500000, out f16=2048
|
||||||
|
{in: math.Float32frombits(0x45000829), out: 0x6800}, // in f32=2048.510010, out f16=2048
|
||||||
|
{in: math.Float32frombits(0x45001000), out: 0x6800}, // in f32=2049.000000, out f16=2048
|
||||||
|
{in: math.Float32frombits(0x450017d7), out: 0x6801}, // in f32=2049.489990, out f16=2050
|
||||||
|
{in: math.Float32frombits(0x45001800), out: 0x6801}, // in f32=2049.500000, out f16=2050
|
||||||
|
{in: math.Float32frombits(0x45001829), out: 0x6801}, // in f32=2049.510010, out f16=2050
|
||||||
|
{in: math.Float32frombits(0x45002000), out: 0x6801}, // in f32=2050.000000, out f16=2050
|
||||||
|
{in: math.Float32frombits(0x45003000), out: 0x6802}, // in f32=2051.000000, out f16=2052
|
||||||
|
{in: math.Float32frombits(0x457fd000), out: 0x6bfe}, // in f32=4093.000000, out f16=4092
|
||||||
|
{in: math.Float32frombits(0x457fe000), out: 0x6bff}, // in f32=4094.000000, out f16=4094
|
||||||
|
{in: math.Float32frombits(0x457ff000), out: 0x6c00}, // in f32=4095.000000, out f16=4096
|
||||||
|
{in: math.Float32frombits(0x45800000), out: 0x6c00}, // in f32=4096.000000, out f16=4096
|
||||||
|
{in: math.Float32frombits(0x45800800), out: 0x6c00}, // in f32=4097.000000, out f16=4096
|
||||||
|
{in: math.Float32frombits(0x45801000), out: 0x6c00}, // in f32=4098.000000, out f16=4096
|
||||||
|
{in: math.Float32frombits(0x45801800), out: 0x6c01}, // in f32=4099.000000, out f16=4100
|
||||||
|
{in: math.Float32frombits(0x45802000), out: 0x6c01}, // in f32=4100.000000, out f16=4100
|
||||||
|
{in: math.Float32frombits(0x45ad9c00), out: 0x6d6d}, // in f32=5555.500000, out f16=5556
|
||||||
|
{in: math.Float32frombits(0x45ffe800), out: 0x6fff}, // in f32=8189.000000, out f16=8188
|
||||||
|
{in: math.Float32frombits(0x45fff000), out: 0x7000}, // in f32=8190.000000, out f16=8192
|
||||||
|
{in: math.Float32frombits(0x45fff800), out: 0x7000}, // in f32=8191.000000, out f16=8192
|
||||||
|
{in: math.Float32frombits(0x46000000), out: 0x7000}, // in f32=8192.000000, out f16=8192
|
||||||
|
{in: math.Float32frombits(0x46000400), out: 0x7000}, // in f32=8193.000000, out f16=8192
|
||||||
|
{in: math.Float32frombits(0x46000800), out: 0x7000}, // in f32=8194.000000, out f16=8192
|
||||||
|
{in: math.Float32frombits(0x46000c00), out: 0x7000}, // in f32=8195.000000, out f16=8192
|
||||||
|
{in: math.Float32frombits(0x46001000), out: 0x7000}, // in f32=8196.000000, out f16=8192
|
||||||
|
{in: math.Float32frombits(0x46001400), out: 0x7001}, // in f32=8197.000000, out f16=8200
|
||||||
|
{in: math.Float32frombits(0x46001800), out: 0x7001}, // in f32=8198.000000, out f16=8200
|
||||||
|
{in: math.Float32frombits(0x46001c00), out: 0x7001}, // in f32=8199.000000, out f16=8200
|
||||||
|
{in: math.Float32frombits(0x46002000), out: 0x7001}, // in f32=8200.000000, out f16=8200
|
||||||
|
{in: math.Float32frombits(0x46002400), out: 0x7001}, // in f32=8201.000000, out f16=8200
|
||||||
|
{in: math.Float32frombits(0x46002800), out: 0x7001}, // in f32=8202.000000, out f16=8200
|
||||||
|
{in: math.Float32frombits(0x46002c00), out: 0x7001}, // in f32=8203.000000, out f16=8200
|
||||||
|
{in: math.Float32frombits(0x46003000), out: 0x7002}, // in f32=8204.000000, out f16=8208
|
||||||
|
{in: math.Float32frombits(0x467fec00), out: 0x73ff}, // in f32=16379.000000, out f16=16376
|
||||||
|
{in: math.Float32frombits(0x467ff000), out: 0x7400}, // in f32=16380.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x467ff400), out: 0x7400}, // in f32=16381.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x467ff800), out: 0x7400}, // in f32=16382.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x467ffc00), out: 0x7400}, // in f32=16383.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800000), out: 0x7400}, // in f32=16384.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800200), out: 0x7400}, // in f32=16385.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800400), out: 0x7400}, // in f32=16386.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800600), out: 0x7400}, // in f32=16387.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800800), out: 0x7400}, // in f32=16388.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800a00), out: 0x7400}, // in f32=16389.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800c00), out: 0x7400}, // in f32=16390.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46800e00), out: 0x7400}, // in f32=16391.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46801000), out: 0x7400}, // in f32=16392.000000, out f16=16384
|
||||||
|
{in: math.Float32frombits(0x46801200), out: 0x7401}, // in f32=16393.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46801400), out: 0x7401}, // in f32=16394.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46801600), out: 0x7401}, // in f32=16395.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46801800), out: 0x7401}, // in f32=16396.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46801a00), out: 0x7401}, // in f32=16397.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46801c00), out: 0x7401}, // in f32=16398.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46801e00), out: 0x7401}, // in f32=16399.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802000), out: 0x7401}, // in f32=16400.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802200), out: 0x7401}, // in f32=16401.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802400), out: 0x7401}, // in f32=16402.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802600), out: 0x7401}, // in f32=16403.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802800), out: 0x7401}, // in f32=16404.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802a00), out: 0x7401}, // in f32=16405.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802c00), out: 0x7401}, // in f32=16406.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46802e00), out: 0x7401}, // in f32=16407.000000, out f16=16400
|
||||||
|
{in: math.Float32frombits(0x46803000), out: 0x7402}, // in f32=16408.000000, out f16=16416
|
||||||
|
{in: math.Float32frombits(0x46ffee00), out: 0x77ff}, // in f32=32759.000000, out f16=32752
|
||||||
|
{in: math.Float32frombits(0x46fff000), out: 0x7800}, // in f32=32760.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x46fff200), out: 0x7800}, // in f32=32761.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x46fff400), out: 0x7800}, // in f32=32762.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x46fff600), out: 0x7800}, // in f32=32763.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x46fff800), out: 0x7800}, // in f32=32764.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x46fffa00), out: 0x7800}, // in f32=32765.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x46fffc00), out: 0x7800}, // in f32=32766.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x46fffe00), out: 0x7800}, // in f32=32767.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000000), out: 0x7800}, // in f32=32768.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000100), out: 0x7800}, // in f32=32769.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000200), out: 0x7800}, // in f32=32770.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000300), out: 0x7800}, // in f32=32771.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000400), out: 0x7800}, // in f32=32772.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000500), out: 0x7800}, // in f32=32773.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000600), out: 0x7800}, // in f32=32774.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000700), out: 0x7800}, // in f32=32775.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000800), out: 0x7800}, // in f32=32776.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000900), out: 0x7800}, // in f32=32777.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000a00), out: 0x7800}, // in f32=32778.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000b00), out: 0x7800}, // in f32=32779.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000c00), out: 0x7800}, // in f32=32780.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000d00), out: 0x7800}, // in f32=32781.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000e00), out: 0x7800}, // in f32=32782.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47000f00), out: 0x7800}, // in f32=32783.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47001000), out: 0x7800}, // in f32=32784.000000, out f16=32768
|
||||||
|
{in: math.Float32frombits(0x47001100), out: 0x7801}, // in f32=32785.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001200), out: 0x7801}, // in f32=32786.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001300), out: 0x7801}, // in f32=32787.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001400), out: 0x7801}, // in f32=32788.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001500), out: 0x7801}, // in f32=32789.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001600), out: 0x7801}, // in f32=32790.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001700), out: 0x7801}, // in f32=32791.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001800), out: 0x7801}, // in f32=32792.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001900), out: 0x7801}, // in f32=32793.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001a00), out: 0x7801}, // in f32=32794.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001b00), out: 0x7801}, // in f32=32795.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001c00), out: 0x7801}, // in f32=32796.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001d00), out: 0x7801}, // in f32=32797.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001e00), out: 0x7801}, // in f32=32798.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47001f00), out: 0x7801}, // in f32=32799.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002000), out: 0x7801}, // in f32=32800.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002100), out: 0x7801}, // in f32=32801.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002200), out: 0x7801}, // in f32=32802.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002300), out: 0x7801}, // in f32=32803.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002400), out: 0x7801}, // in f32=32804.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002500), out: 0x7801}, // in f32=32805.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002600), out: 0x7801}, // in f32=32806.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002700), out: 0x7801}, // in f32=32807.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002800), out: 0x7801}, // in f32=32808.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002900), out: 0x7801}, // in f32=32809.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002a00), out: 0x7801}, // in f32=32810.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002b00), out: 0x7801}, // in f32=32811.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002c00), out: 0x7801}, // in f32=32812.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002d00), out: 0x7801}, // in f32=32813.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002e00), out: 0x7801}, // in f32=32814.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47002f00), out: 0x7801}, // in f32=32815.000000, out f16=32800
|
||||||
|
{in: math.Float32frombits(0x47003000), out: 0x7802}, // in f32=32816.000000, out f16=32832
|
||||||
|
{in: math.Float32frombits(0x477fe500), out: 0x7bff}, // in f32=65509.000000, out f16=65504
|
||||||
|
{in: math.Float32frombits(0x477fe100), out: 0x7bff}, // in f32=65505.000000, out f16=65504
|
||||||
|
{in: math.Float32frombits(0x477fee00), out: 0x7bff}, // in f32=65518.000000, out f16=65504
|
||||||
|
{in: math.Float32frombits(0x477fef00), out: 0x7bff}, // in f32=65519.000000, out f16=65504
|
||||||
|
{in: math.Float32frombits(0x477feffd), out: 0x7bff}, // in f32=65519.988281, out f16=65504
|
||||||
|
{in: math.Float32frombits(0x477ff000), out: 0x7c00}, // in f32=65520.000000, out f16=+Inf
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrecisionFromfloat32(t *testing.T) {
|
||||||
|
for i, v := range wantF32toF16bits {
|
||||||
|
f16 := float16.Fromfloat32(v.in)
|
||||||
|
u16 := uint16(f16)
|
||||||
|
|
||||||
|
if u16 != v.out {
|
||||||
|
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkPrecision(t, v.in, f16, uint64(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 := float32(5.5) // value that doesn't drop any bits in the significand, is within normal exponent range
|
||||||
|
pre := float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionExact {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionExact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionExact, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = math.Float32frombits(0x38000000) // subnormal value with coef = 0 that can round-trip float32->float16->float32
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionUnknown {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = math.Float32frombits(0x387fc000) // subnormal value with coef !=0 that can round-trip float32->float16->float32
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionUnknown {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = math.Float32frombits(0x33c00000) // subnormal value with no dropped bits that cannot round-trip float32->float16->float32
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionUnknown {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = math.Float32frombits(0x38000001) // subnormal value with dropped non-zero bits > 0
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionInexact {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = float32(math.Pi) // value that cannot "preserve value" because it drops bits in the significand
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionInexact {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = math.Float32frombits(0x1) // value that will underflow
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionUnderflow {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = math.Float32frombits(0x33000000) // value that will underflow
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionUnderflow {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32 = math.Float32frombits(0x47800000) // value that will overflow
|
||||||
|
pre = float16.PrecisionFromfloat32(f32)
|
||||||
|
if pre != float16.PrecisionOverflow {
|
||||||
|
t.Errorf("f32bits=0x%08x, wanted=PrecisionOverflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionOverflow, pre)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFromNaN32ps(t *testing.T) {
|
||||||
|
for i, v := range wantF32toF16bits {
|
||||||
|
f16 := float16.Fromfloat32(v.in)
|
||||||
|
u16 := uint16(f16)
|
||||||
|
|
||||||
|
if u16 != v.out {
|
||||||
|
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
|
||||||
|
}
|
||||||
|
|
||||||
|
checkFromNaN32ps(t, v.in, f16)
|
||||||
|
}
|
||||||
|
|
||||||
|
// since checkFromNaN32ps rejects non-NaN input, try one here
|
||||||
|
nan, err := float16.FromNaN32ps(float32(math.Pi))
|
||||||
|
if err != float16.ErrInvalidNaNValue {
|
||||||
|
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted err float16.ErrInvalidNaNValue, got err = %q", err)
|
||||||
|
}
|
||||||
|
if err.Error() != "float16: invalid NaN value, expected IEEE 754 NaN" {
|
||||||
|
t.Errorf("unexpected string value returned by err.Error() for ErrInvalidNaNValue: %s", err.Error())
|
||||||
|
}
|
||||||
|
if uint16(nan) != 0x7c01 { // signaling NaN
|
||||||
|
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted nan = 0x7c01, got nan = 0x%04x", uint16(nan))
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test a small subset of possible conversions from float32 to Float16.
|
||||||
|
// TestSomeFromFloat32 runs in under 1 second while TestAllFromFloat32 takes about 45 seconds.
|
||||||
|
func TestSomeFromFloat32(t *testing.T) {
|
||||||
|
|
||||||
|
for i, v := range wantF32toF16bits {
|
||||||
|
f16 := float16.Fromfloat32(v.in)
|
||||||
|
u16 := uint16(f16)
|
||||||
|
|
||||||
|
if u16 != v.out {
|
||||||
|
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test all possible 4294967296 float32 input values and results for
|
||||||
|
// Fromfloat32(), FromNaN32ps(), and PrecisionFromfloat32().
|
||||||
|
func TestAllFromFloat32(t *testing.T) {
|
||||||
|
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping TestAllFromFloat32 in short mode.")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("WARNING: TestAllFromFloat32 should take about 1-2 minutes to run on amd64, other platforms may take longer...\n")
|
||||||
|
|
||||||
|
// Blake2b is "3f310bc5608a087462d361644fe66feeb4c68145f6f18eb6f1439cd7914888b6df9e30ae5350dce0635162cc6a2f23b31b3e4353ca132a3c552bdbd58baa54e6"
|
||||||
|
const wantSHA512 = "08670429a475164d6c4a080969e35231c77ef7069b430b5f38af22e013796b7818bbe8f5942a6ddf26de0e1dfc67d02243f483d85729ebc3762fc2948a5ca1f8"
|
||||||
|
|
||||||
|
const batchSize uint32 = 16384
|
||||||
|
results := make([]uint16, batchSize)
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
h := sha512.New()
|
||||||
|
|
||||||
|
for i := uint64(0); i < uint64(0xFFFFFFFF); i += uint64(batchSize) {
|
||||||
|
// fill results
|
||||||
|
for j := uint32(0); j < batchSize; j++ {
|
||||||
|
inF32 := math.Float32frombits(uint32(i) + j)
|
||||||
|
f16 := float16.Fromfloat32(inF32)
|
||||||
|
results[j] = uint16(f16)
|
||||||
|
checkPrecision(t, inF32, f16, i)
|
||||||
|
checkFromNaN32ps(t, inF32, f16)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert results to []byte
|
||||||
|
err := binary.Write(buf, binary.LittleEndian, results)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// update hash with []byte of results
|
||||||
|
_, err = h.Write(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// display hash digest in hex
|
||||||
|
digest := h.Sum(nil)
|
||||||
|
gotSHA512hex := hex.EncodeToString(digest)
|
||||||
|
if gotSHA512hex != wantSHA512 {
|
||||||
|
t.Errorf("gotSHA512hex = %s", gotSHA512hex)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test all 65536 conversions from float16 to float32.
|
||||||
|
// TestAllToFloat32 runs in under 1 second.
|
||||||
|
func TestAllToFloat32(t *testing.T) {
|
||||||
|
// Blake2b is "078d8e3fac9480de1493f22c8f9bfc1eb2051537c536f00f621557d70eed1af057a487c3e252f6d593769f5288d5ab66d8e9cd1adba359838802944bdb731f4d"
|
||||||
|
const wantSHA512 = "1a4ccec9fd7b6e83310c6b4958a25778cd95f8d4f88b19950e4b8d6932a955f7fbd96b1c9bd9b2a79c3a9d34d653f55e671f8f86e6a5a876660cd38479001aa6"
|
||||||
|
const batchSize uint32 = 16384
|
||||||
|
results := make([]float32, batchSize)
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
h := sha512.New()
|
||||||
|
|
||||||
|
for i := uint64(0); i < uint64(0xFFFF); i += uint64(batchSize) {
|
||||||
|
// fill results
|
||||||
|
for j := uint32(0); j < batchSize; j++ {
|
||||||
|
inU16 := uint16(i) + uint16(j)
|
||||||
|
f16 := float16.Float16(inU16)
|
||||||
|
results[j] = f16.Float32()
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert results to []byte
|
||||||
|
err := binary.Write(buf, binary.LittleEndian, results)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// update hash with []byte of results
|
||||||
|
_, err = h.Write(buf.Bytes())
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
// display hash digest in hex
|
||||||
|
digest := h.Sum(nil)
|
||||||
|
gotSHA512hex := hex.EncodeToString(digest)
|
||||||
|
if gotSHA512hex != wantSHA512 {
|
||||||
|
t.Errorf("Float16toFloat32: gotSHA512hex = %s", gotSHA512hex)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFrombits(t *testing.T) {
|
||||||
|
x := uint16(0x1234)
|
||||||
|
f16 := float16.Frombits(x)
|
||||||
|
if uint16(f16) != f16.Bits() || uint16(f16) != x {
|
||||||
|
t.Errorf("float16.Frombits(0x7fff) returned %04x, wanted %04x", uint16(f16), x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNaN(t *testing.T) {
|
||||||
|
nan := float16.NaN()
|
||||||
|
if !nan.IsNaN() {
|
||||||
|
t.Errorf("nan.IsNaN() returned false, wanted true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInf(t *testing.T) {
|
||||||
|
posInf := float16.Inf(0)
|
||||||
|
if uint16(posInf) != 0x7c00 {
|
||||||
|
t.Errorf("float16.Inf(0) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
|
||||||
|
}
|
||||||
|
|
||||||
|
posInf = float16.Inf(1)
|
||||||
|
if uint16(posInf) != 0x7c00 {
|
||||||
|
t.Errorf("float16.Inf(1) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
|
||||||
|
}
|
||||||
|
|
||||||
|
negInf := float16.Inf(-1)
|
||||||
|
if uint16(negInf) != 0xfc00 {
|
||||||
|
t.Errorf("float16.Inf(-1) returned %04x, wanted %04x", uint16(negInf), 0xfc00)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBits(t *testing.T) {
|
||||||
|
x := uint16(0x1234)
|
||||||
|
f16 := float16.Frombits(x)
|
||||||
|
if uint16(f16) != f16.Bits() || f16.Bits() != x {
|
||||||
|
t.Errorf("Bits() returned %04x, wanted %04x", uint16(f16), x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsFinite(t *testing.T) {
|
||||||
|
// IsFinite returns true if f is neither infinite nor NaN.
|
||||||
|
|
||||||
|
finite := float16.Fromfloat32(float32(1.5))
|
||||||
|
if !finite.IsFinite() {
|
||||||
|
t.Errorf("finite.Infinite() returned false, wanted true")
|
||||||
|
}
|
||||||
|
|
||||||
|
posInf := float16.Inf(0)
|
||||||
|
if posInf.IsFinite() {
|
||||||
|
t.Errorf("posInf.Infinite() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
negInf := float16.Inf(-1)
|
||||||
|
if negInf.IsFinite() {
|
||||||
|
t.Errorf("negInf.Infinite() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
nan := float16.NaN()
|
||||||
|
if nan.IsFinite() {
|
||||||
|
t.Errorf("nan.Infinite() returned true, wanted false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsNaN(t *testing.T) {
|
||||||
|
|
||||||
|
f16 := float16.Float16(0)
|
||||||
|
if f16.IsNaN() {
|
||||||
|
t.Errorf("Float16(0).IsNaN() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0x7e00)
|
||||||
|
if !f16.IsNaN() {
|
||||||
|
t.Errorf("Float16(0x7e00).IsNaN() returned false, wanted true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsQuietNaN(t *testing.T) {
|
||||||
|
|
||||||
|
f16 := float16.Float16(0)
|
||||||
|
if f16.IsQuietNaN() {
|
||||||
|
t.Errorf("Float16(0).IsQuietNaN() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0x7e00)
|
||||||
|
if !f16.IsQuietNaN() {
|
||||||
|
t.Errorf("Float16(0x7e00).IsQuietNaN() returned false, wanted true")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0x7e00 ^ 0x0200)
|
||||||
|
if f16.IsQuietNaN() {
|
||||||
|
t.Errorf("Float16(0x7e00 ^ 0x0200).IsQuietNaN() returned true, wanted false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsNormal(t *testing.T) {
|
||||||
|
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
|
||||||
|
|
||||||
|
zero := float16.Frombits(0)
|
||||||
|
if zero.IsNormal() {
|
||||||
|
t.Errorf("zero.IsNormal() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
posInf := float16.Inf(0)
|
||||||
|
if posInf.IsNormal() {
|
||||||
|
t.Errorf("posInf.IsNormal() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
negInf := float16.Inf(-1)
|
||||||
|
if negInf.IsNormal() {
|
||||||
|
t.Errorf("negInf.IsNormal() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
nan := float16.NaN()
|
||||||
|
if nan.IsNormal() {
|
||||||
|
t.Errorf("nan.IsNormal() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
subnormal := float16.Frombits(0x0001)
|
||||||
|
if subnormal.IsNormal() {
|
||||||
|
t.Errorf("subnormal.IsNormal() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
normal := float16.Fromfloat32(float32(1.5))
|
||||||
|
if !normal.IsNormal() {
|
||||||
|
t.Errorf("normal.IsNormal() returned false, wanted true")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignbit(t *testing.T) {
|
||||||
|
|
||||||
|
f16 := float16.Fromfloat32(float32(0.0))
|
||||||
|
if f16.Signbit() {
|
||||||
|
t.Errorf("float16.Fromfloat32(float32(0)).Signbit() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Fromfloat32(float32(2.0))
|
||||||
|
if f16.Signbit() {
|
||||||
|
t.Errorf("float16.Fromfloat32(float32(2)).Signbit() returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Fromfloat32(float32(-2.0))
|
||||||
|
if !f16.Signbit() {
|
||||||
|
t.Errorf("float16.Fromfloat32(float32(-2)).Signbit() returned false, wanted true")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestString(t *testing.T) {
|
||||||
|
f16 := float16.Fromfloat32(1.5)
|
||||||
|
s := f16.String()
|
||||||
|
if s != "1.5" {
|
||||||
|
t.Errorf("Float16(1.5).String() returned %s, wanted 1.5", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Fromfloat32(3.141593)
|
||||||
|
s = f16.String()
|
||||||
|
if s != "3.140625" {
|
||||||
|
t.Errorf("Float16(3.141593).String() returned %s, wanted 3.140625", s)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsInf(t *testing.T) {
|
||||||
|
|
||||||
|
f16 := float16.Float16(0)
|
||||||
|
if f16.IsInf(0) {
|
||||||
|
t.Errorf("Float16(0).IsInf(0) returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0x7c00)
|
||||||
|
if !f16.IsInf(0) {
|
||||||
|
t.Errorf("Float16(0x7c00).IsInf(0) returned false, wanted true")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0x7c00)
|
||||||
|
if !f16.IsInf(1) {
|
||||||
|
t.Errorf("Float16(0x7c00).IsInf(1) returned false, wanted true")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0x7c00)
|
||||||
|
if f16.IsInf(-1) {
|
||||||
|
t.Errorf("Float16(0x7c00).IsInf(-1) returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0xfc00)
|
||||||
|
if !f16.IsInf(0) {
|
||||||
|
t.Errorf("Float16(0xfc00).IsInf(0) returned false, wanted true")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0xfc00)
|
||||||
|
if f16.IsInf(1) {
|
||||||
|
t.Errorf("Float16(0xfc00).IsInf(1) returned true, wanted false")
|
||||||
|
}
|
||||||
|
|
||||||
|
f16 = float16.Float16(0xfc00)
|
||||||
|
if !f16.IsInf(-1) {
|
||||||
|
t.Errorf("Float16(0xfc00).IsInf(-1) returned false, wanted true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func float32parts(f32 float32) (exp int32, coef uint32, dropped uint32) {
|
||||||
|
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
|
||||||
|
const EXPSHIFT uint32 = 23
|
||||||
|
const EXPBIAS uint32 = 127
|
||||||
|
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
|
||||||
|
const DROPMASK uint32 = COEFMASK >> 10
|
||||||
|
u32 := math.Float32bits(f32)
|
||||||
|
exp = int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
|
||||||
|
coef = u32 & COEFMASK
|
||||||
|
dropped = coef & DROPMASK
|
||||||
|
return exp, coef, dropped
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNaN32(f32 float32) bool {
|
||||||
|
exp, coef, _ := float32parts(f32)
|
||||||
|
return (exp == 128) && (coef != 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isQuietNaN32(f32 float32) bool {
|
||||||
|
exp, coef, _ := float32parts(f32)
|
||||||
|
return (exp == 128) && (coef != 0) && ((coef & 0x00400000) != 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkFromNaN32ps(t *testing.T, f32 float32, f16 float16.Float16) {
|
||||||
|
|
||||||
|
if !isNaN32(f32) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u32 := math.Float32bits(f32)
|
||||||
|
nan16, err := float16.FromNaN32ps(f32)
|
||||||
|
|
||||||
|
if isQuietNaN32(f32) {
|
||||||
|
// result should be the same
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
|
||||||
|
}
|
||||||
|
if uint16(nan16) != uint16(f16) {
|
||||||
|
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted nan16 = %v, got nan16 = %v", u32, f32, f16, nan16)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// result should differ only by the signaling/quiet bit unless payload is empty
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
coef := uint16(f16) & uint16(0x03ff)
|
||||||
|
payload := uint16(f16) & uint16(0x01ff)
|
||||||
|
diff := uint16(nan16 ^ f16)
|
||||||
|
|
||||||
|
if payload == 0 {
|
||||||
|
// the lowest bit needed to be set to prevent turning sNaN into infinity, so 2 bits differ
|
||||||
|
if diff != 0x0201 {
|
||||||
|
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0201, got 0x%04x", u32, f32, diff)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// only the quiet bit was restored, so 1 bit differs
|
||||||
|
if diff != 0x0200 {
|
||||||
|
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0200, got 0x%04x. f16=0x%04x n16=0x%04x coef=0x%04x", u32, f32, diff, uint16(f16), uint16(nan16), coef)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkPrecision(t *testing.T, f32 float32, f16 float16.Float16, i uint64) {
|
||||||
|
// TODO: rewrite this test when time allows
|
||||||
|
|
||||||
|
u32 := math.Float32bits(f32)
|
||||||
|
u16 := f16.Bits()
|
||||||
|
f32bis := f16.Float32()
|
||||||
|
u32bis := math.Float32bits(f32bis)
|
||||||
|
pre := float16.PrecisionFromfloat32(f32)
|
||||||
|
roundtripped := u32 == u32bis
|
||||||
|
exp32, coef32, dropped32 := float32parts(f32)
|
||||||
|
|
||||||
|
if roundtripped {
|
||||||
|
checkRoundTrippedPrecision(t, u32, u16, u32bis, exp32, coef32, dropped32)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if pre == float16.PrecisionExact {
|
||||||
|
// this should only happen if both input and output are NaN
|
||||||
|
if !(f16.IsNaN() && isNaN32(f32)) {
|
||||||
|
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionExact when roundtrip failed with non-special value", i, u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if pre == float16.PrecisionUnknown {
|
||||||
|
if exp32 < -24 {
|
||||||
|
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionUnderflow", i, u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
if dropped32 != 0 {
|
||||||
|
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionInexact", i, u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
} else if pre == float16.PrecisionInexact {
|
||||||
|
checkPrecisionInexact(t, u32, u16, u32bis, exp32, coef32, dropped32)
|
||||||
|
} else if pre == float16.PrecisionUnderflow {
|
||||||
|
if exp32 >= -14 {
|
||||||
|
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnderflow when exp32 is >= -14", i, u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
} else if pre == float16.PrecisionOverflow {
|
||||||
|
if exp32 <= 15 {
|
||||||
|
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionOverflow when exp32 is <= 15", i, u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkPrecisionInexact(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
|
||||||
|
f32 := math.Float32frombits(u32)
|
||||||
|
f32bis := math.Float32frombits(u32bis)
|
||||||
|
|
||||||
|
if exp32 < -24 {
|
||||||
|
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionUnderflow", u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
if exp32 > 15 {
|
||||||
|
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionOverflow", u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
if coef32 == 0 {
|
||||||
|
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when coef32 is 0", u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
if dropped32 == 0 {
|
||||||
|
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when dropped32 is 0", u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkRoundTrippedPrecision(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
|
||||||
|
f32 := math.Float32frombits(u32)
|
||||||
|
f32bis := math.Float32frombits(u32bis)
|
||||||
|
pre := float16.PrecisionFromfloat32(f32)
|
||||||
|
f16 := float16.Frombits(u16)
|
||||||
|
|
||||||
|
if dropped32 != 0 {
|
||||||
|
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), dropped32 != 0 with successful roundtrip", u32, f32, u16, u32bis, f32bis)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pre != float16.PrecisionExact {
|
||||||
|
// there are 2046 values that are subnormal and can round-trip float32->float16->float32
|
||||||
|
if pre != float16.PrecisionUnknown {
|
||||||
|
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%032b) (%f), out f16bits=0x%04x (%v), back=0x%08x (%f), got %v, wanted PrecisionExact, exp=%d, coef=%d, drpd=%d", u32, u32, f32, u16, f16, u32bis, f32bis, pre, exp32, coef32, dropped32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
package pickle_test
|
package pickle_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
|
@ -18,11 +19,13 @@ func ExampleLoadInfo() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pickle.LoadInfo(modelFile)
|
m, err := pickle.LoadModelInfo(modelFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println(m)
|
||||||
|
|
||||||
// Output:
|
// Output:
|
||||||
// classifier.0.bias - [4096]
|
// classifier.0.bias - [4096]
|
||||||
// classifier.0.weight - [4096 25088]
|
// classifier.0.weight - [4096 25088]
|
||||||
|
@ -57,4 +60,25 @@ func ExampleLoadInfo() {
|
||||||
// features.7.bias - [128]
|
// features.7.bias - [128]
|
||||||
// features.7.weight - [128 128 3 3]
|
// features.7.weight - [128 128 3 3]
|
||||||
// Num of variables: 32
|
// Num of variables: 32
|
||||||
|
// Tensor DType: Float
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExampleModelFloat16() {
|
||||||
|
modelName := "HuggingFaceH4/tiny-random-LlamaForCausalLM"
|
||||||
|
url := "https://huggingface.co/HuggingFaceH4/tiny-random-LlamaForCausalLM/resolve/main/pytorch_model.bin"
|
||||||
|
|
||||||
|
modelFile, err := gotch.CachedPath(url, modelName)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := pickle.LoadModelInfo(modelFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Model DType: %v\n", m.DType())
|
||||||
|
|
||||||
|
// Output:
|
||||||
|
// Model DType: Half
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,7 @@ import (
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
"github.com/sugarme/gotch/nn"
|
"github.com/sugarme/gotch/nn"
|
||||||
"github.com/sugarme/gotch/ts"
|
"github.com/sugarme/gotch/ts"
|
||||||
)
|
)
|
||||||
|
@ -60,83 +61,41 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
|
||||||
dictResult := *result.(*Dict)
|
dictResult := *result.(*Dict)
|
||||||
for _, item := range dictResult {
|
for _, item := range dictResult {
|
||||||
name := item.Key
|
name := item.Key
|
||||||
itemTyp := reflect.TypeOf(item.Value).String()
|
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||||
switch itemTyp {
|
if !isStorageTensor {
|
||||||
case "*pickle.Dict": // Nested *pickle.Dict case
|
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||||
subResult := *item.Value.(*Dict)
|
return nil, err
|
||||||
for _, subItem := range subResult {
|
|
||||||
subName := subItem.Key
|
|
||||||
x, ok := subItem.Value.(*StorageTensor)
|
|
||||||
if !ok {
|
|
||||||
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data := x.Source.GetData()
|
|
||||||
size := x.Size
|
|
||||||
dtype := x.Source.DType()
|
|
||||||
device := x.Source.Device()
|
|
||||||
stride := x.Stride
|
|
||||||
storageOffset := x.StorageOffset
|
|
||||||
if reflect.ValueOf(data).Len() == 0 {
|
|
||||||
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO. should we just skip them?
|
|
||||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
|
||||||
size = []int64{1}
|
|
||||||
stride = []int64{1}
|
|
||||||
}
|
|
||||||
|
|
||||||
x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
|
||||||
if x.RequiresGrad {
|
|
||||||
x1.MustRequiresGrad_(x.RequiresGrad)
|
|
||||||
}
|
|
||||||
|
|
||||||
namedTensors[name.(string)] = x1
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
|
||||||
|
|
||||||
// if !isStorageTensor {
|
|
||||||
// err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
|
||||||
// return nil, err
|
|
||||||
// }
|
|
||||||
if !isStorageTensor {
|
|
||||||
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
data := sx.Source.GetData()
|
|
||||||
size := sx.Size
|
|
||||||
dtype := sx.Source.DType()
|
|
||||||
device := sx.Source.Device()
|
|
||||||
stride := sx.Stride
|
|
||||||
storageOffset := sx.StorageOffset
|
|
||||||
|
|
||||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
|
||||||
// log.Printf("data: %v\n", data)
|
|
||||||
|
|
||||||
// Dealing with Pytorch `..._tracked` variables.
|
|
||||||
if reflect.ValueOf(data).Len() == 0 {
|
|
||||||
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO. should we just skip them?
|
|
||||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
|
||||||
size = []int64{1}
|
|
||||||
stride = []int64{1}
|
|
||||||
}
|
|
||||||
|
|
||||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
|
||||||
if sx.RequiresGrad {
|
|
||||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
|
||||||
}
|
|
||||||
|
|
||||||
namedTensors[name.(string)] = x
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
data := sx.Source.GetData()
|
||||||
|
size := sx.Size
|
||||||
|
dtype := sx.Source.DType()
|
||||||
|
|
||||||
|
device := sx.Source.Device()
|
||||||
|
stride := sx.Stride
|
||||||
|
storageOffset := sx.StorageOffset
|
||||||
|
|
||||||
|
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||||
|
// log.Printf("data: %v\n", data)
|
||||||
|
|
||||||
|
// Dealing with Pytorch `..._tracked` variables.
|
||||||
|
if reflect.ValueOf(data).Len() == 0 {
|
||||||
|
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO. should we just skip them?
|
||||||
|
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||||
|
size = []int64{1}
|
||||||
|
stride = []int64{1}
|
||||||
|
}
|
||||||
|
|
||||||
|
x := ts.MustOfSlice(data, ts.WithDType(dtype)).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||||
|
if sx.RequiresGrad {
|
||||||
|
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||||
|
}
|
||||||
|
|
||||||
|
namedTensors[name.(string)] = x
|
||||||
}
|
}
|
||||||
case "*pickle.OrderedDict":
|
case "*pickle.OrderedDict":
|
||||||
dictResult := result.(*OrderedDict)
|
dictResult := result.(*OrderedDict)
|
||||||
|
@ -581,35 +540,74 @@ func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
|
||||||
return missingVariables, nil
|
return missingVariables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadInfo loads pretrained weights and prints out name and shape of weights.
|
type ModelInfor struct {
|
||||||
func LoadInfo(modelFile string) error {
|
weights map[string][]int64
|
||||||
weights, err := Decode(modelFile)
|
dtype gotch.DType
|
||||||
if err != nil {
|
}
|
||||||
err = fmt.Errorf("LoadInfo() failed: %w", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
layers := make([]string, 0, len(weights))
|
func NewModelInfor(weights map[string][]int64, dtype gotch.DType) *ModelInfor {
|
||||||
for tsName := range weights {
|
return &ModelInfor{
|
||||||
|
weights: weights,
|
||||||
|
dtype: dtype,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelInfor) String() string {
|
||||||
|
var summary string
|
||||||
|
layers := make([]string, 0, len(m.weights))
|
||||||
|
for tsName := range m.weights {
|
||||||
layers = append(layers, tsName)
|
layers = append(layers, tsName)
|
||||||
}
|
}
|
||||||
sort.Strings(layers)
|
sort.Strings(layers)
|
||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
var x *ts.Tensor
|
var x []int64
|
||||||
for tsName, tsVal := range weights {
|
for tsName, shape := range m.weights {
|
||||||
if tsName == l {
|
if tsName == l {
|
||||||
x = tsVal
|
x = shape
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
fmt.Printf("%s - %+v\n", l, x.MustSize())
|
|
||||||
|
summary += fmt.Sprintf("%s - %+v\n", l, x)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Num of variables: %v\n", len(weights))
|
summary += fmt.Sprintf("Num of variables: %v\n", len(m.weights))
|
||||||
|
summary += fmt.Sprintf("Tensor DType: %v\n", m.dtype)
|
||||||
|
|
||||||
for _, x := range weights {
|
return summary
|
||||||
x.MustDrop()
|
}
|
||||||
}
|
|
||||||
|
func (m *ModelInfor) DType() gotch.DType {
|
||||||
return nil
|
return m.dtype
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ModelInfor) Parameters() int {
|
||||||
|
return len(m.weights)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadInfo loads pretrained weights and prints out name and shape of weights.
|
||||||
|
func LoadModelInfo(modelFile string) (*ModelInfor, error) {
|
||||||
|
weights, err := Decode(modelFile)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("LoadInfo() failed: %w", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
w := make(map[string][]int64)
|
||||||
|
var dtype gotch.DType
|
||||||
|
isFirst := true
|
||||||
|
for n, x := range weights {
|
||||||
|
w[n] = x.MustSize()
|
||||||
|
|
||||||
|
if isFirst {
|
||||||
|
dtype = x.DType()
|
||||||
|
isFirst = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewModelInfor(w, dtype)
|
||||||
|
|
||||||
|
ts.CleanUp()
|
||||||
|
|
||||||
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/half"
|
||||||
)
|
)
|
||||||
|
|
||||||
// This file implements Pytorch storage data types.
|
// This file implements Pytorch storage data types.
|
||||||
|
@ -67,7 +68,8 @@ func (s *HalfStorageClass) New(size int, location string) Storage {
|
||||||
|
|
||||||
type HalfStorage struct {
|
type HalfStorage struct {
|
||||||
BaseStorage
|
BaseStorage
|
||||||
Data []float32
|
// Data []float32
|
||||||
|
Data []half.Float16
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Storage = &HalfStorage{}
|
var _ Storage = &HalfStorage{}
|
||||||
|
@ -77,7 +79,7 @@ func (s *HalfStorage) SetFromFile(r io.Reader) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
data := make([]float32, size)
|
data := make([]half.Float16, size)
|
||||||
br := NewLimitedBufferReader(r, size, 2, 512)
|
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
bytes, err := br.ReadNext()
|
bytes, err := br.ReadNext()
|
||||||
|
@ -85,7 +87,7 @@ func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
u16 := binary.LittleEndian.Uint16(bytes)
|
u16 := binary.LittleEndian.Uint16(bytes)
|
||||||
data[i] = math.Float32frombits(FloatBits16to32(u16))
|
data[i] = half.Float16(u16)
|
||||||
}
|
}
|
||||||
s.Data = data
|
s.Data = data
|
||||||
return nil
|
return nil
|
||||||
|
@ -96,7 +98,7 @@ func (s *HalfStorage) GetData() interface{} {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HalfStorage) DType() gotch.DType {
|
func (s *HalfStorage) DType() gotch.DType {
|
||||||
return gotch.Float
|
return gotch.Half
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *HalfStorage) Device() gotch.Device {
|
func (s *HalfStorage) Device() gotch.Device {
|
||||||
|
@ -108,6 +110,62 @@ func (s *HalfStorage) Device() gotch.Device {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BFloat16Storage:
|
||||||
|
// ================
|
||||||
|
type BFloat16StorageClass struct{}
|
||||||
|
|
||||||
|
var _ StorageClass = &BFloat16StorageClass{}
|
||||||
|
|
||||||
|
func (s *BFloat16StorageClass) New(size int, location string) Storage {
|
||||||
|
return &BFloat16Storage{
|
||||||
|
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||||
|
Data: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type BFloat16Storage struct {
|
||||||
|
BaseStorage
|
||||||
|
Data []half.BFloat16
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Storage = &BFloat16Storage{}
|
||||||
|
|
||||||
|
func (s *BFloat16Storage) SetFromFile(r io.Reader) error {
|
||||||
|
return setFromFile(s, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BFloat16Storage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||||
|
data := make([]half.BFloat16, size)
|
||||||
|
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
bytes, err := br.ReadNext()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
u16 := binary.LittleEndian.Uint16(bytes)
|
||||||
|
data[i] = half.BFloat16(u16)
|
||||||
|
}
|
||||||
|
s.Data = data
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BFloat16Storage) GetData() interface{} {
|
||||||
|
return s.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BFloat16Storage) DType() gotch.DType {
|
||||||
|
return gotch.BFloat16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *BFloat16Storage) Device() gotch.Device {
|
||||||
|
switch s.Location {
|
||||||
|
case "cuda":
|
||||||
|
return gotch.CudaIfAvailable()
|
||||||
|
default:
|
||||||
|
return gotch.CPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// FloatStorage:
|
// FloatStorage:
|
||||||
// =============
|
// =============
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ func (it *Iterable) Next() (item interface{}, ok bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch it.ItemKind.Kind().String() {
|
switch it.ItemKind.GoKind().String() {
|
||||||
case "int64":
|
case "int64":
|
||||||
item, err = it.Content.Int64Value([]int64{it.Index})
|
item, err = it.Content.Int64Value([]int64{it.Index})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
20
ts/npy.go
20
ts/npy.go
|
@ -101,22 +101,22 @@ func (h *NpyHeader) ToString() (string, error) {
|
||||||
shape := strings.Join(shapeStr, ",")
|
shape := strings.Join(shapeStr, ",")
|
||||||
|
|
||||||
var descr string
|
var descr string
|
||||||
switch h.descr.Kind().String() {
|
switch h.descr {
|
||||||
// case "float16": // NOTE. No float16 in Go primary types. TODO. implement
|
case gotch.Half:
|
||||||
// descr = "f2"
|
descr = "f2"
|
||||||
case "float32":
|
case gotch.Float:
|
||||||
descr = "f4"
|
descr = "f4"
|
||||||
case "float64":
|
case gotch.Double:
|
||||||
descr = "f8"
|
descr = "f8"
|
||||||
case "int":
|
case gotch.Int:
|
||||||
descr = "i4"
|
descr = "i4"
|
||||||
case "int64":
|
case gotch.Int64:
|
||||||
descr = "i8"
|
descr = "i8"
|
||||||
case "int16":
|
case gotch.Int16:
|
||||||
descr = "i2"
|
descr = "i2"
|
||||||
case "int8":
|
case gotch.Int8:
|
||||||
descr = "i1"
|
descr = "i1"
|
||||||
case "uint8":
|
case gotch.Uint8:
|
||||||
descr = "u1"
|
descr = "u1"
|
||||||
default:
|
default:
|
||||||
err := fmt.Errorf("Unsupported kind: %v\n", h.descr)
|
err := fmt.Errorf("Unsupported kind: %v\n", h.descr)
|
||||||
|
|
46
ts/print.go
46
ts/print.go
|
@ -11,41 +11,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (ts *Tensor) ValueGo() interface{} {
|
func (ts *Tensor) ValueGo() interface{} {
|
||||||
dtype := ts.DType()
|
return ts.Vals()
|
||||||
numel := ts.Numel()
|
|
||||||
var dst interface{}
|
|
||||||
switch dtype {
|
|
||||||
case gotch.Uint8:
|
|
||||||
dst = make([]uint8, numel)
|
|
||||||
case gotch.Int8:
|
|
||||||
dst = make([]int8, numel)
|
|
||||||
case gotch.Int16:
|
|
||||||
dst = make([]int16, numel)
|
|
||||||
case gotch.Int:
|
|
||||||
dst = make([]int32, numel)
|
|
||||||
case gotch.Int64:
|
|
||||||
dst = make([]int64, numel)
|
|
||||||
case gotch.Float:
|
|
||||||
dst = make([]float32, numel)
|
|
||||||
case gotch.Double:
|
|
||||||
dst = make([]float64, numel)
|
|
||||||
case gotch.Bool:
|
|
||||||
dst = make([]bool, numel)
|
|
||||||
default:
|
|
||||||
err := fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
err := ts.CopyData(dst, ts.Numel())
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// convert []int32 -> int
|
|
||||||
if reflect.TypeOf(dst).String() == "[]int32" {
|
|
||||||
dst = sliceInt32ToInt(dst.([]int32))
|
|
||||||
}
|
|
||||||
|
|
||||||
return dst
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func shapeToSize(shape []int64) int {
|
func shapeToSize(shape []int64) int {
|
||||||
|
@ -250,8 +216,8 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
|
||||||
// var typ T
|
// var typ T
|
||||||
typ := ts.DType()
|
typ := ts.DType()
|
||||||
|
|
||||||
switch typ.String() {
|
switch typ {
|
||||||
case "float32", "float64":
|
case gotch.Half, gotch.BFloat16, gotch.Float, gotch.Double:
|
||||||
switch f.verb {
|
switch f.verb {
|
||||||
case 'f', 'e', 'E', 'G', 'b':
|
case 'f', 'e', 'E', 'G', 'b':
|
||||||
// accepted. Do nothing
|
// accepted. Do nothing
|
||||||
|
@ -259,7 +225,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
|
||||||
f.verb = 'g'
|
f.verb = 'g'
|
||||||
}
|
}
|
||||||
|
|
||||||
case "uint8", "int8", "int16", "int32", "int64":
|
case gotch.Uint8, gotch.Int8, gotch.Int16, gotch.Int, gotch.Int64:
|
||||||
switch f.verb {
|
switch f.verb {
|
||||||
case 'b':
|
case 'b':
|
||||||
f.base = 2
|
f.base = 2
|
||||||
|
@ -273,7 +239,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
|
||||||
f.base = 10
|
f.base = 10
|
||||||
f.verb = 'd'
|
f.verb = 'd'
|
||||||
}
|
}
|
||||||
case "bool":
|
case gotch.Bool:
|
||||||
f.verb = 't'
|
f.verb = 't'
|
||||||
default:
|
default:
|
||||||
f.verb = 'v'
|
f.verb = 'v'
|
||||||
|
@ -318,7 +284,7 @@ func (ts *Tensor) Format(s fmt.State, verb rune) {
|
||||||
shape := toSliceInt(ts.MustSize())
|
shape := toSliceInt(ts.MustSize())
|
||||||
strides := shapeToStrides(shape)
|
strides := shapeToStrides(shape)
|
||||||
device := ts.MustDevice()
|
device := ts.MustDevice()
|
||||||
dtype := ts.DType().String()
|
dtype := ts.DType()
|
||||||
defined := ts.MustDefined()
|
defined := ts.MustDefined()
|
||||||
if verb == 'i' {
|
if verb == 'i' {
|
||||||
fmt.Fprintf(
|
fmt.Fprintf(
|
||||||
|
|
240
ts/tensor.go
240
ts/tensor.go
|
@ -273,27 +273,13 @@ func (ts *Tensor) nbytes() int64 {
|
||||||
if numel == 0 {
|
if numel == 0 {
|
||||||
return 0 // ts.None
|
return 0 // ts.None
|
||||||
}
|
}
|
||||||
dtype := ts.DType()
|
|
||||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
nbytes := int64(numel) * int64(eltSizeInBytes)
|
return int64(numel * ts.DType().Size())
|
||||||
|
|
||||||
return nbytes
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
||||||
// Decode sz
|
dtype := gotch.Int64 // tensor size dtype = int64
|
||||||
// 1. Count number of elements in data
|
nbytes := int(nsize) * int(dtype.Size())
|
||||||
elementNum := nsize
|
|
||||||
// 2. Element size in bytes
|
|
||||||
eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
|
||||||
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
|
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
|
||||||
r := bytes.NewReader(dataSlice)
|
r := bytes.NewReader(dataSlice)
|
||||||
dataIn := make([]int64, nsize)
|
dataIn := make([]int64, nsize)
|
||||||
|
@ -304,8 +290,49 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
||||||
return dataIn
|
return dataIn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TensorOptions constructs options to build/rebuild tensor.
|
||||||
|
type TensorOptions struct {
|
||||||
|
Name string
|
||||||
|
DType gotch.DType
|
||||||
|
Quantized bool
|
||||||
|
// TODO. can expand as needed
|
||||||
|
}
|
||||||
|
|
||||||
|
type TensorOpt func(*TensorOptions)
|
||||||
|
|
||||||
|
func DefaultTensorOptions() *TensorOptions {
|
||||||
|
return &TensorOptions{
|
||||||
|
Name: "",
|
||||||
|
DType: gotch.Float,
|
||||||
|
Quantized: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithName(v string) TensorOpt {
|
||||||
|
return func(o *TensorOptions) {
|
||||||
|
o.Name = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithDType(v gotch.DType) TensorOpt {
|
||||||
|
return func(o *TensorOptions) {
|
||||||
|
o.DType = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithQuantized(v bool) TensorOpt {
|
||||||
|
return func(o *TensorOptions) {
|
||||||
|
o.Quantized = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// OfSlice creates tensor from a slice data
|
// OfSlice creates tensor from a slice data
|
||||||
func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
|
func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) {
|
||||||
|
o := DefaultTensorOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
|
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
|
||||||
if reflect.TypeOf(data).String() == "[]int" {
|
if reflect.TypeOf(data).String() == "[]int" {
|
||||||
data = sliceIntToInt32(data.([]int))
|
data = sliceIntToInt32(data.([]int))
|
||||||
|
@ -318,10 +345,10 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
typ := reflect.TypeOf(data).Elem()
|
elementKind := reflect.TypeOf(data).Elem().Kind()
|
||||||
dataLen := v.Len()
|
dataLen := v.Len()
|
||||||
|
|
||||||
dtype, err := gotch.ToDType(typ)
|
dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -329,12 +356,7 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
|
||||||
shape := []int64{int64(dataLen)}
|
shape := []int64{int64(dataLen)}
|
||||||
elementNum := ElementCount(shape)
|
elementNum := ElementCount(shape)
|
||||||
|
|
||||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
nbytes := int(dtype.Size()) * elementNum
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
|
||||||
|
|
||||||
dataPtr, buff := CMalloc(nbytes)
|
dataPtr, buff := CMalloc(nbytes)
|
||||||
defer C.free(unsafe.Pointer(dataPtr))
|
defer C.free(unsafe.Pointer(dataPtr))
|
||||||
|
@ -343,30 +365,25 @@ func OfSlice(data interface{}, nameOpt ...string) (*Tensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cint, err := gotch.DType2CInt(dtype)
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind()))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return newTensor(ctensor, nameOpt...), nil
|
return newTensor(ctensor, o.Name), nil
|
||||||
// return newTensor(ctensor), nil
|
// return newTensor(ctensor), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OfDataSize creates Tensor from input byte data, shape and dtype.
|
// OfDataSize creates Tensor from input byte data, shape and dtype.
|
||||||
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, nameOpt ...string) (*Tensor, error) {
|
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) {
|
||||||
|
o := DefaultTensorOptions()
|
||||||
elementNum := ElementCount(shape)
|
for _, opt := range opts {
|
||||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
opt(o)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
elementNum := ElementCount(shape)
|
||||||
|
|
||||||
|
nbytes := elementNum * int(dtype.Size())
|
||||||
|
|
||||||
if nbytes != len(data) {
|
if nbytes != len(data) {
|
||||||
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
|
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
|
||||||
|
@ -380,24 +397,19 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType, nameOpt ...string
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
cint, err := gotch.DType2CInt(dtype)
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
|
||||||
if err != nil {
|
if err := TorchErr(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
return newTensor(ctensor, o.Name), nil
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return newTensor(ctensor, nameOpt...), nil
|
|
||||||
// return newTensor(ctensor), nil
|
// return newTensor(ctensor), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
||||||
// or panic if error
|
// or panic if error
|
||||||
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
|
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor {
|
||||||
ts, err := OfDataSize(data, size, dtype)
|
ts, err := OfDataSize(data, size, dtype, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -406,8 +418,8 @@ func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustOfSlice create a tensor from slice of data. It will be panic if error.
|
// MustOfSlice create a tensor from slice of data. It will be panic if error.
|
||||||
func MustOfSlice(data interface{}) *Tensor {
|
func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor {
|
||||||
ts, err := OfSlice(data)
|
ts, err := OfSlice(data, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -416,8 +428,8 @@ func MustOfSlice(data interface{}) *Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TensorFrom create a tensor from slice of data. It will be panic if error.
|
// TensorFrom create a tensor from slice of data. It will be panic if error.
|
||||||
func TensorFrom(data interface{}, nameOpt ...string) *Tensor {
|
func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor {
|
||||||
ts, err := OfSlice(data, nameOpt...)
|
ts, err := OfSlice(data, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -436,7 +448,12 @@ func (ts *Tensor) Print() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTensorFromData creates tensor from given data and shape
|
// NewTensorFromData creates tensor from given data and shape
|
||||||
func NewTensorFromData(data interface{}, shape []int64, nameOpt ...string) (*Tensor, error) {
|
func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) {
|
||||||
|
o := DefaultTensorOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(o)
|
||||||
|
}
|
||||||
|
|
||||||
// 1. Check whether data and shape match
|
// 1. Check whether data and shape match
|
||||||
elementNum, err := DataDim(data)
|
elementNum, err := DataDim(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -463,33 +480,18 @@ func NewTensorFromData(data interface{}, shape []int64, nameOpt ...string) (*Ten
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
cint, err := gotch.DType2CInt(dtype)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return newTensor(ctensor, nameOpt...), nil
|
return newTensor(ctensor, o.Name), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *Tensor) DType() gotch.DType {
|
func (ts *Tensor) DType() gotch.DType {
|
||||||
cint := lib.AtScalarType(ts.ctensor)
|
cint := lib.AtScalarType(ts.ctensor)
|
||||||
|
|
||||||
dtype, err := gotch.CInt2DType(cint)
|
return gotch.CKind2DType(cint)
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Tensor DType error: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return dtype
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *Tensor) Device() (gotch.Device, error) {
|
func (ts *Tensor) Device() (gotch.Device, error) {
|
||||||
|
@ -545,6 +547,7 @@ func (ts *Tensor) MustDevice() gotch.Device {
|
||||||
* return retVal
|
* return retVal
|
||||||
* }
|
* }
|
||||||
* */
|
* */
|
||||||
|
|
||||||
// Float64Value returns a float value on tensors holding a single element.
|
// Float64Value returns a float value on tensors holding a single element.
|
||||||
// An error is returned otherwise.
|
// An error is returned otherwise.
|
||||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||||
|
@ -748,7 +751,6 @@ func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGra
|
||||||
//
|
//
|
||||||
// NOTE: `dst` located in Go memory. Should it be?
|
// NOTE: `dst` located in Go memory. Should it be?
|
||||||
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
|
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
|
||||||
|
|
||||||
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
|
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
|
||||||
// there will be memory leak and or out of range error.
|
// there will be memory leak and or out of range error.
|
||||||
if len(dst) < int(numel) {
|
if len(dst) < int(numel) {
|
||||||
|
@ -757,12 +759,9 @@ func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
vs := unsafe.Pointer(&dst[0])
|
vs := unsafe.Pointer(&dst[0])
|
||||||
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
|
dtype := gotch.Uint8
|
||||||
if err != nil {
|
lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size())
|
||||||
return err
|
if err := TorchErr(); err != nil {
|
||||||
}
|
|
||||||
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
|
|
||||||
if err = TorchErr(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -784,55 +783,20 @@ func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
|
||||||
// and number of elements to C land. This may break in the future
|
// and number of elements to C land. This may break in the future
|
||||||
// if Go policy changes.
|
// if Go policy changes.
|
||||||
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
|
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
|
||||||
|
dtype, dlen, err := DataCheck(dst)
|
||||||
gotype, dlen, err := DataCheck(dst)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
dtype, err := gotch.ToDType(gotype)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if dlen < int(numel) {
|
if dlen < int(numel) {
|
||||||
err = fmt.Errorf("CopyData Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
|
err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if ts.DType() != dtype {
|
if ts.DType() != dtype {
|
||||||
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var vs unsafe.Pointer
|
// Get data pointer
|
||||||
switch dtype {
|
dataPtr := reflect.ValueOf(dst).UnsafePointer()
|
||||||
case gotch.Uint8:
|
|
||||||
vs = unsafe.Pointer(&dst.([]uint8)[0])
|
|
||||||
case gotch.Int8:
|
|
||||||
vs = unsafe.Pointer(&dst.([]int8)[0])
|
|
||||||
case gotch.Int16:
|
|
||||||
vs = unsafe.Pointer(&dst.([]int16)[0])
|
|
||||||
case gotch.Int:
|
|
||||||
vs = unsafe.Pointer(&dst.([]int32)[0])
|
|
||||||
case gotch.Int64:
|
|
||||||
vs = unsafe.Pointer(&dst.([]int64)[0])
|
|
||||||
case gotch.Float:
|
|
||||||
vs = unsafe.Pointer(&dst.([]float32)[0])
|
|
||||||
case gotch.Double:
|
|
||||||
vs = unsafe.Pointer(&dst.([]float64)[0])
|
|
||||||
case gotch.Bool:
|
|
||||||
vs = unsafe.Pointer(&dst.([]bool)[0])
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
|
lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size())
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
|
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -863,7 +827,6 @@ func (ts *Tensor) Numel() uint {
|
||||||
|
|
||||||
// ShallowClone returns a new tensor that share storage with the input tensor.
|
// ShallowClone returns a new tensor that share storage with the input tensor.
|
||||||
func (ts *Tensor) ShallowClone() (*Tensor, error) {
|
func (ts *Tensor) ShallowClone() (*Tensor, error) {
|
||||||
|
|
||||||
ctensor := lib.AtShallowClone(ts.ctensor)
|
ctensor := lib.AtShallowClone(ts.ctensor)
|
||||||
|
|
||||||
if err := TorchErr(); err != nil {
|
if err := TorchErr(); err != nil {
|
||||||
|
@ -1309,33 +1272,18 @@ func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
|
||||||
// E.g. res := xs.Vals().([]int64)
|
// E.g. res := xs.Vals().([]int64)
|
||||||
func (ts *Tensor) Vals() interface{} {
|
func (ts *Tensor) Vals() interface{} {
|
||||||
dtype := ts.DType()
|
dtype := ts.DType()
|
||||||
numel := ts.Numel()
|
numel := int(ts.Numel())
|
||||||
|
|
||||||
var retVal interface{}
|
typ, err := dtype.GoType()
|
||||||
|
if err != nil {
|
||||||
switch dtype.Name() {
|
log.Fatal(err)
|
||||||
case "uint8":
|
|
||||||
retVal = make([]uint8, numel)
|
|
||||||
case "int8":
|
|
||||||
retVal = make([]int8, numel)
|
|
||||||
case "int16":
|
|
||||||
retVal = make([]int16, numel)
|
|
||||||
case "int32":
|
|
||||||
retVal = make([]int32, numel)
|
|
||||||
case "int64":
|
|
||||||
retVal = make([]int64, numel)
|
|
||||||
case "float32":
|
|
||||||
retVal = make([]float32, numel)
|
|
||||||
case "float64":
|
|
||||||
retVal = make([]float64, numel)
|
|
||||||
case "bool":
|
|
||||||
retVal = make([]bool, numel)
|
|
||||||
default:
|
|
||||||
log.Fatalf("Unsupported dtype (%v)", dtype)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ts.CopyData(retVal, numel)
|
dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface()
|
||||||
return retVal
|
|
||||||
|
ts.CopyData(dataSlice, uint(numel))
|
||||||
|
|
||||||
|
return dataSlice
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlatView flattens a tensor.
|
// FlatView flattens a tensor.
|
||||||
|
|
64
ts/util.go
64
ts/util.go
|
@ -75,7 +75,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
if err := w.WriteByte(b); err != nil {
|
if err := w.WriteByte(b); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -86,14 +86,14 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
if v.Kind() == reflect.Slice {
|
if v.Kind() == reflect.Slice {
|
||||||
expected := int(shape[0])
|
expected := int(shape[0])
|
||||||
if v.Len() != expected {
|
if v.Len() != expected {
|
||||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
return fmt.Errorf("EncodeTensor() failed: mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
||||||
if len(shape) == 1 && v.Len() > 0 {
|
if len(shape) == 1 && v.Len() > 0 {
|
||||||
switch v.Index(0).Kind() {
|
switch v.Index(0).Kind() {
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
return binary.Write(w, nativeEndian, v.Interface())
|
return binary.Write(w, nativeEndian, v.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported type %v", v.Type())
|
return fmt.Errorf("EncodeTensor() failed: unsupported type %v", v.Type())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -122,7 +122,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
ptr.Elem().SetBool(b == 1)
|
ptr.Elem().SetBool(b == 1)
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -134,7 +134,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
||||||
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
||||||
if len(shape) == 1 && val.Len() > 0 {
|
if len(shape) == 1 && val.Len() > 0 {
|
||||||
switch val.Index(0).Kind() {
|
switch val.Index(0).Kind() {
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||||
return binary.Read(r, nativeEndian, val.Interface())
|
return binary.Read(r, nativeEndian, val.Interface())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,10 +152,10 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
||||||
}
|
}
|
||||||
|
|
||||||
// ElementCount counts number of element in the tensor given a shape
|
// ElementCount counts number of element in the tensor given a shape
|
||||||
func ElementCount(shape []int64) int64 {
|
func ElementCount(shape []int64) int {
|
||||||
n := int64(1)
|
n := 1
|
||||||
for _, d := range shape {
|
for _, d := range shape {
|
||||||
n *= d
|
n *= int(d)
|
||||||
}
|
}
|
||||||
return n
|
return n
|
||||||
}
|
}
|
||||||
|
@ -163,54 +163,48 @@ func ElementCount(shape []int64) int64 {
|
||||||
// DataDim returns number of elements in data
|
// DataDim returns number of elements in data
|
||||||
// NOTE: only support scalar and (nested) slice/array of scalar type
|
// NOTE: only support scalar and (nested) slice/array of scalar type
|
||||||
func DataDim(data interface{}) (retVal int, err error) {
|
func DataDim(data interface{}) (retVal int, err error) {
|
||||||
|
|
||||||
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||||
|
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataCheck checks the input data for element Go type and number of elements.
|
// DataCheck checks the input data for element Go type and number of elements.
|
||||||
// It will return errors if element type is not supported.
|
// It will return errors if element dtype is not supported.
|
||||||
func DataCheck(data interface{}) (k reflect.Type, n int, err error) {
|
func DataCheck(data interface{}) (dtype gotch.DType, n int, err error) {
|
||||||
|
|
||||||
return dataCheck(reflect.ValueOf(data).Interface(), 0)
|
return dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: 0 is reflect.Kind() of Invalid
|
// NOTE: 0 is reflect.Kind() of Invalid
|
||||||
// See: https://golang.org/pkg/reflect/#Kind
|
// See: https://golang.org/pkg/reflect/#Kind
|
||||||
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
func dataCheck(data interface{}, count int) (dtype gotch.DType, n int, err error) {
|
||||||
v := reflect.ValueOf(data)
|
v := reflect.ValueOf(data)
|
||||||
var goType reflect.Type = reflect.TypeOf(data)
|
|
||||||
var total int = count
|
var total int = count
|
||||||
var round = 0
|
var round = 0
|
||||||
|
|
||||||
switch v.Kind() {
|
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
|
||||||
case reflect.Slice, reflect.Array:
|
|
||||||
if round == 0 {
|
if round == 0 {
|
||||||
round = v.Len()
|
round = v.Len()
|
||||||
}
|
}
|
||||||
for i := 0; i < v.Len(); i++ {
|
for i := 0; i < v.Len(); i++ {
|
||||||
round--
|
round--
|
||||||
goType, total, err = dataCheck(v.Index(i).Interface(), total)
|
dtype, total, err = dataCheck(v.Index(i).Interface(), total)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reflect.TypeOf(reflect.Zero), 0, err
|
return gotch.Invalid, 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return goType, total, nil
|
return dtype, total, nil
|
||||||
|
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
|
||||||
total++
|
|
||||||
if goType.String() != "invalid" {
|
|
||||||
goType = v.Type()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
|
|
||||||
return reflect.TypeOf(reflect.Zero), 0, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return goType, total, nil
|
total += 1
|
||||||
|
dtype, err = gotch.GoKind2DType(v.Kind())
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("DataCheck() failed: unsupported data structure or type: %v\n", v.Kind())
|
||||||
|
return gotch.Invalid, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return dtype, total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataAsPtr write to C memory and returns a C pointer.
|
// DataAsPtr write to C memory and returns a C pointer.
|
||||||
|
@ -219,25 +213,19 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
||||||
// Supported data types are scalar, slice/array of scalar type equivalent to
|
// Supported data types are scalar, slice/array of scalar type equivalent to
|
||||||
// DType.
|
// DType.
|
||||||
func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
|
func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
|
||||||
|
|
||||||
// 1. Count number of elements in data
|
// 1. Count number of elements in data
|
||||||
elementNum, err := DataDim(data)
|
elementNum, err := DataDim(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Element size in bytes
|
// 2. Number of bytes
|
||||||
dtype, err := gotch.DTypeFromData(data)
|
dtype, err := gotch.DTypeFromData(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
nbytes := int(dtype.Size()) * int(elementNum)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
|
||||||
|
|
||||||
// 3. Get C pointer and prepare C memory buffer for writing
|
// 3. Get C pointer and prepare C memory buffer for writing
|
||||||
dataPtr, buff := CMalloc(nbytes)
|
dataPtr, buff := CMalloc(nbytes)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user