gotch/half/bfloat16.go

168 lines
4.9 KiB
Go
Raw Permalink Normal View History

2023-07-06 15:01:23 +01:00
package half
import (
"math"
"math/bits"
)
// A 16-bit floating point type implementing the bfloat16 format.
// Ref. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
// https://github.com/starkat99/half-rs/tree/main/src/bfloat
// The bfloat16 - Google 'brain' floating point format is a truncated 16-bit version of the IEEE 754 standard binary32.
// bfloat16 has approximately the same dynamic range as float32 (8 bits -> 3.4 × 10^38) by having a lower precision than float16.
// While float16 has a precision of 10 bits, bfloat16 has a precision of only 7 bits.
//
// +------------+------------------------+----------------------------+
// | 1-bit sign | 8-bit exponent (range) | 7-bit fraction (precision) |
// +------------+------------------------+----------------------------+
type BFloat16 uint16
// Ref.https://github.com/starkat99/half-rs/blob/cabfc74e2a48b44b4556780f9d1550dd50a708be/src/bfloat/convert.rs#L5C1-L24C1
func Float32ToBFloat16(value float32) uint16 {
// convert to raw bytes
x := math.Float32bits(value)
// Check for NaN
if (x & 0x7FFF_FFFF) > 0x7F80_0000 {
// keep high part of current mantissa but also set most significant mantissa bit
return uint16((x >> 16) | 0x0040)
}
// Round and shift
var roundBit uint32 = 0x0000_8000
if ((x & roundBit) != 0) && ((x & (3*roundBit - 1)) != 0) {
return uint16(x>>16) + 1
} else {
return uint16(x >> 16)
}
}
func Float64ToBFloat16(value float64) uint16 {
// Convert o raw bytes, truncating the last 32-bits of mantissa
// that precision will always be lost on half-precision
val := math.Float64bits(value)
x := uint32(val >> 32)
// Extract IEEE754 components
sign := x & 0x8000_0000
exp := x & 0x7FF0_0000
man := x & 0x000F_FFFF
// Check for all exponent bit being set, which is Infinity or NaN
if exp == 0x7FF0_0000 {
// Set mantissa MSB for NaN and also keep shifted mantissa bits.
// Also check the last 32 bits.
var nanBit uint32 = 0x0040
if man == 0 && (uint32(val) == 0) {
nanBit = 0
}
return uint16((sign >> 16) | 0x7F80 | nanBit | (man >> 13))
}
// The number is normalized, start assembling half precision version
halfSign := sign >> 16
// Unbias the exponent, then bias for bfloat16 precision
unbiasedExp := (int64(exp>>20) - 1023)
halfExp := unbiasedExp + 127
// Check for exponent overflow, return +infinity
if halfExp >= 0xFF {
return uint16(halfSign | 0x7F80)
}
// Check for underflow
if halfExp <= 0 {
// Check mantissa for what we can do
if 7-halfExp > 21 {
// No rounding possibility, so this is a full underflow, return signed zero
return uint16(halfSign)
}
// Don't forget about hidden leading mantissa bit when assembling mantissa
man = man | 0x0010_0000
halfMan := man >> (14 - halfExp)
// Check for rounding
var roundBit uint32 = 1 << (13 - halfExp)
if ((man & roundBit) != 0) && ((man & (3*roundBit - 1)) != 0) {
halfMan += 1
}
// No exponent for subnormals
return uint16(halfSign | halfMan)
}
// Rebias the exponent
halfExp1 := uint32(halfExp) << 7
halfMan1 := man >> 13
// Check for rounding
var roundBit1 uint32 = 0x0000_1000
if ((man & roundBit1) != 0) && ((man & (3*roundBit1 - 1)) != 0) {
// Round it
return uint16((halfSign | halfExp1 | halfMan1) + 1)
} else {
return uint16(halfSign | halfExp1 | halfMan1)
}
}
func BFloat16ToFloat32(i uint16) float32 {
// If NaN, keep current mantissa but also set most significant mantissa bit
if i&0x7FFF > 0x7F80 {
return math.Float32frombits((uint32(i) | 0x0040) << 16)
} else {
return math.Float32frombits(uint32(i) << 16)
}
}
func BFloat16ToFloat64(i uint16) float64 {
// Check for signed zero
if i&0x7FFF == 0 {
return math.Float64frombits(uint64(i) << 48)
}
halfSign := uint64(i & 0x8000)
halfExp := uint64(i & 0x7F80)
halfMan := uint64(i & 0x007F)
// Check for an infinity or NaN when all exponent bits set
if halfExp == 0x7F80 {
// Check for signed infinity if mantissa is zero
if halfMan == 0 {
return math.Float64frombits((halfSign << 48) | 0x7FF0_0000_0000_0000)
} else {
// NaN, keep current mantissa but also set most significant mantissa bit
return math.Float64frombits((halfSign << 48) | 0x7FF8_0000_0000_0000 | (halfMan << 45))
}
}
// Calculate double-precision components with adjusted exponent
sign := halfSign << 48
// Unbias exponent
unbiasedExp := (int64(halfExp) >> 7) - 127
// Check for subnormals, which will be normalized by adjusting exponent
if halfExp == 0 {
// Calculate how much to adjust the exponent by
// leading zeros uint16
e := bits.LeadingZeros16(uint16(halfMan)) - 9
// Rebias and adjust exponent
exp := (uint64(1023-127-e) << 52)
man := (halfMan << (46 + e)) & 0xF_FFFF_FFFF_FFFF
return math.Float64frombits(sign | exp | man)
}
// Rebias exponent for a normalized normal
exp := uint64(unbiasedExp+1023) << 52
man := (halfMan & 0x007F) << 45
return math.Float64frombits(sign | exp | man)
}