gotch/half/bfloat16.go

168 lines
4.9 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package half
import (
"math"
"math/bits"
)
// A 16-bit floating point type implementing the bfloat16 format.
// Ref. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
// https://github.com/starkat99/half-rs/tree/main/src/bfloat
// The bfloat16 - Google 'brain' floating point format is a truncated 16-bit version of the IEEE 754 standard binary32.
// bfloat16 has approximately the same dynamic range as float32 (8 bits -> 3.4 × 10^38) by having a lower precision than float16.
// While float16 has a precision of 10 bits, bfloat16 has a precision of only 7 bits.
//
// +------------+------------------------+----------------------------+
// | 1-bit sign | 8-bit exponent (range) | 7-bit fraction (precision) |
// +------------+------------------------+----------------------------+
type BFloat16 uint16
// Ref.https://github.com/starkat99/half-rs/blob/cabfc74e2a48b44b4556780f9d1550dd50a708be/src/bfloat/convert.rs#L5C1-L24C1
func Float32ToBFloat16(value float32) uint16 {
// convert to raw bytes
x := math.Float32bits(value)
// Check for NaN
if (x & 0x7FFF_FFFF) > 0x7F80_0000 {
// keep high part of current mantissa but also set most significant mantissa bit
return uint16((x >> 16) | 0x0040)
}
// Round and shift
var roundBit uint32 = 0x0000_8000
if ((x & roundBit) != 0) && ((x & (3*roundBit - 1)) != 0) {
return uint16(x>>16) + 1
} else {
return uint16(x >> 16)
}
}
func Float64ToBFloat16(value float64) uint16 {
// Convert o raw bytes, truncating the last 32-bits of mantissa
// that precision will always be lost on half-precision
val := math.Float64bits(value)
x := uint32(val >> 32)
// Extract IEEE754 components
sign := x & 0x8000_0000
exp := x & 0x7FF0_0000
man := x & 0x000F_FFFF
// Check for all exponent bit being set, which is Infinity or NaN
if exp == 0x7FF0_0000 {
// Set mantissa MSB for NaN and also keep shifted mantissa bits.
// Also check the last 32 bits.
var nanBit uint32 = 0x0040
if man == 0 && (uint32(val) == 0) {
nanBit = 0
}
return uint16((sign >> 16) | 0x7F80 | nanBit | (man >> 13))
}
// The number is normalized, start assembling half precision version
halfSign := sign >> 16
// Unbias the exponent, then bias for bfloat16 precision
unbiasedExp := (int64(exp>>20) - 1023)
halfExp := unbiasedExp + 127
// Check for exponent overflow, return +infinity
if halfExp >= 0xFF {
return uint16(halfSign | 0x7F80)
}
// Check for underflow
if halfExp <= 0 {
// Check mantissa for what we can do
if 7-halfExp > 21 {
// No rounding possibility, so this is a full underflow, return signed zero
return uint16(halfSign)
}
// Don't forget about hidden leading mantissa bit when assembling mantissa
man = man | 0x0010_0000
halfMan := man >> (14 - halfExp)
// Check for rounding
var roundBit uint32 = 1 << (13 - halfExp)
if ((man & roundBit) != 0) && ((man & (3*roundBit - 1)) != 0) {
halfMan += 1
}
// No exponent for subnormals
return uint16(halfSign | halfMan)
}
// Rebias the exponent
halfExp1 := uint32(halfExp) << 7
halfMan1 := man >> 13
// Check for rounding
var roundBit1 uint32 = 0x0000_1000
if ((man & roundBit1) != 0) && ((man & (3*roundBit1 - 1)) != 0) {
// Round it
return uint16((halfSign | halfExp1 | halfMan1) + 1)
} else {
return uint16(halfSign | halfExp1 | halfMan1)
}
}
func BFloat16ToFloat32(i uint16) float32 {
// If NaN, keep current mantissa but also set most significant mantissa bit
if i&0x7FFF > 0x7F80 {
return math.Float32frombits((uint32(i) | 0x0040) << 16)
} else {
return math.Float32frombits(uint32(i) << 16)
}
}
func BFloat16ToFloat64(i uint16) float64 {
// Check for signed zero
if i&0x7FFF == 0 {
return math.Float64frombits(uint64(i) << 48)
}
halfSign := uint64(i & 0x8000)
halfExp := uint64(i & 0x7F80)
halfMan := uint64(i & 0x007F)
// Check for an infinity or NaN when all exponent bits set
if halfExp == 0x7F80 {
// Check for signed infinity if mantissa is zero
if halfMan == 0 {
return math.Float64frombits((halfSign << 48) | 0x7FF0_0000_0000_0000)
} else {
// NaN, keep current mantissa but also set most significant mantissa bit
return math.Float64frombits((halfSign << 48) | 0x7FF8_0000_0000_0000 | (halfMan << 45))
}
}
// Calculate double-precision components with adjusted exponent
sign := halfSign << 48
// Unbias exponent
unbiasedExp := (int64(halfExp) >> 7) - 127
// Check for subnormals, which will be normalized by adjusting exponent
if halfExp == 0 {
// Calculate how much to adjust the exponent by
// leading zeros uint16
e := bits.LeadingZeros16(uint16(halfMan)) - 9
// Rebias and adjust exponent
exp := (uint64(1023-127-e) << 52)
man := (halfMan << (46 + e)) & 0xF_FFFF_FFFF_FFFF
return math.Float64frombits(sign | exp | man)
}
// Rebias exponent for a normalized normal
exp := uint64(unbiasedExp+1023) << 52
man := (halfMan & 0x007F) << 45
return math.Float64frombits(sign | exp | man)
}