upgrade libtorch v2.0

This commit is contained in:
sugarme 2023-10-11 12:00:02 +11:00
commit 4f03dec060
79 changed files with 439197 additions and 5817 deletions

View File

@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- reworked `ts.Format()`
- Added conv2d benchmark
- Fixed #88 memory leak at `example/char-rnn`
- Added missing tensor `Stride()` and `MustDataPtr()`, `IsMkldnn`, `MustIsMkldnn`, `IsContiguous`, `MustIsContiguous`
- Added ts `New()`
## [0.7.0]
- Added `WsName` and `BsName` fields to `nn.LayerNorm.Config`

View File

@ -24,8 +24,13 @@
## Installation
<<<<<<< HEAD
- Default CUDA version is `11.8` if CUDA is available otherwise using CPU version.
- Default Pytorch C++ API version is `1.11.0`
=======
- Default CUDA version is `11.7` if CUDA is available otherwise using CPU version.
- Default Pytorch C++ API version is `2.0.1`
>>>>>>> 83394ef0933537921d2765ccb5e9043671e2edb8
**NOTE**: `libtorch` will be installed at **`/usr/local/lib`**
@ -69,8 +74,12 @@
```bash
wget https://github.com/sugarme/gotch/releases/download/v0.8.0/setup-libtorch.sh
chmod +x setup-libtorch.sh
<<<<<<< HEAD
export CUDA_VER=11.8 && bash setup-libtorch.sh
=======
export CUDA_VER=11.7 && bash setup-libtorch.sh
>>>>>>> 83394ef0933537921d2765ccb5e9043671e2edb8
```
**Update Environment**: in Debian/Ubuntu, add/update the following lines to `.bashrc` file
@ -87,7 +96,12 @@
```bash
wget https://github.com/sugarme/gotch/releases/download/v0.8.0/setup-gotch.sh
chmod +x setup-gotch.sh
<<<<<<< HEAD
export CUDA_VER=11.8 && export GOTCH_VER=v0.8.0 && bash setup-gotch.sh
=======
# CUDA 11.7
export CUDA_VER=11.7 && export GOTCH_VER=v0.8.0 && bash setup-gotch.sh
>>>>>>> 83394ef0933537921d2765ccb5e9043671e2edb8
```
## Examples

651
dtype.go
View File

@ -3,8 +3,6 @@ package gotch
import (
"fmt"
"log"
// "log"
"reflect"
)
@ -12,151 +10,148 @@ import (
type CInt = int32
// DType represents different kind of element that a tensor can hold.
// It has an embedded `reflect.Type` for type reflection.
type DType struct {
reflect.Type
}
// Ref. https://github.com/pytorch/pytorch/blob/a290cbf32b0c282aa60fa521ca5c6cd19c7f779f/c10/core/ScalarType.h
type DType int
/*
* // Custom-made Float16 as not exist in Go
* // Ref: https://github.com/golang/go/issues/32022
* type GoFloat16 = int16 // not implemented yet
* type GoComplexHalf = interface{} // not implemented yet!
* */
// TODO: double check these Torch DType to Go type
var (
Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0
Int8 DType = DType{reflect.TypeOf(int8(1))} // 1
Int16 DType = DType{reflect.TypeOf(int16(1))} // 2
Int DType = DType{reflect.TypeOf(int32(1))} // 3
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
Half DType = DType{reflect.TypeOf(float32(1))} // 5
Float DType = DType{reflect.TypeOf(float32(1))} // 6
Double DType = DType{reflect.TypeOf(float64(1))} // 7
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
// ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9
// ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10
Bool DType = DType{reflect.TypeOf(true)} // 11
const (
Invalid DType = -1
Uint8 DType = 0
Int8 DType = 1
Int16 DType = 2
Int DType = 3
Int64 DType = 4
Half DType = 5
Float DType = 6
Double DType = 7
ComplexHalf DType = 8
ComplexFloat DType = 9
ComplexDouble DType = 10
Bool DType = 11
QInt8 DType = 12
QUInt8 DType = 13
QInt32 DType = 14
BFloat16 DType = 15
// ---not implemented ---
QUInt4x2 DType = 16
QUInt2x4 DType = 17
Bits1x8 DType = 18
Bits2x4 DType = 19
Bits4x2 DType = 20
Bits8 DType = 21
Bits16 DType = 22
)
var dtypeGoType = map[DType]reflect.Type{
Uint8: reflect.TypeOf(uint8(1)),
Int8: reflect.TypeOf(int8(1)),
Int16: reflect.TypeOf(int16(1)),
Int: reflect.TypeOf(int32(1)),
Int64: reflect.TypeOf(int64(1)),
Half: reflect.TypeOf(float32(1)),
Float: reflect.TypeOf(float32(1)),
Double: reflect.TypeOf(float64(1)),
Bool: reflect.TypeOf(true),
var dtype2CKind = map[DType]CInt{
Uint8: 0,
Int8: 1,
Int16: 2,
Int: 3,
Int64: 4,
Half: 5,
Float: 6,
Double: 7,
ComplexHalf: 8,
ComplexFloat: 9,
ComplexDouble: 10,
Bool: 11,
QInt8: 12,
QUInt8: 13,
QInt32: 14,
BFloat16: 15,
// ---not implemented ---
QUInt4x2: 16,
QUInt2x4: 17,
Bits1x8: 18,
Bits2x4: 19,
Bits4x2: 20,
Bits8: 21,
Bits16: 22,
}
// ToDType infers and returns supported equivalent DType from given Go type
func ToDType(typ reflect.Type) (retVal DType, err error) {
var found = false
for key, val := range dtypeGoType {
if val == typ {
retVal = key
found = true
break
}
func (dt DType) CKind() CInt {
if cint, ok := dtype2CKind[dt]; ok {
return cint
}
if !found {
err = fmt.Errorf("Unsupported Go type: %v", typ)
return DType{}, err
if Debug {
log.Printf("WARNING: dt.CKind() failed: no corresponding CKind to this DType %v\n", dt)
}
return retVal, nil
return -1 // invalid
}
// ToGoType infers and returns supported equivalent Go type from given DType
func ToGoType(dtype DType) (retVal reflect.Type, err error) {
if _, ok := dtypeGoType[dtype]; !ok {
err = fmt.Errorf("Unsupported DType %v", dtype)
return nil, err
}
retVal = dtypeGoType[dtype]
return retVal, nil
// Back compat
func (dt DType) CInt() CInt {
return dt.CKind()
}
var dtypeCInt = map[DType]CInt{
Uint8: 0,
Int8: 1,
Int16: 2,
Int: 3,
Int64: 4,
Half: 5,
Float: 6,
Double: 7,
Bool: 11,
var ckind2DType map[CInt]DType = map[CInt]DType{
0: Uint8,
1: Int8,
2: Int16,
3: Int,
4: Int64,
5: Half,
6: Float,
7: Double,
8: ComplexHalf,
9: ComplexFloat,
10: ComplexDouble,
11: Bool,
12: QInt8,
13: QUInt8,
14: QInt32,
15: BFloat16,
// ---not implemented ---
16: QUInt4x2,
17: QUInt2x4,
18: Bits1x8,
19: Bits2x4,
20: Bits4x2,
21: Bits8,
22: Bits16,
}
func DType2CInt(dt DType) (retVal CInt, err error) {
if _, ok := dtypeCInt[dt]; !ok {
err = fmt.Errorf("Unsupported CInt conversion from DType: %v\n", dt)
func CKind2DType(ckind int32) DType {
if dtype, ok := ckind2DType[ckind]; ok {
return dtype
}
retVal = dtypeCInt[dt]
return retVal, nil
if Debug {
log.Printf("WARNING: CKind2DType() failed: no corresponding DType to input CInt %v\n", ckind)
}
return -1 // invalid
}
func (dt DType) CInt() (retVal CInt) {
retVal, err := DType2CInt(dt)
if err != nil {
log.Fatal(err)
}
return retVal
var dtypeSize map[DType]uint = map[DType]uint{
Uint8: 1,
Int8: 1,
Int16: 2,
Int: 4,
Int64: 8,
Half: 2,
Float: 4,
Double: 8,
ComplexHalf: 4,
ComplexFloat: 8,
ComplexDouble: 16,
Bool: 1,
QInt8: 1,
QUInt8: 1,
QInt32: 4,
BFloat16: 2,
QUInt4x2: 2,
QUInt2x4: 1,
// ---not implemented ---
Bits1x8: 1,
Bits2x4: 1,
Bits4x2: 1,
Bits8: 1,
Bits16: 2,
}
func CInt2DType(v CInt) (dtype DType, err error) {
var found = false
for key, val := range dtypeCInt {
if val == v {
dtype = key
found = true
break
}
}
if !found {
err = fmt.Errorf("Unsuported DType for CInt %v\n", v)
return DType{}, err
}
return dtype, nil
}
// dtypeSize is a map of DType and its size in Bytes
var dtypeSize = map[DType]uint{
Uint8: 1,
Int8: 1,
Int16: 2,
Int: 4,
Int64: 8,
Half: 4, // Should it be?
Float: 4,
Double: 8,
Bool: 1,
}
// DTypeSize returns DType size in Bytes
func DTypeSize(dt DType) (retVal uint, err error) {
if _, ok := dtypeSize[dt]; !ok {
err = fmt.Errorf("Unsupported conversion DType size in Byte for DType: %v\n", dt)
return 99, err
}
retVal = dtypeSize[dt]
return retVal, nil
// Size returns dtype size in Bytes.
func (dt DType) Size() uint {
return dtypeSize[dt]
}
type DTypeDevice struct {
@ -174,201 +169,255 @@ var (
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
)
// Type Inferring:
// ===============
// DTypeFromData infers returns equavalent DType from given data
func DTypeFromData(data interface{}) (retVal DType, err error) {
// NOTE: call `Interface()` to get data type back to interface{} type
typ, _, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
if err != nil {
return retVal, err
}
if typ.Kind() == reflect.Slice {
return ToDType(typ.Elem())
}
return ToDType(typ)
var dtype2GoKind map[DType]reflect.Kind = map[DType]reflect.Kind{
Uint8: reflect.Uint8,
Int8: reflect.Int8,
Int16: reflect.Int16,
Int: reflect.Int32,
Int64: reflect.Int64,
Half: reflect.Uint16, // <- just uint16
Float: reflect.Float32,
Double: reflect.Float64,
ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
ComplexFloat: reflect.Complex64,
ComplexDouble: reflect.Complex128,
Bool: reflect.Bool,
QInt8: reflect.Int8,
QUInt8: reflect.Uint8,
QInt32: reflect.Int32,
BFloat16: reflect.Uint16, // <- just uint16
// ---not implemented ---
QUInt4x2: reflect.Invalid,
QUInt2x4: reflect.Invalid,
Bits1x8: reflect.Invalid,
Bits2x4: reflect.Invalid,
Bits4x2: reflect.Invalid,
Bits8: reflect.Invalid,
Bits16: reflect.Invalid,
}
// NOTE: 0 is reflect.Kind() of Invalid
// See: https://golang.org/pkg/reflect/#Kind
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
v := reflect.ValueOf(data)
var goType reflect.Type = reflect.TypeOf(data)
var total int = count
var round = 0
func (dt DType) GoKind() reflect.Kind {
if kind, ok := dtype2GoKind[dt]; ok && kind != reflect.Invalid {
return kind
}
switch v.Kind() {
case reflect.Slice, reflect.Array:
if round == 0 {
round = v.Len()
if Debug {
log.Printf("WARNING: DType.GoKind() failed: no corresponding Go reflect.Kind to given DType %v\n", dt)
}
return reflect.Invalid
}
const (
// NOTE. reflect.Kind 0-26
QUInt8Kind reflect.Kind = 27
QInt8Kind reflect.Kind = 28
QInt32Kind reflect.Kind = 29
Float16Kind reflect.Kind = 30
BFloat16Kind reflect.Kind = 31
QUInt4x2Kind reflect.Kind = 32
QUInt2x4Kind reflect.Kind = 33
Bits1x8Kind reflect.Kind = 34
Bits2x4Kind reflect.Kind = 35
Bits4x2Kind reflect.Kind = 36
Bits8Kind reflect.Kind = 37
Bits16Kind reflect.Kind = 38
ComplexHalfKind reflect.Kind = 39
)
var goKind2DType map[reflect.Kind]DType = map[reflect.Kind]DType{
reflect.Uint8: Uint8,
reflect.Int8: Int8,
reflect.Int16: Int16,
reflect.Int32: Int,
reflect.Int64: Int64,
reflect.Float32: Float,
reflect.Float64: Double,
reflect.Complex64: ComplexFloat,
reflect.Complex128: ComplexDouble,
reflect.Bool: Bool,
reflect.Uint16: Half,
// Added Kinds
QUInt8Kind: QUInt8,
QInt8Kind: QInt8,
QInt32Kind: QInt32,
// Float16Kind: Half,
BFloat16Kind: BFloat16,
QUInt4x2Kind: QUInt4x2,
QUInt2x4Kind: QUInt2x4,
Bits1x8Kind: Bits1x8,
Bits2x4Kind: Bits2x4,
Bits4x2Kind: Bits4x2,
Bits8Kind: Bits8,
Bits16Kind: Bits16,
ComplexHalfKind: ComplexHalf,
}
type DTypeOptions struct {
HalfDTypePref DType
Quantized bool
}
type DTypeOpt func(*DTypeOptions)
func DefaultDTypeOptions() *DTypeOptions {
return &DTypeOptions{
HalfDTypePref: Half,
Quantized: false,
}
}
func HalfDTypePref(v DType) DTypeOpt {
if v != Half && v != BFloat16 {
if Debug {
log.Printf("WARNING: HalfDTypePref(): Ignoring invalid HalfDTypePref. HalfDTypePref either 'gotch.Half' or 'gotch.BFloat16'. Got %v\n", v)
}
for i := 0; i < v.Len(); i++ {
round--
goType, total, err = dataCheck(v.Index(i).Interface(), total)
if err != nil {
return reflect.TypeOf(reflect.Zero), 0, err
}
}
return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++
if goType.String() != "invalid" {
goType = v.Type()
}
default:
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
return reflect.TypeOf(reflect.Zero), 0, err
}
return goType, total, nil
}
// ElementGoType infers and returns Go type of element in given data
func ElementGoType(data interface{}) (retVal reflect.Type, err error) {
dataValue := reflect.ValueOf(data)
return elementType(dataValue)
}
func elementType(data reflect.Value) (dataType reflect.Type, err error) {
dataKind := data.Kind()
switch dataKind {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
dataType = data.Type()
case reflect.Slice, reflect.Array:
data = data.Elem()
dataType, err = elementType(data) // recursively type inferring
default:
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
return DType{}, err
return func(o *DTypeOptions) {
o.HalfDTypePref = v
}
return dataType, nil
}
// DataDType infers and returns data type of tensor data
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
// assuming that all elements in data have the same type
switch {
case len(shape) == 0:
retVal, err = ElementDType(v)
case len(shape) > 0:
return ElementDType(v.([]interface{})[0])
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
return DType{}, err
func WithQuantized(v bool) DTypeOpt {
return func(o *DTypeOptions) {
o.Quantized = v
}
return DType{}, nil
}
// ElementDType infers and returns its own tensor data type
func ElementDType(v interface{}) (retVal DType, err error) {
switch v.(type) {
case uint8:
retVal = Uint8
case int8:
retVal = Int8
case int16:
retVal = Int16
case int32:
retVal = Int
case int64:
retVal = Int64
case float32:
retVal = Float
case float64:
retVal = Double
case bool:
retVal = Bool
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
}
return retVal, nil
}
// TypeOf infers and returns element Go type from given tensor DType and shape
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
var typ reflect.Type
if typ, err = ToGoType(dt); err != nil {
return nil, err
func GoKind2DType(kind reflect.Kind, opts ...DTypeOpt) (DType, error) {
o := DefaultDTypeOptions()
for _, opt := range opts {
opt(o)
}
switch {
case len(shape) == 0:
return typ, nil
case len(shape) > 0:
return reflect.SliceOf(typ), nil
case kind == reflect.Uint16 && o.HalfDTypePref == Half:
return Half, nil
case kind == reflect.Uint16 && o.HalfDTypePref == BFloat16:
return BFloat16, nil
case kind == reflect.Int8 && o.Quantized:
return QInt8, nil
case kind == reflect.Uint8 && o.Quantized:
return QUInt8, nil
case kind == reflect.Int32 && o.Quantized:
return QInt32, nil
default:
err = fmt.Errorf("Unsupported data type.")
return nil, err
dtype, ok := goKind2DType[kind]
if !ok {
err := fmt.Errorf("GoKind2DType() failed: no corresponding DType to given Go reflect.Kind %v\n", kind)
return Invalid, err
}
return dtype, nil
}
}
/*
* // TypeCheck checks whether data Go type matching DType
* func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
* dataValue := reflect.ValueOf(data)
* var dataType reflect.Type
* var err error
* dataType, err = elementType(dataValue)
* if err != nil {
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
* msg += err.Error()
* return false, msg
* }
*
* matched = dataType == dtype.Type
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
*
* return matched, msg
* }
* */
var supportedTypes = map[reflect.Kind]bool{
reflect.Uint8: true,
reflect.Int8: true,
reflect.Int16: true,
reflect.Int32: true,
reflect.Int64: true,
reflect.Float32: true,
reflect.Float64: true,
reflect.Bool: true,
var dtype2GoType map[DType]reflect.Type = map[DType]reflect.Type{
Uint8: reflect.TypeOf(uint8(0)),
Int8: reflect.TypeOf(int8(0)),
Int16: reflect.TypeOf(int16(0)),
Int: reflect.TypeOf(int(0)),
Int64: reflect.TypeOf(int64(0)),
Half: reflect.TypeOf(uint16(0)), // <- just uint16
Float: reflect.TypeOf(float32(0)),
Double: reflect.TypeOf(float64(0)),
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
ComplexFloat: reflect.TypeOf(complex64(0)),
ComplexDouble: reflect.TypeOf(complex128(0)),
Bool: reflect.TypeOf(true),
QInt8: reflect.TypeOf(int8(0)),
QUInt8: reflect.TypeOf(uint8(0)),
QInt32: reflect.TypeOf(int32(0)),
BFloat16: reflect.TypeOf(uint16(0)), // <- just uint16
// ---not implemented ---
QUInt4x2: reflect.TypeOf(int8(0)),
QUInt2x4: reflect.TypeOf(uint8(0)),
Bits1x8: reflect.TypeOf(uint8(0)),
Bits2x4: reflect.TypeOf(uint8(0)),
Bits4x2: reflect.TypeOf(uint8(0)),
Bits8: reflect.TypeOf(uint8(0)),
Bits16: reflect.TypeOf(uint16(0)),
}
var scalarTypes = map[reflect.Kind]bool{
reflect.Bool: true,
reflect.Int: true,
reflect.Int8: true,
reflect.Int16: true,
reflect.Int32: true,
reflect.Int64: true,
reflect.Uint: true,
reflect.Uint8: true,
reflect.Uint16: true,
reflect.Uint32: true,
reflect.Uint64: true,
reflect.Uintptr: true,
reflect.Float32: true,
reflect.Float64: true,
reflect.Complex64: true,
reflect.Complex128: true,
func (dt DType) GoType() (reflect.Type, error) {
typ, ok := dtype2GoType[dt]
if !ok {
err := fmt.Errorf("DType.GoType() failed: no corresponding Go type to given DType %v\n", typ.String())
return nil, err
}
return typ, nil
}
// IsSupportedScalar checks whether given SCALAR type is supported
// TODO: check input is a scalar.
func IsSupportedScalar(k reflect.Kind) bool {
// if _, ok := scalarTypes[k]; !ok {
// log.Fatalf("Input type: %v is not a Go scalar type.", k)
// }
_, retVal := supportedTypes[k]
return retVal
var dtypeNames map[DType]string = map[DType]string{
Uint8: "Uint8",
Int8: "Int8",
Int16: "Int16",
Int: "Int",
Int64: "Int64",
Half: "Half", // <- just uint16
Float: "Float",
Double: "Double",
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
ComplexFloat: "ComplexFloat",
ComplexDouble: "ComplexDouble",
Bool: "Bool",
QInt8: "QInt8",
QUInt8: "QUInt8",
QInt32: "QInt32",
BFloat16: "BFloat16", // <- just uint16
// ---not implemented ---
QUInt4x2: "QUInt4x2",
QUInt2x4: "QUInt2x4",
Bits1x8: "Bits1x8",
Bits2x4: "Bits2x4",
Bits4x2: "Bits4x2",
Bits8: "Bits8",
Bits16: "Bits16",
}
func (dt DType) String() string {
return dtypeNames[dt]
}
func DTypeFromData(data interface{}) (DType, error) {
dataKind := reflect.TypeOf(data).Kind()
// Data is a slice/array
if dataKind == reflect.Slice || dataKind == reflect.Array {
elementKind := reflect.TypeOf(data).Elem().Kind()
return GoKind2DType(elementKind)
}
// single element
return GoKind2DType(dataKind)
}
// IsFloatDType returns whether dtype is floating point data type.
func IsFloatDType(dtype DType) bool {
switch dtype {
case Double, Float, Half, BFloat16:
return true
default:
return false
}
}
// Default DType:
// ==============
var DefaultDType DType = Float
// SetDefaultDType set DefaultDType to new value and return the previous one.
func SetDefaultDType(dtype DType) DType {
odtype := DefaultDType
DefaultDType = dtype
if Debug {
log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype)
}
return odtype
}

View File

@ -32,27 +32,18 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
state := lstm.Step(input, inState)
// 1. Delete inState tensors (from C land memory)
inState.(*nn.LSTMState).Tensor1.MustDrop()
inState.(*nn.LSTMState).Tensor2.MustDrop()
// 2. Then update with current state
// Update with current state
inState = state
// 3. Delete intermediate tensors
input.MustDrop()
inputView.MustDrop()
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueezeDim(0, true).MustSoftmax(-1, gotch.Float, true)
sampledY := forwardTs.MustMultinomial(1, false, true)
lastLabel = sampledY.Int64Values()[0]
sampledY.MustDrop()
char := data.LabelForChar(lastLabel)
runes = append(runes, char)
}
// Delete the last state
inState.(*nn.LSTMState).Tensor1.MustDrop()
inState.(*nn.LSTMState).Tensor2.MustDrop()
ts.CleanUp(100)
}
return string(runes)
}
@ -93,42 +84,31 @@ func main() {
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
xsOnehot := batchNarrow.Onehot(labels).MustTo(device, true) // [256, 180, 65]
batchNarrow.MustDrop()
ys := batchTs.MustNarrow(1, 1, SeqLen, true).MustTotype(gotch.Int64, true).MustTo(device, true).MustView([]int64{BatchSize * SeqLen}, true)
lstmOut, outState := lstm.Seq(xsOnehot)
// NOTE. Although outState will not be used. There a hidden memory usage
// on C land memory that is needed to free up. Don't use `_`
outState.(*nn.LSTMState).Tensor1.MustDrop()
outState.(*nn.LSTMState).Tensor2.MustDrop()
xsOnehot.MustDrop()
lstmOut, _ := lstm.Seq(xsOnehot)
logits := linear.Forward(lstmOut)
lstmOut.MustDrop()
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
loss := lossView.CrossEntropyForLogits(ys)
ys.MustDrop()
lossView.MustDrop()
opt.BackwardStepClip(loss, 0.5)
sumLoss += loss.Float64Values()[0]
cntLoss += 1.0
loss.MustDrop()
batchCount++
if batchCount%500 == 0 {
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
fmt.Printf("\nEpoch %v - Batch %v \n", epoch, batchCount)
}
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
// fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
fmt.Print(".")
ts.CleanUp(100)
} // infinite for-loop
sampleStr := sample(data, lstm, linear, device)
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
fmt.Printf("\nEpoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
fmt.Println(sampleStr)
dataIter.Data.MustDrop()
dataIter.Indexes.MustDrop()
}
}

79
example/mem/main.go Normal file
View File

@ -0,0 +1,79 @@
package main
import (
"fmt"
"math/rand"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/ts"
)
const (
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
BatchSize int64 = 3000
epochsNN = 200
LrNN = 1e-3
)
type model struct {
fc *nn.Linear
act nn.Func
}
func newModel(vs *nn.VarStore) *model {
fc := nn.NewLinear(vs.Root(), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())
act := nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustRelu(false)
})
return &model{
fc: fc,
act: act,
}
}
func (m *model) Forward(x *ts.Tensor) *ts.Tensor {
fc := m.fc.Forward(x)
act := m.act.Forward(fc)
return act
}
func newData() []float32 {
n := int(BatchSize * ImageDimNN)
data := make([]float32, n)
for i := 0; i < n; i++ {
data[i] = rand.Float32()
}
return data
}
func main() {
epochs := 4000
// device := gotch.CPU
device := gotch.CudaIfAvailable()
vs := nn.NewVarStore(device)
m := newModel(vs)
for i := 0; i < epochs; i++ {
// input := ts.MustOfSlice(newData()).MustView([]int64{BatchSize, ImageDimNN}, true).MustTo(device, true)
input := ts.MustRandn([]int64{BatchSize, ImageDimNN}, gotch.Float, device)
ts.NoGrad(func() {
_ = m.Forward(input)
})
if i%10 == 0 {
fmt.Printf("=================== Epoch %03d completed========================\n", i)
}
}
ts.CleanUp()
}

BIN
example/mem/mem.prof Normal file

Binary file not shown.

View File

@ -0,0 +1,80 @@
# CNN MNIST training Float vs BFloat16
## BFloat16 - 16bit floating point
```bash
testImages: [10000 784]
testLabels: [10000]
Start eval...Epoch: 0 Loss: 0.05 Test accuracy: 98.05%
Start eval...Epoch: 1 Loss: 0.03 Test accuracy: 98.36%
Start eval...Epoch: 2 Loss: 0.03 Test accuracy: 98.44%
Start eval...Epoch: 3 Loss: 0.18 Test accuracy: 98.44%
Start eval...Epoch: 4 Loss: 0.01 Test accuracy: 98.52%
Start eval...Epoch: 5 Loss: 0.06 Test accuracy: 98.52%
Start eval...Epoch: 6 Loss: 0.21 Test accuracy: 98.52%
Start eval...Epoch: 7 Loss: 0.05 Test accuracy: 98.59%
Start eval...Epoch: 8 Loss: 0.12 Test accuracy: 98.52%
Start eval...Epoch: 9 Loss: 0.12 Test accuracy: 98.48%
Start eval...Epoch: 10 Loss: 0.04 Test accuracy: 98.52%
Start eval...Epoch: 11 Loss: 0.03 Test accuracy: 98.52%
Start eval...Epoch: 12 Loss: 0.04 Test accuracy: 98.48%
Start eval...Epoch: 13 Loss: 0.32 Test accuracy: 98.48%
Start eval...Epoch: 14 Loss: 0.06 Test accuracy: 98.52%
Start eval...Epoch: 15 Loss: 0.10 Test accuracy: 98.55%
Start eval...Epoch: 16 Loss: 0.02 Test accuracy: 98.52%
Start eval...Epoch: 17 Loss: 0.01 Test accuracy: 98.48%
Start eval...Epoch: 18 Loss: 0.01 Test accuracy: 98.67%
Start eval...Epoch: 19 Loss: 0.10 Test accuracy: 98.63%
Start eval...Epoch: 20 Loss: 0.05 Test accuracy: 98.71%
Start eval...Epoch: 21 Loss: 0.01 Test accuracy: 98.79%
Start eval...Epoch: 22 Loss: 0.05 Test accuracy: 98.71%
Start eval...Epoch: 23 Loss: 0.03 Test accuracy: 98.67%
Start eval...Epoch: 24 Loss: 0.03 Test accuracy: 98.67%
Start eval...Epoch: 25 Loss: 0.16 Test accuracy: 98.75%
Start eval...Epoch: 26 Loss: 0.07 Test accuracy: 98.75%
Start eval...Epoch: 27 Loss: 0.01 Test accuracy: 98.75%
Start eval...Epoch: 28 Loss: 0.15 Test accuracy: 98.63%
Start eval...Epoch: 29 Loss: 0.01 Test accuracy: 98.59%
Best test accuracy: 98.79%
Taken time: 8.67 mins
```
## Float - 32bit floating point
```bash
testImages: [10000 784]
testLabels: [10000]
Start eval...Epoch: 0 Loss: 0.27 Test accuracy: 98.42%
Start eval...Epoch: 1 Loss: 0.06 Test accuracy: 98.60%
Start eval...Epoch: 2 Loss: 0.01 Test accuracy: 98.68%
Start eval...Epoch: 3 Loss: 0.01 Test accuracy: 98.63%
Start eval...Epoch: 4 Loss: 0.11 Test accuracy: 98.82%
Start eval...Epoch: 5 Loss: 0.11 Test accuracy: 99.00%
Start eval...Epoch: 6 Loss: 0.00 Test accuracy: 98.93%
Start eval...Epoch: 7 Loss: 0.00 Test accuracy: 98.96%
Start eval...Epoch: 8 Loss: 0.01 Test accuracy: 99.02%
Start eval...Epoch: 9 Loss: 0.04 Test accuracy: 99.04%
Start eval...Epoch: 10 Loss: 0.06 Test accuracy: 99.07%
Start eval...Epoch: 11 Loss: 0.01 Test accuracy: 99.12%
Start eval...Epoch: 12 Loss: 0.00 Test accuracy: 99.12%
Start eval...Epoch: 13 Loss: 0.00 Test accuracy: 99.12%
Start eval...Epoch: 14 Loss: 0.04 Test accuracy: 99.14%
Start eval...Epoch: 15 Loss: 0.07 Test accuracy: 99.12%
Start eval...Epoch: 16 Loss: 0.00 Test accuracy: 99.08%
Start eval...Epoch: 17 Loss: 0.00 Test accuracy: 99.10%
Start eval...Epoch: 18 Loss: 0.08 Test accuracy: 99.16%
Start eval...Epoch: 19 Loss: 0.07 Test accuracy: 99.20%
Start eval...Epoch: 20 Loss: 0.00 Test accuracy: 99.06%
Start eval...Epoch: 21 Loss: 0.05 Test accuracy: 98.97%
Start eval...Epoch: 22 Loss: 0.01 Test accuracy: 99.13%
Start eval...Epoch: 23 Loss: 0.00 Test accuracy: 99.13%
Start eval...Epoch: 24 Loss: 0.01 Test accuracy: 99.16%
Start eval...Epoch: 25 Loss: 0.00 Test accuracy: 99.11%
Start eval...Epoch: 26 Loss: 0.09 Test accuracy: 99.13%
Start eval...Epoch: 27 Loss: 0.00 Test accuracy: 99.14%
Start eval...Epoch: 28 Loss: 0.00 Test accuracy: 99.13%
Start eval...Epoch: 29 Loss: 0.01 Test accuracy: 99.20%
Best test accuracy: 99.20%
Taken time: 3.06 mins
```

149
example/mnist-fp16/main.go Normal file
View File

@ -0,0 +1,149 @@
package main
import (
"fmt"
"log"
"runtime"
"sync"
"time"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/ts"
"github.com/sugarme/gotch/vision"
)
func main() {
runCNN()
}
const (
MnistDir string = "/mnt/projects/numbat/data/mnist"
epochsCNN = 30
batchCNN = 256
// batchSize = 256
batchSize = 32
LrCNN = 3 * 1e-4
)
var mu sync.Mutex
// var device gotch.Device = gotch.CPU
var device gotch.Device = gotch.CudaIfAvailable()
// var dtype gotch.DType = gotch.BFloat16
// var dtype gotch.DType = gotch.Half
var dtype gotch.DType = gotch.Float
type Net struct {
conv1 *nn.Conv2D
conv2 *nn.Conv2D
fc1 *nn.Linear
fc2 *nn.Linear
}
func newNet(vs *nn.Path) *Net {
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())
return &Net{
conv1,
conv2,
fc1,
fc2}
}
func (n *Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
outC1 := outView1.Apply(n.conv1)
outMP1 := outC1.MaxPool2DDefault(2, true)
outC2 := outMP1.Apply(n.conv2)
outMP2 := outC2.MaxPool2DDefault(2, true)
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
outFC1 := outView2.Apply(n.fc1)
outRelu := outFC1.MustRelu(false)
outDropout := ts.MustDropout(outRelu, 0.5, train)
return outDropout.Apply(n.fc2)
}
func runCNN() {
var ds *vision.Dataset
ds = vision.LoadMNISTDir(MnistDir)
trainImages := ds.TrainImages.MustTo(device, false) //[60000, 784]
trainLabels := ds.TrainLabels.MustTo(device, false) // [60000, 784]
testImages := ds.TestImages.MustTo(device, false).MustTotype(dtype, true) // [10000, 784]
testLabels := ds.TestLabels.MustTo(device, false).MustTotype(dtype, true) // [10000, 784]
fmt.Printf("testImages: %v\n", testImages.MustSize())
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
odtype := gotch.SetDefaultDType(dtype)
vs := nn.NewVarStore(device)
net := newNet(vs.Root())
gotch.SetDefaultDType(odtype)
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
// opt, err := nn.DefaultSGDConfig().Build(vs, LrCNN)
if err != nil {
log.Fatal(err)
}
var bestAccuracy float64 = 0.0
startTime := time.Now()
for epoch := 0; epoch < epochsCNN; epoch++ {
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
// Shuffling
index := ts.MustRandperm(int64(totalSize), gotch.Int64, device)
imagesTs := trainImages.MustIndexSelect(0, index, false).MustTotype(dtype, true)
labelsTs := trainLabels.MustIndexSelect(0, index, false)
batches := samples / batchSize
batchIndex := 0
var epocLoss float64
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
break
}
batchIndex += 1
// Indexing
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
logits := net.ForwardT(bImages, true)
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
loss := logits.CrossEntropyForLogits(bLabels)
loss = loss.MustSetRequiresGrad(true, true)
opt.BackwardStep(loss)
epocLoss = loss.Float64Values()[0]
runtime.GC()
}
ts.NoGrad(func() {
fmt.Printf("Start eval...")
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1000)
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
if testAccuracy > bestAccuracy {
bestAccuracy = testAccuracy
}
})
}
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
ts.CleanUp()
}

View File

@ -264,3 +264,754 @@ Best test accuracy: 99.11%
Taken time: 2.81 mins
```
## New Implement with Go `Garbage Collection`
### Linear
```bash
go run -race . -model=linear -device=cuda
Epoch: 0 - Loss: 2.303 - Test accuracy: 68.08%
Epoch: 1 - Loss: 1.508 - Test accuracy: 60.77%
Epoch: 2 - Loss: 1.388 - Test accuracy: 52.54%
Epoch: 3 - Loss: 1.579 - Test accuracy: 64.46%
Epoch: 4 - Loss: 1.707 - Test accuracy: 60.47%
Epoch: 5 - Loss: 1.194 - Test accuracy: 61.55%
Epoch: 6 - Loss: 1.395 - Test accuracy: 70.66%
Epoch: 7 - Loss: 1.290 - Test accuracy: 70.43%
Epoch: 8 - Loss: 0.892 - Test accuracy: 66.76%
Epoch: 9 - Loss: 0.935 - Test accuracy: 71.84%
Epoch: 10 - Loss: 1.108 - Test accuracy: 73.77%
Epoch: 11 - Loss: 0.906 - Test accuracy: 78.11%
Epoch: 12 - Loss: 0.786 - Test accuracy: 79.07%
Epoch: 13 - Loss: 0.686 - Test accuracy: 80.85%
Epoch: 14 - Loss: 0.666 - Test accuracy: 79.68%
Epoch: 15 - Loss: 0.685 - Test accuracy: 82.69%
Epoch: 16 - Loss: 0.623 - Test accuracy: 82.13%
Epoch: 17 - Loss: 0.592 - Test accuracy: 82.36%
Epoch: 18 - Loss: 0.604 - Test accuracy: 80.09%
Epoch: 19 - Loss: 0.614 - Test accuracy: 80.69%
Epoch: 20 - Loss: 0.613 - Test accuracy: 78.71%
Epoch: 21 - Loss: 0.645 - Test accuracy: 81.83%
Epoch: 22 - Loss: 0.582 - Test accuracy: 81.11%
Epoch: 23 - Loss: 0.591 - Test accuracy: 83.56%
Epoch: 24 - Loss: 0.542 - Test accuracy: 83.58%
Epoch: 25 - Loss: 0.539 - Test accuracy: 85.27%
Epoch: 26 - Loss: 0.507 - Test accuracy: 85.70%
Epoch: 27 - Loss: 0.498 - Test accuracy: 86.79%
Epoch: 28 - Loss: 0.477 - Test accuracy: 87.10%
Epoch: 29 - Loss: 0.467 - Test accuracy: 87.83%
Epoch: 30 - Loss: 0.453 - Test accuracy: 88.09%
Epoch: 31 - Loss: 0.445 - Test accuracy: 88.44%
Epoch: 32 - Loss: 0.436 - Test accuracy: 88.52%
Epoch: 33 - Loss: 0.429 - Test accuracy: 88.78%
Epoch: 34 - Loss: 0.424 - Test accuracy: 88.78%
Epoch: 35 - Loss: 0.419 - Test accuracy: 89.20%
Epoch: 36 - Loss: 0.415 - Test accuracy: 89.10%
Epoch: 37 - Loss: 0.412 - Test accuracy: 89.49%
Epoch: 38 - Loss: 0.409 - Test accuracy: 89.37%
Epoch: 39 - Loss: 0.406 - Test accuracy: 89.62%
Epoch: 40 - Loss: 0.404 - Test accuracy: 89.58%
Epoch: 41 - Loss: 0.402 - Test accuracy: 89.71%
Epoch: 42 - Loss: 0.399 - Test accuracy: 89.68%
Epoch: 43 - Loss: 0.398 - Test accuracy: 89.85%
Epoch: 44 - Loss: 0.396 - Test accuracy: 89.77%
Epoch: 45 - Loss: 0.394 - Test accuracy: 89.87%
Epoch: 46 - Loss: 0.392 - Test accuracy: 89.89%
Epoch: 47 - Loss: 0.391 - Test accuracy: 89.95%
Epoch: 48 - Loss: 0.389 - Test accuracy: 89.96%
Epoch: 49 - Loss: 0.388 - Test accuracy: 90.05%
Epoch: 50 - Loss: 0.387 - Test accuracy: 90.04%
Epoch: 51 - Loss: 0.385 - Test accuracy: 90.13%
Epoch: 52 - Loss: 0.384 - Test accuracy: 90.11%
Epoch: 53 - Loss: 0.383 - Test accuracy: 90.24%
Epoch: 54 - Loss: 0.381 - Test accuracy: 90.19%
Epoch: 55 - Loss: 0.380 - Test accuracy: 90.27%
Epoch: 56 - Loss: 0.379 - Test accuracy: 90.22%
Epoch: 57 - Loss: 0.378 - Test accuracy: 90.30%
Epoch: 58 - Loss: 0.377 - Test accuracy: 90.28%
Epoch: 59 - Loss: 0.376 - Test accuracy: 90.33%
Epoch: 60 - Loss: 0.375 - Test accuracy: 90.33%
Epoch: 61 - Loss: 0.374 - Test accuracy: 90.37%
Epoch: 62 - Loss: 0.372 - Test accuracy: 90.38%
Epoch: 63 - Loss: 0.371 - Test accuracy: 90.41%
Epoch: 64 - Loss: 0.371 - Test accuracy: 90.42%
Epoch: 65 - Loss: 0.370 - Test accuracy: 90.43%
Epoch: 66 - Loss: 0.369 - Test accuracy: 90.45%
Epoch: 67 - Loss: 0.368 - Test accuracy: 90.51%
Epoch: 68 - Loss: 0.367 - Test accuracy: 90.52%
Epoch: 69 - Loss: 0.366 - Test accuracy: 90.53%
Epoch: 70 - Loss: 0.365 - Test accuracy: 90.53%
Epoch: 71 - Loss: 0.364 - Test accuracy: 90.55%
Epoch: 72 - Loss: 0.363 - Test accuracy: 90.57%
Epoch: 73 - Loss: 0.363 - Test accuracy: 90.57%
Epoch: 74 - Loss: 0.362 - Test accuracy: 90.58%
Epoch: 75 - Loss: 0.361 - Test accuracy: 90.59%
Epoch: 76 - Loss: 0.360 - Test accuracy: 90.59%
Epoch: 77 - Loss: 0.359 - Test accuracy: 90.62%
Epoch: 78 - Loss: 0.359 - Test accuracy: 90.67%
Epoch: 79 - Loss: 0.358 - Test accuracy: 90.67%
Epoch: 80 - Loss: 0.357 - Test accuracy: 90.70%
Epoch: 81 - Loss: 0.357 - Test accuracy: 90.71%
Epoch: 82 - Loss: 0.356 - Test accuracy: 90.72%
Epoch: 83 - Loss: 0.355 - Test accuracy: 90.77%
Epoch: 84 - Loss: 0.355 - Test accuracy: 90.77%
Epoch: 85 - Loss: 0.354 - Test accuracy: 90.78%
Epoch: 86 - Loss: 0.353 - Test accuracy: 90.80%
Epoch: 87 - Loss: 0.353 - Test accuracy: 90.82%
Epoch: 88 - Loss: 0.352 - Test accuracy: 90.83%
Epoch: 89 - Loss: 0.351 - Test accuracy: 90.82%
Epoch: 90 - Loss: 0.351 - Test accuracy: 90.84%
Epoch: 91 - Loss: 0.350 - Test accuracy: 90.87%
Epoch: 92 - Loss: 0.350 - Test accuracy: 90.88%
Epoch: 93 - Loss: 0.349 - Test accuracy: 90.87%
Epoch: 94 - Loss: 0.348 - Test accuracy: 90.89%
Epoch: 95 - Loss: 0.348 - Test accuracy: 90.89%
Epoch: 96 - Loss: 0.347 - Test accuracy: 90.91%
Epoch: 97 - Loss: 0.347 - Test accuracy: 90.94%
Epoch: 98 - Loss: 0.346 - Test accuracy: 90.94%
Epoch: 99 - Loss: 0.346 - Test accuracy: 90.96%
Epoch: 100 - Loss: 0.345 - Test accuracy: 90.96%
Epoch: 101 - Loss: 0.345 - Test accuracy: 90.96%
Epoch: 102 - Loss: 0.344 - Test accuracy: 91.02%
Epoch: 103 - Loss: 0.344 - Test accuracy: 91.02%
Epoch: 104 - Loss: 0.343 - Test accuracy: 91.04%
Epoch: 105 - Loss: 0.343 - Test accuracy: 91.05%
Epoch: 106 - Loss: 0.342 - Test accuracy: 91.05%
Epoch: 107 - Loss: 0.342 - Test accuracy: 91.06%
Epoch: 108 - Loss: 0.341 - Test accuracy: 91.07%
Epoch: 109 - Loss: 0.341 - Test accuracy: 91.08%
Epoch: 110 - Loss: 0.340 - Test accuracy: 91.09%
Epoch: 111 - Loss: 0.340 - Test accuracy: 91.11%
Epoch: 112 - Loss: 0.339 - Test accuracy: 91.14%
Epoch: 113 - Loss: 0.339 - Test accuracy: 91.15%
Epoch: 114 - Loss: 0.339 - Test accuracy: 91.14%
Epoch: 115 - Loss: 0.338 - Test accuracy: 91.15%
Epoch: 116 - Loss: 0.338 - Test accuracy: 91.17%
Epoch: 117 - Loss: 0.337 - Test accuracy: 91.18%
Epoch: 118 - Loss: 0.337 - Test accuracy: 91.19%
Epoch: 119 - Loss: 0.336 - Test accuracy: 91.20%
Epoch: 120 - Loss: 0.336 - Test accuracy: 91.22%
Epoch: 121 - Loss: 0.336 - Test accuracy: 91.23%
Epoch: 122 - Loss: 0.335 - Test accuracy: 91.28%
Epoch: 123 - Loss: 0.335 - Test accuracy: 91.30%
Epoch: 124 - Loss: 0.334 - Test accuracy: 91.29%
Epoch: 125 - Loss: 0.334 - Test accuracy: 91.32%
Epoch: 126 - Loss: 0.334 - Test accuracy: 91.35%
Epoch: 127 - Loss: 0.333 - Test accuracy: 91.37%
Epoch: 128 - Loss: 0.333 - Test accuracy: 91.36%
Epoch: 129 - Loss: 0.333 - Test accuracy: 91.36%
Epoch: 130 - Loss: 0.332 - Test accuracy: 91.36%
Epoch: 131 - Loss: 0.332 - Test accuracy: 91.35%
Epoch: 132 - Loss: 0.332 - Test accuracy: 91.36%
Epoch: 133 - Loss: 0.331 - Test accuracy: 91.37%
Epoch: 134 - Loss: 0.331 - Test accuracy: 91.38%
Epoch: 135 - Loss: 0.330 - Test accuracy: 91.39%
Epoch: 136 - Loss: 0.330 - Test accuracy: 91.40%
Epoch: 137 - Loss: 0.330 - Test accuracy: 91.43%
Epoch: 138 - Loss: 0.329 - Test accuracy: 91.44%
Epoch: 139 - Loss: 0.329 - Test accuracy: 91.45%
Epoch: 140 - Loss: 0.329 - Test accuracy: 91.46%
Epoch: 141 - Loss: 0.328 - Test accuracy: 91.47%
Epoch: 142 - Loss: 0.328 - Test accuracy: 91.46%
Epoch: 143 - Loss: 0.328 - Test accuracy: 91.48%
Epoch: 144 - Loss: 0.328 - Test accuracy: 91.46%
Epoch: 145 - Loss: 0.327 - Test accuracy: 91.46%
Epoch: 146 - Loss: 0.327 - Test accuracy: 91.46%
Epoch: 147 - Loss: 0.327 - Test accuracy: 91.47%
Epoch: 148 - Loss: 0.326 - Test accuracy: 91.47%
Epoch: 149 - Loss: 0.326 - Test accuracy: 91.48%
Epoch: 150 - Loss: 0.326 - Test accuracy: 91.48%
Epoch: 151 - Loss: 0.325 - Test accuracy: 91.50%
Epoch: 152 - Loss: 0.325 - Test accuracy: 91.50%
Epoch: 153 - Loss: 0.325 - Test accuracy: 91.52%
Epoch: 154 - Loss: 0.325 - Test accuracy: 91.52%
Epoch: 155 - Loss: 0.324 - Test accuracy: 91.52%
Epoch: 156 - Loss: 0.324 - Test accuracy: 91.51%
Epoch: 157 - Loss: 0.324 - Test accuracy: 91.51%
Epoch: 158 - Loss: 0.323 - Test accuracy: 91.52%
Epoch: 159 - Loss: 0.323 - Test accuracy: 91.51%
Epoch: 160 - Loss: 0.323 - Test accuracy: 91.51%
Epoch: 161 - Loss: 0.323 - Test accuracy: 91.50%
Epoch: 162 - Loss: 0.322 - Test accuracy: 91.51%
Epoch: 163 - Loss: 0.322 - Test accuracy: 91.53%
Epoch: 164 - Loss: 0.322 - Test accuracy: 91.54%
Epoch: 165 - Loss: 0.322 - Test accuracy: 91.54%
Epoch: 166 - Loss: 0.321 - Test accuracy: 91.54%
Epoch: 167 - Loss: 0.321 - Test accuracy: 91.56%
Epoch: 168 - Loss: 0.321 - Test accuracy: 91.56%
Epoch: 169 - Loss: 0.321 - Test accuracy: 91.56%
Epoch: 170 - Loss: 0.320 - Test accuracy: 91.57%
Epoch: 171 - Loss: 0.320 - Test accuracy: 91.59%
Epoch: 172 - Loss: 0.320 - Test accuracy: 91.59%
Epoch: 173 - Loss: 0.320 - Test accuracy: 91.60%
Epoch: 174 - Loss: 0.319 - Test accuracy: 91.60%
Epoch: 175 - Loss: 0.319 - Test accuracy: 91.61%
Epoch: 176 - Loss: 0.319 - Test accuracy: 91.61%
Epoch: 177 - Loss: 0.319 - Test accuracy: 91.61%
Epoch: 178 - Loss: 0.318 - Test accuracy: 91.60%
Epoch: 179 - Loss: 0.318 - Test accuracy: 91.60%
Epoch: 180 - Loss: 0.318 - Test accuracy: 91.60%
Epoch: 181 - Loss: 0.318 - Test accuracy: 91.60%
Epoch: 182 - Loss: 0.318 - Test accuracy: 91.60%
Epoch: 183 - Loss: 0.317 - Test accuracy: 91.60%
Epoch: 184 - Loss: 0.317 - Test accuracy: 91.60%
Epoch: 185 - Loss: 0.317 - Test accuracy: 91.62%
Epoch: 186 - Loss: 0.317 - Test accuracy: 91.63%
Epoch: 187 - Loss: 0.316 - Test accuracy: 91.66%
Epoch: 188 - Loss: 0.316 - Test accuracy: 91.66%
Epoch: 189 - Loss: 0.316 - Test accuracy: 91.65%
Epoch: 190 - Loss: 0.316 - Test accuracy: 91.65%
Epoch: 191 - Loss: 0.316 - Test accuracy: 91.65%
Epoch: 192 - Loss: 0.315 - Test accuracy: 91.65%
Epoch: 193 - Loss: 0.315 - Test accuracy: 91.67%
Epoch: 194 - Loss: 0.315 - Test accuracy: 91.66%
Epoch: 195 - Loss: 0.315 - Test accuracy: 91.66%
Epoch: 196 - Loss: 0.315 - Test accuracy: 91.67%
Epoch: 197 - Loss: 0.314 - Test accuracy: 91.68%
Epoch: 198 - Loss: 0.314 - Test accuracy: 91.68%
Epoch: 199 - Loss: 0.314 - Test accuracy: 91.68%
```
### NN
```bash
go run -race . -model=nn -device=cpu
Epoch: 0 Loss: 2.313 Test accuracy: 23.07%
Epoch: 1 Loss: 2.247 Test accuracy: 32.13%
Epoch: 2 Loss: 2.182 Test accuracy: 44.69%
Epoch: 3 Loss: 2.116 Test accuracy: 57.08%
Epoch: 4 Loss: 2.047 Test accuracy: 64.63%
Epoch: 5 Loss: 1.976 Test accuracy: 68.40%
Epoch: 6 Loss: 1.903 Test accuracy: 70.94%
Epoch: 7 Loss: 1.831 Test accuracy: 72.40%
Epoch: 8 Loss: 1.758 Test accuracy: 74.03%
Epoch: 9 Loss: 1.686 Test accuracy: 75.31%
Epoch: 10 Loss: 1.614 Test accuracy: 76.54%
Epoch: 11 Loss: 1.544 Test accuracy: 77.83%
Epoch: 12 Loss: 1.475 Test accuracy: 78.68%
Epoch: 13 Loss: 1.408 Test accuracy: 79.54%
Epoch: 14 Loss: 1.343 Test accuracy: 80.20%
Epoch: 15 Loss: 1.281 Test accuracy: 80.76%
Epoch: 16 Loss: 1.221 Test accuracy: 81.47%
Epoch: 17 Loss: 1.165 Test accuracy: 81.90%
Epoch: 18 Loss: 1.110 Test accuracy: 82.19%
Epoch: 19 Loss: 1.059 Test accuracy: 82.67%
Epoch: 20 Loss: 1.010 Test accuracy: 83.07%
Epoch: 21 Loss: 0.965 Test accuracy: 83.39%
Epoch: 22 Loss: 0.922 Test accuracy: 83.74%
Epoch: 23 Loss: 0.882 Test accuracy: 84.00%
Epoch: 24 Loss: 0.845 Test accuracy: 84.25%
Epoch: 25 Loss: 0.810 Test accuracy: 84.41%
Epoch: 26 Loss: 0.778 Test accuracy: 84.77%
Epoch: 27 Loss: 0.748 Test accuracy: 85.00%
Epoch: 28 Loss: 0.721 Test accuracy: 85.28%
Epoch: 29 Loss: 0.696 Test accuracy: 85.60%
Epoch: 30 Loss: 0.672 Test accuracy: 85.85%
Epoch: 31 Loss: 0.651 Test accuracy: 86.05%
Epoch: 32 Loss: 0.631 Test accuracy: 86.31%
Epoch: 33 Loss: 0.612 Test accuracy: 86.48%
Epoch: 34 Loss: 0.595 Test accuracy: 86.74%
Epoch: 35 Loss: 0.579 Test accuracy: 86.89%
Epoch: 36 Loss: 0.564 Test accuracy: 87.17%
Epoch: 37 Loss: 0.551 Test accuracy: 87.23%
Epoch: 38 Loss: 0.538 Test accuracy: 87.34%
Epoch: 39 Loss: 0.526 Test accuracy: 87.55%
Epoch: 40 Loss: 0.515 Test accuracy: 87.74%
Epoch: 41 Loss: 0.504 Test accuracy: 88.01%
Epoch: 42 Loss: 0.495 Test accuracy: 88.23%
Epoch: 43 Loss: 0.485 Test accuracy: 88.37%
Epoch: 44 Loss: 0.477 Test accuracy: 88.55%
Epoch: 45 Loss: 0.469 Test accuracy: 88.71%
Epoch: 46 Loss: 0.461 Test accuracy: 88.89%
Epoch: 47 Loss: 0.454 Test accuracy: 88.98%
Epoch: 48 Loss: 0.447 Test accuracy: 89.06%
Epoch: 49 Loss: 0.440 Test accuracy: 89.17%
Epoch: 50 Loss: 0.434 Test accuracy: 89.30%
Epoch: 51 Loss: 0.428 Test accuracy: 89.39%
Epoch: 52 Loss: 0.422 Test accuracy: 89.51%
Epoch: 53 Loss: 0.417 Test accuracy: 89.57%
Epoch: 54 Loss: 0.412 Test accuracy: 89.69%
Epoch: 55 Loss: 0.407 Test accuracy: 89.85%
Epoch: 56 Loss: 0.403 Test accuracy: 89.89%
Epoch: 57 Loss: 0.398 Test accuracy: 89.94%
Epoch: 58 Loss: 0.394 Test accuracy: 90.02%
Epoch: 59 Loss: 0.390 Test accuracy: 90.13%
Epoch: 60 Loss: 0.386 Test accuracy: 90.25%
Epoch: 61 Loss: 0.382 Test accuracy: 90.30%
Epoch: 62 Loss: 0.379 Test accuracy: 90.36%
Epoch: 63 Loss: 0.375 Test accuracy: 90.43%
Epoch: 64 Loss: 0.372 Test accuracy: 90.47%
Epoch: 65 Loss: 0.368 Test accuracy: 90.52%
Epoch: 66 Loss: 0.365 Test accuracy: 90.59%
Epoch: 67 Loss: 0.362 Test accuracy: 90.63%
Epoch: 68 Loss: 0.360 Test accuracy: 90.66%
Epoch: 69 Loss: 0.357 Test accuracy: 90.72%
Epoch: 70 Loss: 0.354 Test accuracy: 90.76%
Epoch: 71 Loss: 0.351 Test accuracy: 90.80%
Epoch: 72 Loss: 0.349 Test accuracy: 90.83%
Epoch: 73 Loss: 0.346 Test accuracy: 90.89%
Epoch: 74 Loss: 0.344 Test accuracy: 90.93%
Epoch: 75 Loss: 0.342 Test accuracy: 91.04%
Epoch: 76 Loss: 0.340 Test accuracy: 91.11%
Epoch: 77 Loss: 0.337 Test accuracy: 91.16%
Epoch: 78 Loss: 0.335 Test accuracy: 91.20%
Epoch: 79 Loss: 0.333 Test accuracy: 91.25%
Epoch: 80 Loss: 0.331 Test accuracy: 91.29%
Epoch: 81 Loss: 0.329 Test accuracy: 91.31%
Epoch: 82 Loss: 0.327 Test accuracy: 91.34%
Epoch: 83 Loss: 0.325 Test accuracy: 91.36%
Epoch: 84 Loss: 0.324 Test accuracy: 91.42%
Epoch: 85 Loss: 0.322 Test accuracy: 91.46%
Epoch: 86 Loss: 0.320 Test accuracy: 91.49%
Epoch: 87 Loss: 0.318 Test accuracy: 91.52%
Epoch: 88 Loss: 0.317 Test accuracy: 91.53%
Epoch: 89 Loss: 0.315 Test accuracy: 91.55%
Epoch: 90 Loss: 0.313 Test accuracy: 91.62%
Epoch: 91 Loss: 0.312 Test accuracy: 91.68%
Epoch: 92 Loss: 0.310 Test accuracy: 91.72%
Epoch: 93 Loss: 0.309 Test accuracy: 91.77%
Epoch: 94 Loss: 0.307 Test accuracy: 91.82%
Epoch: 95 Loss: 0.306 Test accuracy: 91.87%
Epoch: 96 Loss: 0.304 Test accuracy: 91.89%
Epoch: 97 Loss: 0.303 Test accuracy: 91.90%
Epoch: 98 Loss: 0.302 Test accuracy: 91.92%
Epoch: 99 Loss: 0.300 Test accuracy: 91.95%
Epoch: 100 Loss: 0.299 Test accuracy: 91.99%
Epoch: 101 Loss: 0.298 Test accuracy: 92.04%
Epoch: 102 Loss: 0.296 Test accuracy: 92.07%
Epoch: 103 Loss: 0.295 Test accuracy: 92.11%
Epoch: 104 Loss: 0.294 Test accuracy: 92.13%
Epoch: 105 Loss: 0.292 Test accuracy: 92.16%
Epoch: 106 Loss: 0.291 Test accuracy: 92.18%
Epoch: 107 Loss: 0.290 Test accuracy: 92.20%
Epoch: 108 Loss: 0.289 Test accuracy: 92.20%
Epoch: 109 Loss: 0.287 Test accuracy: 92.24%
Epoch: 110 Loss: 0.286 Test accuracy: 92.25%
Epoch: 111 Loss: 0.285 Test accuracy: 92.26%
Epoch: 112 Loss: 0.284 Test accuracy: 92.26%
Epoch: 113 Loss: 0.283 Test accuracy: 92.30%
Epoch: 114 Loss: 0.282 Test accuracy: 92.32%
Epoch: 115 Loss: 0.281 Test accuracy: 92.34%
Epoch: 116 Loss: 0.279 Test accuracy: 92.40%
Epoch: 117 Loss: 0.278 Test accuracy: 92.41%
Epoch: 118 Loss: 0.277 Test accuracy: 92.43%
Epoch: 119 Loss: 0.276 Test accuracy: 92.45%
Epoch: 120 Loss: 0.275 Test accuracy: 92.47%
Epoch: 121 Loss: 0.274 Test accuracy: 92.52%
Epoch: 122 Loss: 0.273 Test accuracy: 92.55%
Epoch: 123 Loss: 0.272 Test accuracy: 92.55%
Epoch: 124 Loss: 0.271 Test accuracy: 92.57%
Epoch: 125 Loss: 0.270 Test accuracy: 92.61%
Epoch: 126 Loss: 0.269 Test accuracy: 92.61%
Epoch: 127 Loss: 0.268 Test accuracy: 92.61%
Epoch: 128 Loss: 0.267 Test accuracy: 92.63%
Epoch: 129 Loss: 0.266 Test accuracy: 92.65%
Epoch: 130 Loss: 0.265 Test accuracy: 92.66%
Epoch: 131 Loss: 0.264 Test accuracy: 92.72%
Epoch: 132 Loss: 0.263 Test accuracy: 92.78%
Epoch: 133 Loss: 0.262 Test accuracy: 92.77%
Epoch: 134 Loss: 0.261 Test accuracy: 92.77%
Epoch: 135 Loss: 0.260 Test accuracy: 92.82%
Epoch: 136 Loss: 0.259 Test accuracy: 92.81%
Epoch: 137 Loss: 0.258 Test accuracy: 92.84%
Epoch: 138 Loss: 0.257 Test accuracy: 92.86%
Epoch: 139 Loss: 0.256 Test accuracy: 92.88%
Epoch: 140 Loss: 0.255 Test accuracy: 92.89%
Epoch: 141 Loss: 0.254 Test accuracy: 92.90%
Epoch: 142 Loss: 0.253 Test accuracy: 92.90%
Epoch: 143 Loss: 0.253 Test accuracy: 92.94%
Epoch: 144 Loss: 0.252 Test accuracy: 92.96%
Epoch: 145 Loss: 0.251 Test accuracy: 92.99%
Epoch: 146 Loss: 0.250 Test accuracy: 93.01%
Epoch: 147 Loss: 0.249 Test accuracy: 93.03%
Epoch: 148 Loss: 0.248 Test accuracy: 93.08%
Epoch: 149 Loss: 0.247 Test accuracy: 93.11%
Epoch: 150 Loss: 0.246 Test accuracy: 93.12%
Epoch: 151 Loss: 0.245 Test accuracy: 93.14%
Epoch: 152 Loss: 0.245 Test accuracy: 93.15%
Epoch: 153 Loss: 0.244 Test accuracy: 93.16%
Epoch: 154 Loss: 0.243 Test accuracy: 93.16%
Epoch: 155 Loss: 0.242 Test accuracy: 93.16%
Epoch: 156 Loss: 0.241 Test accuracy: 93.18%
Epoch: 157 Loss: 0.240 Test accuracy: 93.18%
Epoch: 158 Loss: 0.240 Test accuracy: 93.22%
Epoch: 159 Loss: 0.239 Test accuracy: 93.26%
Epoch: 160 Loss: 0.238 Test accuracy: 93.28%
Epoch: 161 Loss: 0.237 Test accuracy: 93.30%
Epoch: 162 Loss: 0.236 Test accuracy: 93.30%
Epoch: 163 Loss: 0.235 Test accuracy: 93.33%
Epoch: 164 Loss: 0.235 Test accuracy: 93.34%
Epoch: 165 Loss: 0.234 Test accuracy: 93.40%
Epoch: 166 Loss: 0.233 Test accuracy: 93.42%
Epoch: 167 Loss: 0.232 Test accuracy: 93.43%
Epoch: 168 Loss: 0.231 Test accuracy: 93.43%
Epoch: 169 Loss: 0.231 Test accuracy: 93.45%
Epoch: 170 Loss: 0.230 Test accuracy: 93.46%
Epoch: 171 Loss: 0.229 Test accuracy: 93.45%
Epoch: 172 Loss: 0.228 Test accuracy: 93.46%
Epoch: 173 Loss: 0.227 Test accuracy: 93.48%
Epoch: 174 Loss: 0.227 Test accuracy: 93.49%
Epoch: 175 Loss: 0.226 Test accuracy: 93.54%
Epoch: 176 Loss: 0.225 Test accuracy: 93.56%
Epoch: 177 Loss: 0.224 Test accuracy: 93.57%
Epoch: 178 Loss: 0.224 Test accuracy: 93.57%
Epoch: 179 Loss: 0.223 Test accuracy: 93.59%
Epoch: 180 Loss: 0.222 Test accuracy: 93.60%
Epoch: 181 Loss: 0.221 Test accuracy: 93.60%
Epoch: 182 Loss: 0.221 Test accuracy: 93.62%
Epoch: 183 Loss: 0.220 Test accuracy: 93.64%
Epoch: 184 Loss: 0.219 Test accuracy: 93.63%
Epoch: 185 Loss: 0.218 Test accuracy: 93.64%
Epoch: 186 Loss: 0.218 Test accuracy: 93.67%
Epoch: 187 Loss: 0.217 Test accuracy: 93.68%
Epoch: 188 Loss: 0.216 Test accuracy: 93.71%
Epoch: 189 Loss: 0.215 Test accuracy: 93.73%
Epoch: 190 Loss: 0.215 Test accuracy: 93.75%
Epoch: 191 Loss: 0.214 Test accuracy: 93.77%
Epoch: 192 Loss: 0.213 Test accuracy: 93.78%
Epoch: 193 Loss: 0.213 Test accuracy: 93.79%
Epoch: 194 Loss: 0.212 Test accuracy: 93.83%
Epoch: 195 Loss: 0.211 Test accuracy: 93.85%
Epoch: 196 Loss: 0.211 Test accuracy: 93.89%
Epoch: 197 Loss: 0.210 Test accuracy: 93.96%
Epoch: 198 Loss: 0.209 Test accuracy: 93.99%
Epoch: 199 Loss: 0.209 Test accuracy: 94.01%
```
### CNN
**BatchSize = 256 on CPU**
```
go run . -model=cnn -device=cpu
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.15 Test accuracy: 96.69%
Epoch: 1 Loss: 0.20 Test accuracy: 94.54%
Epoch: 2 Loss: 0.17 Test accuracy: 95.38%
Epoch: 3 Loss: 0.15 Test accuracy: 95.46%
Epoch: 4 Loss: 0.12 Test accuracy: 97.09%
Epoch: 5 Loss: 0.17 Test accuracy: 97.17%
Epoch: 6 Loss: 0.09 Test accuracy: 97.17%
Epoch: 7 Loss: 0.06 Test accuracy: 97.17%
Epoch: 8 Loss: 0.10 Test accuracy: 97.16%
Epoch: 9 Loss: 0.11 Test accuracy: 97.16%
Epoch: 10 Loss: 0.14 Test accuracy: 97.16%
Epoch: 11 Loss: 0.11 Test accuracy: 97.16%
Epoch: 12 Loss: 0.08 Test accuracy: 97.16%
Epoch: 13 Loss: 0.10 Test accuracy: 97.16%
Epoch: 14 Loss: 0.13 Test accuracy: 97.16%
Epoch: 15 Loss: 0.08 Test accuracy: 97.16%
Epoch: 16 Loss: 0.10 Test accuracy: 97.16%
Epoch: 17 Loss: 0.12 Test accuracy: 97.16%
Epoch: 18 Loss: 0.13 Test accuracy: 97.16%
Epoch: 19 Loss: 0.08 Test accuracy: 97.16%
Epoch: 20 Loss: 0.09 Test accuracy: 97.16%
Epoch: 21 Loss: 0.05 Test accuracy: 97.16%
Epoch: 22 Loss: 0.10 Test accuracy: 97.16%
Epoch: 23 Loss: 0.11 Test accuracy: 97.16%
Epoch: 24 Loss: 0.14 Test accuracy: 97.16%
Epoch: 25 Loss: 0.15 Test accuracy: 97.16%
Epoch: 26 Loss: 0.07 Test accuracy: 97.16%
Epoch: 27 Loss: 0.08 Test accuracy: 97.16%
Epoch: 28 Loss: 0.13 Test accuracy: 97.16%
Epoch: 29 Loss: 0.10 Test accuracy: 97.16%
Best test accuracy: 97.17%
Taken time: 9.04 mins
go run . -model=cnn -device=cpu
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.13 Test accuracy: 96.77%
Epoch: 1 Loss: 0.18 Test accuracy: 95.43%
Epoch: 2 Loss: 0.18 Test accuracy: 95.53%
Epoch: 3 Loss: 0.18 Test accuracy: 96.08%
Epoch: 4 Loss: 0.14 Test accuracy: 96.37%
Epoch: 5 Loss: 0.14 Test accuracy: 96.40%
Epoch: 6 Loss: 0.11 Test accuracy: 96.44%
Epoch: 7 Loss: 0.08 Test accuracy: 96.96%
Epoch: 8 Loss: 0.16 Test accuracy: 97.09%
Epoch: 9 Loss: 0.11 Test accuracy: 97.05%
Epoch: 10 Loss: 0.11 Test accuracy: 97.04%
Epoch: 11 Loss: 0.11 Test accuracy: 97.10%
Epoch: 12 Loss: 0.12 Test accuracy: 97.13%
Epoch: 13 Loss: 0.09 Test accuracy: 97.13%
Epoch: 14 Loss: 0.11 Test accuracy: 97.13%
Epoch: 15 Loss: 0.16 Test accuracy: 97.13%
Epoch: 16 Loss: 0.14 Test accuracy: 97.13%
Epoch: 17 Loss: 0.11 Test accuracy: 97.13%
Epoch: 18 Loss: 0.14 Test accuracy: 97.13%
Epoch: 19 Loss: 0.17 Test accuracy: 97.13%
Epoch: 20 Loss: 0.16 Test accuracy: 97.13%
Epoch: 21 Loss: 0.07 Test accuracy: 97.13%
Epoch: 22 Loss: 0.15 Test accuracy: 97.13%
Epoch: 23 Loss: 0.14 Test accuracy: 97.13%
Epoch: 24 Loss: 0.07 Test accuracy: 97.13%
Epoch: 25 Loss: 0.13 Test accuracy: 97.13%
Epoch: 26 Loss: 0.11 Test accuracy: 97.13%
Epoch: 27 Loss: 0.14 Test accuracy: 97.13%
Epoch: 28 Loss: 0.14 Test accuracy: 97.13%
Epoch: 29 Loss: 0.08 Test accuracy: 97.13%
Best test accuracy: 97.13%
Taken time: 9.37 mins
go run . -model=cnn -device=cpu
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.13 Test accuracy: 97.03%
Epoch: 1 Loss: 0.14 Test accuracy: 97.43%
Epoch: 2 Loss: 0.10 Test accuracy: 97.37%
Epoch: 3 Loss: 0.13 Test accuracy: 97.35%
Epoch: 4 Loss: 0.15 Test accuracy: 97.37%
Epoch: 5 Loss: 0.06 Test accuracy: 97.68%
Epoch: 6 Loss: 0.12 Test accuracy: 97.19%
Epoch: 7 Loss: 0.08 Test accuracy: 97.68%
Epoch: 8 Loss: 0.13 Test accuracy: 97.89%
Epoch: 9 Loss: 0.10 Test accuracy: 97.32%
Epoch: 10 Loss: 0.10 Test accuracy: 98.25%
Epoch: 11 Loss: 0.07 Test accuracy: 98.26%
Epoch: 12 Loss: 0.07 Test accuracy: 98.39%
Epoch: 13 Loss: 0.09 Test accuracy: 98.43%
Epoch: 14 Loss: 0.07 Test accuracy: 98.44%
Epoch: 15 Loss: 0.09 Test accuracy: 98.49%
Epoch: 16 Loss: 0.06 Test accuracy: 98.48%
Epoch: 17 Loss: 0.05 Test accuracy: 98.48%
Epoch: 18 Loss: 0.06 Test accuracy: 98.48%
Epoch: 19 Loss: 0.04 Test accuracy: 98.48%
Epoch: 20 Loss: 0.08 Test accuracy: 98.48%
Epoch: 21 Loss: 0.11 Test accuracy: 98.48%
Epoch: 22 Loss: 0.09 Test accuracy: 98.48%
Epoch: 23 Loss: 0.06 Test accuracy: 98.48%
Epoch: 24 Loss: 0.06 Test accuracy: 98.48%
Epoch: 25 Loss: 0.05 Test accuracy: 98.48%
Epoch: 26 Loss: 0.05 Test accuracy: 98.48%
Epoch: 27 Loss: 0.05 Test accuracy: 98.48%
Epoch: 28 Loss: 0.07 Test accuracy: 98.48%
Epoch: 29 Loss: 0.10 Test accuracy: 98.48%
Best test accuracy: 98.49%
Taken time: 9.39 mins
go run . -model=cnn -device=cpu
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.14 Test accuracy: 96.88%
Epoch: 1 Loss: 0.12 Test accuracy: 97.29%
Epoch: 2 Loss: 0.13 Test accuracy: 97.25%
Epoch: 3 Loss: 0.11 Test accuracy: 97.21%
Epoch: 4 Loss: 0.12 Test accuracy: 97.22%
Epoch: 5 Loss: 0.08 Test accuracy: 97.32%
Epoch: 6 Loss: 0.11 Test accuracy: 97.31%
Epoch: 7 Loss: 0.13 Test accuracy: 97.32%
Epoch: 8 Loss: 0.10 Test accuracy: 97.44%
Epoch: 9 Loss: 0.15 Test accuracy: 97.37%
Epoch: 10 Loss: 0.09 Test accuracy: 97.46%
Epoch: 11 Loss: 0.11 Test accuracy: 97.49%
Epoch: 12 Loss: 0.11 Test accuracy: 96.21%
Epoch: 13 Loss: 0.13 Test accuracy: 95.94%
Epoch: 14 Loss: 0.20 Test accuracy: 95.97%
Epoch: 15 Loss: 0.18 Test accuracy: 97.12%
Epoch: 16 Loss: 0.04 Test accuracy: 97.50%
Epoch: 17 Loss: 0.18 Test accuracy: 97.38%
Epoch: 18 Loss: 0.06 Test accuracy: 97.60%
Epoch: 19 Loss: 0.13 Test accuracy: 97.45%
Epoch: 20 Loss: 0.06 Test accuracy: 97.57%
Epoch: 21 Loss: 0.12 Test accuracy: 97.60%
Epoch: 22 Loss: 0.10 Test accuracy: 97.60%
Epoch: 23 Loss: 0.09 Test accuracy: 97.60%
Epoch: 24 Loss: 0.11 Test accuracy: 97.60%
Epoch: 25 Loss: 0.13 Test accuracy: 97.60%
Epoch: 26 Loss: 0.08 Test accuracy: 97.60%
Epoch: 27 Loss: 0.18 Test accuracy: 97.60%
Epoch: 28 Loss: 0.09 Test accuracy: 97.60%
Epoch: 29 Loss: 0.07 Test accuracy: 97.60%
Best test accuracy: 97.60%
Taken time: 9.41 mins
```
**BatchSize = 32 on CUDA**
```bash
go run . -model=cnn -device=cuda
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.28 Test accuracy: 98.41%
Epoch: 1 Loss: 0.01 Test accuracy: 98.55%
Epoch: 2 Loss: 0.09 Test accuracy: 98.53%
Epoch: 3 Loss: 0.01 Test accuracy: 98.64%
Epoch: 4 Loss: 0.01 Test accuracy: 98.74%
Epoch: 5 Loss: 0.01 Test accuracy: 98.81%
Epoch: 6 Loss: 0.10 Test accuracy: 98.91%
Epoch: 7 Loss: 0.02 Test accuracy: 98.86%
Epoch: 8 Loss: 0.00 Test accuracy: 98.64%
Epoch: 9 Loss: 0.17 Test accuracy: 98.84%
Epoch: 10 Loss: 0.01 Test accuracy: 98.83%
Epoch: 11 Loss: 0.00 Test accuracy: 98.88%
Epoch: 12 Loss: 0.05 Test accuracy: 98.90%
Epoch: 13 Loss: 0.01 Test accuracy: 99.01%
Epoch: 14 Loss: 0.09 Test accuracy: 97.85%
Epoch: 15 Loss: 0.10 Test accuracy: 98.24%
Epoch: 16 Loss: 0.00 Test accuracy: 98.53%
Epoch: 17 Loss: 0.00 Test accuracy: 98.49%
Epoch: 18 Loss: 0.16 Test accuracy: 98.49%
Epoch: 19 Loss: 0.13 Test accuracy: 98.49%
Epoch: 20 Loss: 0.00 Test accuracy: 98.49%
Epoch: 21 Loss: 0.01 Test accuracy: 98.49%
Epoch: 22 Loss: 0.17 Test accuracy: 98.49%
Epoch: 23 Loss: 0.06 Test accuracy: 98.49%
Epoch: 24 Loss: 0.00 Test accuracy: 98.49%
Epoch: 25 Loss: 0.12 Test accuracy: 98.49%
Epoch: 26 Loss: 0.08 Test accuracy: 98.49%
Epoch: 27 Loss: 0.19 Test accuracy: 98.49%
Epoch: 28 Loss: 0.02 Test accuracy: 98.49%
Epoch: 29 Loss: 0.01 Test accuracy: 98.49%
Best test accuracy: 99.01%
Taken time: 8.89 mins
go run . -model=cnn -device=cuda
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.05 Test accuracy: 98.40%
Epoch: 1 Loss: 0.01 Test accuracy: 98.92%
Epoch: 2 Loss: 0.10 Test accuracy: 98.97%
Epoch: 3 Loss: 0.03 Test accuracy: 98.79%
Epoch: 4 Loss: 0.02 Test accuracy: 98.81%
Epoch: 5 Loss: 0.15 Test accuracy: 98.85%
Epoch: 6 Loss: 0.01 Test accuracy: 98.82%
Epoch: 7 Loss: 0.03 Test accuracy: 98.83%
Epoch: 8 Loss: 0.01 Test accuracy: 98.56%
Epoch: 9 Loss: 0.00 Test accuracy: 98.85%
Epoch: 10 Loss: 0.22 Test accuracy: 98.51%
Epoch: 11 Loss: 0.78 Test accuracy: 98.37%
Epoch: 12 Loss: 0.01 Test accuracy: 98.47%
Epoch: 13 Loss: 0.55 Test accuracy: 98.48%
Epoch: 14 Loss: 0.00 Test accuracy: 98.45%
Epoch: 15 Loss: 0.13 Test accuracy: 98.47%
Epoch: 16 Loss: 0.01 Test accuracy: 98.49%
Epoch: 17 Loss: 0.00 Test accuracy: 98.35%
Epoch: 18 Loss: 0.08 Test accuracy: 98.41%
Epoch: 19 Loss: 0.63 Test accuracy: 98.58%
Epoch: 20 Loss: 0.22 Test accuracy: 98.59%
Epoch: 21 Loss: 0.00 Test accuracy: 98.63%
Epoch: 22 Loss: 0.80 Test accuracy: 98.63%
Epoch: 23 Loss: 0.19 Test accuracy: 98.63%
Epoch: 24 Loss: 0.00 Test accuracy: 98.63%
Epoch: 25 Loss: 0.00 Test accuracy: 98.63%
Epoch: 26 Loss: 0.00 Test accuracy: 98.63%
Epoch: 27 Loss: 0.00 Test accuracy: 98.63%
Epoch: 28 Loss: 0.09 Test accuracy: 98.63%
Epoch: 29 Loss: 0.02 Test accuracy: 98.63%
Best test accuracy: 98.97%
Taken time: 8.85 mins
go run . -model=cnn -device=cuda
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.39 Test accuracy: 97.83%
Epoch: 1 Loss: 0.01 Test accuracy: 97.95%
Epoch: 2 Loss: 0.00 Test accuracy: 98.74%
Epoch: 3 Loss: 0.00 Test accuracy: 98.64%
Epoch: 4 Loss: 0.07 Test accuracy: 98.62%
Epoch: 5 Loss: 0.01 Test accuracy: 98.75%
Epoch: 6 Loss: 0.01 Test accuracy: 98.76%
Epoch: 7 Loss: 0.26 Test accuracy: 98.33%
Epoch: 8 Loss: 0.04 Test accuracy: 98.44%
Epoch: 9 Loss: 0.12 Test accuracy: 98.60%
Epoch: 10 Loss: 0.00 Test accuracy: 98.60%
Epoch: 11 Loss: 0.51 Test accuracy: 98.60%
Epoch: 12 Loss: 0.05 Test accuracy: 98.60%
Epoch: 13 Loss: 0.12 Test accuracy: 98.60%
Epoch: 14 Loss: 0.00 Test accuracy: 98.60%
Epoch: 15 Loss: 0.03 Test accuracy: 98.60%
Epoch: 16 Loss: 0.03 Test accuracy: 98.60%
Epoch: 17 Loss: 0.25 Test accuracy: 98.60%
Epoch: 18 Loss: 0.18 Test accuracy: 98.35%
Epoch: 19 Loss: 0.18 Test accuracy: 98.42%
Epoch: 20 Loss: 0.01 Test accuracy: 98.40%
Epoch: 21 Loss: 0.01 Test accuracy: 98.66%
Epoch: 22 Loss: 0.11 Test accuracy: 98.71%
Epoch: 23 Loss: 0.17 Test accuracy: 98.72%
Epoch: 24 Loss: 0.21 Test accuracy: 98.72%
Epoch: 25 Loss: 0.00 Test accuracy: 98.72%
Epoch: 26 Loss: 0.00 Test accuracy: 98.72%
Epoch: 27 Loss: 0.00 Test accuracy: 98.72%
Epoch: 28 Loss: 0.06 Test accuracy: 98.72%
Epoch: 29 Loss: 0.11 Test accuracy: 98.72%
Best test accuracy: 98.76%
Taken time: 8.84 mins
go run . -model=cnn -device=cuda
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.15 Test accuracy: 98.48%
Epoch: 1 Loss: 0.23 Test accuracy: 98.95%
Epoch: 2 Loss: 0.02 Test accuracy: 98.94%
Epoch: 3 Loss: 0.01 Test accuracy: 99.06%
Epoch: 4 Loss: 0.16 Test accuracy: 99.03%
Epoch: 5 Loss: 0.01 Test accuracy: 99.07%
Epoch: 6 Loss: 0.22 Test accuracy: 98.25%
Epoch: 7 Loss: 0.06 Test accuracy: 98.23%
Epoch: 8 Loss: 0.26 Test accuracy: 98.25%
Epoch: 9 Loss: 0.07 Test accuracy: 98.25%
Epoch: 10 Loss: 0.02 Test accuracy: 98.25%
Epoch: 11 Loss: 0.04 Test accuracy: 98.35%
Epoch: 12 Loss: 0.01 Test accuracy: 98.36%
Epoch: 13 Loss: 0.01 Test accuracy: 98.36%
Epoch: 14 Loss: 0.04 Test accuracy: 98.42%
Epoch: 15 Loss: 0.04 Test accuracy: 98.54%
Epoch: 16 Loss: 0.11 Test accuracy: 98.53%
Epoch: 17 Loss: 0.07 Test accuracy: 98.53%
Epoch: 18 Loss: 0.45 Test accuracy: 98.53%
Epoch: 19 Loss: 0.07 Test accuracy: 98.53%
Epoch: 20 Loss: 0.15 Test accuracy: 98.53%
Epoch: 21 Loss: 0.20 Test accuracy: 98.53%
Epoch: 22 Loss: 0.02 Test accuracy: 98.53%
Epoch: 23 Loss: 0.02 Test accuracy: 98.53%
Epoch: 24 Loss: 0.00 Test accuracy: 98.53%
Epoch: 25 Loss: 0.01 Test accuracy: 98.53%
Epoch: 26 Loss: 0.12 Test accuracy: 98.53%
Epoch: 27 Loss: 0.01 Test accuracy: 98.53%
Epoch: 28 Loss: 0.04 Test accuracy: 98.53%
Epoch: 29 Loss: 0.18 Test accuracy: 98.53%
Best test accuracy: 99.07%
Taken time: 8.82 mins
testImages: [10000 784]
testLabels: [10000]
Epoch: 0 Loss: 0.02 Test accuracy: 98.37%
Epoch: 1 Loss: 0.01 Test accuracy: 98.26%
Epoch: 2 Loss: 0.02 Test accuracy: 98.51%
Epoch: 3 Loss: 0.17 Test accuracy: 98.56%
Epoch: 4 Loss: 0.02 Test accuracy: 98.60%
Epoch: 5 Loss: 0.00 Test accuracy: 98.66%
Epoch: 6 Loss: 0.01 Test accuracy: 98.85%
Epoch: 7 Loss: 0.02 Test accuracy: 98.86%
Epoch: 8 Loss: 0.01 Test accuracy: 98.42%
Epoch: 9 Loss: 0.00 Test accuracy: 98.44%
Epoch: 10 Loss: 0.02 Test accuracy: 98.50%
Epoch: 11 Loss: 0.00 Test accuracy: 98.50%
Epoch: 12 Loss: 0.05 Test accuracy: 98.50%
Epoch: 13 Loss: 0.13 Test accuracy: 98.50%
Epoch: 14 Loss: 0.00 Test accuracy: 98.50%
Epoch: 15 Loss: 0.12 Test accuracy: 98.50%
Epoch: 16 Loss: 0.00 Test accuracy: 98.50%
Epoch: 17 Loss: 0.03 Test accuracy: 98.50%
Epoch: 18 Loss: 0.41 Test accuracy: 98.50%
Epoch: 19 Loss: 0.17 Test accuracy: 98.50%
Epoch: 20 Loss: 0.26 Test accuracy: 98.50%
Epoch: 21 Loss: 0.00 Test accuracy: 98.50%
Epoch: 22 Loss: 0.29 Test accuracy: 98.50%
Epoch: 23 Loss: 0.00 Test accuracy: 98.50%
Epoch: 24 Loss: 0.20 Test accuracy: 98.50%
Epoch: 25 Loss: 0.01 Test accuracy: 98.50%
Epoch: 26 Loss: 0.18 Test accuracy: 98.50%
Epoch: 27 Loss: 0.01 Test accuracy: 98.50%
Epoch: 28 Loss: 0.12 Test accuracy: 98.50%
Epoch: 29 Loss: 0.04 Test accuracy: 98.50%
Best test accuracy: 98.86%
Taken time: 8.77 mins
```

View File

@ -3,6 +3,8 @@ package main
import (
"fmt"
"log"
"runtime"
"sync"
"time"
"github.com/sugarme/gotch"
@ -16,11 +18,14 @@ const (
epochsCNN = 30
batchCNN = 256
batchSize = 256
// batchSize = 256
batchSize = 32
LrCNN = 3 * 1e-4
)
var mu sync.Mutex
type Net struct {
conv1 *nn.Conv2D
conv2 *nn.Conv2D
@ -43,45 +48,32 @@ func newNet(vs *nn.Path) *Net {
func (n *Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
defer outView1.MustDrop()
outC1 := outView1.Apply(n.conv1)
outMP1 := outC1.MaxPool2DDefault(2, true)
defer outMP1.MustDrop()
outC2 := outMP1.Apply(n.conv2)
outMP2 := outC2.MaxPool2DDefault(2, true)
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
defer outView2.MustDrop()
outFC1 := outView2.Apply(n.fc1)
outRelu := outFC1.MustRelu(true)
defer outRelu.MustDrop()
outRelu := outFC1.MustRelu(false)
outDropout := ts.MustDropout(outRelu, 0.5, train)
defer outDropout.MustDrop()
return outDropout.Apply(n.fc2)
}
func runCNN1() {
var ds *vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
// ds.TrainImages [60000, 784]
// ds.TrainLabels [60000, 784]
testImages := ds.TestImages // [10000, 784]
testLabels := ds.TestLabels // [10000, 784]
trainImages := ds.TrainImages.MustTo(device, false) //[60000, 784]
trainLabels := ds.TrainLabels.MustTo(device, false) // [60000, 784]
testImages := ds.TestImages.MustTo(device, false) // [10000, 784]
testLabels := ds.TestLabels.MustTo(device, false) // [10000, 784]
fmt.Printf("testImages: %v\n", testImages.MustSize())
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
device := gotch.CudaIfAvailable()
vs := nn.NewVarStore(device)
net := newNet(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
// opt, err := nn.DefaultSGDConfig().Build(vs, LrCNN)
@ -96,10 +88,9 @@ func runCNN1() {
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
// Shuffling
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
index.MustDrop()
index := ts.MustRandperm(int64(totalSize), gotch.Int64, device)
imagesTs := trainImages.MustIndexSelect(0, index, false)
labelsTs := trainLabels.MustIndexSelect(0, index, false)
batches := samples / batchSize
batchIndex := 0
@ -114,36 +105,29 @@ func runCNN1() {
// Indexing
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
bImages = bImages.MustTo(vs.Device(), true)
bLabels = bLabels.MustTo(vs.Device(), true)
logits := net.ForwardT(bImages, true)
bImages.MustDrop()
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
loss := logits.CrossEntropyForLogits(bLabels)
logits.MustDrop()
bLabels.MustDrop()
loss = loss.MustSetRequiresGrad(true, true)
opt.BackwardStep(loss)
epocLoss = loss.Float64Values()[0]
loss.MustDrop()
runtime.GC()
}
ts.NoGrad(func() {
fmt.Printf("Start eval...")
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1000)
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
if testAccuracy > bestAccuracy {
bestAccuracy = testAccuracy
}
})
imagesTs.MustDrop()
labelsTs.MustDrop()
}
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
ts.CleanUp()
}

View File

@ -9,31 +9,37 @@ import (
)
const (
ImageDim int64 = 784
Label int64 = 10
MnistDir string = "../../data/mnist"
ImageDim int64 = 784
Label int64 = 10
epochs = 200
)
func runLinear() {
var ds *vision.Dataset
ds = vision.LoadMNISTDir(MnistDir)
ds = vision.LoadMNISTDir(MnistDirNN)
trainImages := ds.TrainImages.MustTo(device, true)
trainLabels := ds.TrainLabels.MustTo(device, true)
testImages := ds.TestImages.MustTo(device, true)
testLabels := ds.TestLabels.MustTo(device, true)
device := gotch.CPU
dtype := gotch.Float
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true, false)
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true, false)
for epoch := 0; epoch < epochs; epoch++ {
// NOTE(TT). if initiating with random float, result is worse.
// ws := ts.MustRandn([]int64{ImageDim, Label}, dtype, device)
// bs := ts.MustRandn([]int64{Label}, dtype, device)
// ws.MustRequiresGrad_(true)
// bs.MustRequiresGrad_(true)
for epoch := 0; epoch < epochs; epoch++ {
weight := ts.NewTensor()
reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100)
logits := ds.TrainImages.MustMm(ws, false).MustAdd(bs, true)
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(ds.TrainLabels, weight, reduction, ignoreIndex, true)
logits := trainImages.MustMm(ws, false).MustAdd(bs, true)
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(trainLabels, weight, reduction, ignoreIndex, true)
ws.ZeroGrad()
bs.ZeroGrad()
@ -42,13 +48,13 @@ func runLinear() {
ts.NoGrad(func() {
ws.Add_(ws.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
bs.Add_(bs.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
ts.CleanUp(100)
})
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
testLogits := testImages.MustMm(ws, false).MustAdd(bs, true)
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(testLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
loss.MustDrop()
}
}

View File

@ -2,19 +2,31 @@ package main
import (
"flag"
"github.com/sugarme/gotch"
)
var model string
var (
model string
deviceOpt string
device gotch.Device
)
func init() {
flag.StringVar(&model, "model", "linear", "specify a model to run")
flag.StringVar(&deviceOpt, "device", "cpu", "specify device to run on. Eitheir 'cpu' or 'cuda'")
}
func main() {
flag.Parse()
if deviceOpt == "cuda" {
device = gotch.CudaIfAvailable()
} else {
device = gotch.CPU
}
switch model {
case "linear":
runLinear()

View File

@ -3,6 +3,7 @@ package main
import (
"fmt"
"log"
"runtime"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
@ -11,16 +12,16 @@ import (
)
const (
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
MnistDirNN string = "../../data/mnist"
ImageDimNN int64 = 784
HiddenNodesNN int64 = 128
LabelNN int64 = 10
epochsNN = 200
LrNN = 1e-3
)
var MnistDirNN string = fmt.Sprintf("%s/%s", gotch.CachedDir, "mnist")
var l nn.Linear
func netInit(vs *nn.Path) ts.Module {
@ -38,7 +39,6 @@ func netInit(vs *nn.Path) ts.Module {
}
func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimizer, epoch int) {
logits := m.Forward(trainX)
loss := logits.CrossEntropyForLogits(trainY)
@ -47,26 +47,29 @@ func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimiz
testLogits := m.Forward(testX)
testAccuracy := testLogits.AccuracyForLogits(testY)
accuracy := testAccuracy.Float64Values()[0] * 100
testAccuracy.MustDrop()
lossVal := loss.Float64Values()[0]
loss.MustDrop()
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
runtime.GC()
}
func runNN() {
var ds *vision.Dataset
ds = vision.LoadMNISTDir(MnistDirNN)
vs := nn.NewVarStore(gotch.CPU)
vs := nn.NewVarStore(device)
net := netInit(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)
}
for epoch := 0; epoch < epochsNN; epoch++ {
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
}
trainImages := ds.TrainImages.MustTo(device, true)
trainLabels := ds.TrainLabels.MustTo(device, true)
testImages := ds.TestImages.MustTo(device, true)
testLabels := ds.TestLabels.MustTo(device, true)
for epoch := 0; epoch < epochsNN; epoch++ {
train(trainImages, trainLabels, testImages, testLabels, net, opt, epoch)
}
}

View File

@ -1,6 +1,7 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
@ -25,8 +26,10 @@ func main() {
panic(err)
}
err = pickle.LoadInfo(modelFile)
m, err := pickle.LoadModelInfo(modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Println(m)
}

View File

@ -113,10 +113,15 @@ var ModelUrls map[string]string = map[string]string{
// 1. Resolves input string to a fullpath cached filename candidate.
// 2. Check it at `CachedDir`, if exists, then return the candidate. If not
// 3. Retrieves and Caches data to `CachedDir` and returns path to cached data
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
func CachedPath(filenameOrUrl string, folderOpt ...string) (resolvedPath string, err error) {
filename := path.Base(filenameOrUrl)
// Resolves to "candidate" filename at `CachedDir`
cachedFileCandidate := fmt.Sprintf("%s/%s", CachedDir, filename)
fullPath := CachedDir
if len(folderOpt) > 0 {
fullPath = fmt.Sprintf("%v/%v", CachedDir, folderOpt[0])
}
cachedFileCandidate := fmt.Sprintf("%s/%s", fullPath, filename)
// 1. Cached candidate file exists
if _, err := os.Stat(cachedFileCandidate); err == nil {

File diff suppressed because it is too large Load Diff

1357
gen/gen.ml.1.11 Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

4
go.mod
View File

@ -1,8 +1,8 @@
module github.com/sugarme/gotch
go 1.14
go 1.19
require (
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5
golang.org/x/image v0.5.0
)

27
go.sum
View File

@ -1,5 +1,28 @@
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5 h1:QelT11PB4FXiDEXucrfNckHoFxwt8USGY1ajP1ZF5lM=
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI=
golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

167
half/bfloat16.go Normal file
View File

@ -0,0 +1,167 @@
package half
import (
"math"
"math/bits"
)
// A 16-bit floating point type implementing the bfloat16 format.
// Ref. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
// https://github.com/starkat99/half-rs/tree/main/src/bfloat
// The bfloat16 - Google 'brain' floating point format is a truncated 16-bit version of the IEEE 754 standard binary32.
// bfloat16 has approximately the same dynamic range as float32 (8 bits -> 3.4 × 10^38) by having a lower precision than float16.
// While float16 has a precision of 10 bits, bfloat16 has a precision of only 7 bits.
//
// +------------+------------------------+----------------------------+
// | 1-bit sign | 8-bit exponent (range) | 7-bit fraction (precision) |
// +------------+------------------------+----------------------------+
type BFloat16 uint16
// Ref.https://github.com/starkat99/half-rs/blob/cabfc74e2a48b44b4556780f9d1550dd50a708be/src/bfloat/convert.rs#L5C1-L24C1
func Float32ToBFloat16(value float32) uint16 {
// convert to raw bytes
x := math.Float32bits(value)
// Check for NaN
if (x & 0x7FFF_FFFF) > 0x7F80_0000 {
// keep high part of current mantissa but also set most significant mantissa bit
return uint16((x >> 16) | 0x0040)
}
// Round and shift
var roundBit uint32 = 0x0000_8000
if ((x & roundBit) != 0) && ((x & (3*roundBit - 1)) != 0) {
return uint16(x>>16) + 1
} else {
return uint16(x >> 16)
}
}
func Float64ToBFloat16(value float64) uint16 {
// Convert o raw bytes, truncating the last 32-bits of mantissa
// that precision will always be lost on half-precision
val := math.Float64bits(value)
x := uint32(val >> 32)
// Extract IEEE754 components
sign := x & 0x8000_0000
exp := x & 0x7FF0_0000
man := x & 0x000F_FFFF
// Check for all exponent bit being set, which is Infinity or NaN
if exp == 0x7FF0_0000 {
// Set mantissa MSB for NaN and also keep shifted mantissa bits.
// Also check the last 32 bits.
var nanBit uint32 = 0x0040
if man == 0 && (uint32(val) == 0) {
nanBit = 0
}
return uint16((sign >> 16) | 0x7F80 | nanBit | (man >> 13))
}
// The number is normalized, start assembling half precision version
halfSign := sign >> 16
// Unbias the exponent, then bias for bfloat16 precision
unbiasedExp := (int64(exp>>20) - 1023)
halfExp := unbiasedExp + 127
// Check for exponent overflow, return +infinity
if halfExp >= 0xFF {
return uint16(halfSign | 0x7F80)
}
// Check for underflow
if halfExp <= 0 {
// Check mantissa for what we can do
if 7-halfExp > 21 {
// No rounding possibility, so this is a full underflow, return signed zero
return uint16(halfSign)
}
// Don't forget about hidden leading mantissa bit when assembling mantissa
man = man | 0x0010_0000
halfMan := man >> (14 - halfExp)
// Check for rounding
var roundBit uint32 = 1 << (13 - halfExp)
if ((man & roundBit) != 0) && ((man & (3*roundBit - 1)) != 0) {
halfMan += 1
}
// No exponent for subnormals
return uint16(halfSign | halfMan)
}
// Rebias the exponent
halfExp1 := uint32(halfExp) << 7
halfMan1 := man >> 13
// Check for rounding
var roundBit1 uint32 = 0x0000_1000
if ((man & roundBit1) != 0) && ((man & (3*roundBit1 - 1)) != 0) {
// Round it
return uint16((halfSign | halfExp1 | halfMan1) + 1)
} else {
return uint16(halfSign | halfExp1 | halfMan1)
}
}
func BFloat16ToFloat32(i uint16) float32 {
// If NaN, keep current mantissa but also set most significant mantissa bit
if i&0x7FFF > 0x7F80 {
return math.Float32frombits((uint32(i) | 0x0040) << 16)
} else {
return math.Float32frombits(uint32(i) << 16)
}
}
func BFloat16ToFloat64(i uint16) float64 {
// Check for signed zero
if i&0x7FFF == 0 {
return math.Float64frombits(uint64(i) << 48)
}
halfSign := uint64(i & 0x8000)
halfExp := uint64(i & 0x7F80)
halfMan := uint64(i & 0x007F)
// Check for an infinity or NaN when all exponent bits set
if halfExp == 0x7F80 {
// Check for signed infinity if mantissa is zero
if halfMan == 0 {
return math.Float64frombits((halfSign << 48) | 0x7FF0_0000_0000_0000)
} else {
// NaN, keep current mantissa but also set most significant mantissa bit
return math.Float64frombits((halfSign << 48) | 0x7FF8_0000_0000_0000 | (halfMan << 45))
}
}
// Calculate double-precision components with adjusted exponent
sign := halfSign << 48
// Unbias exponent
unbiasedExp := (int64(halfExp) >> 7) - 127
// Check for subnormals, which will be normalized by adjusting exponent
if halfExp == 0 {
// Calculate how much to adjust the exponent by
// leading zeros uint16
e := bits.LeadingZeros16(uint16(halfMan)) - 9
// Rebias and adjust exponent
exp := (uint64(1023-127-e) << 52)
man := (halfMan << (46 + e)) & 0xF_FFFF_FFFF_FFFF
return math.Float64frombits(sign | exp | man)
}
// Rebias exponent for a normalized normal
exp := uint64(unbiasedExp+1023) << 52
man := (halfMan & 0x007F) << 45
return math.Float64frombits(sign | exp | man)
}

1
half/bfloat16_test.go Normal file
View File

@ -0,0 +1 @@
package half

303
half/float16.go Normal file
View File

@ -0,0 +1,303 @@
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
//
// Special thanks to Kathryn Long for her Rust implementation
// of float16 at github.com/starkat99/half-rs (MIT license)
// Package half defines support for half-precision floating-point numbers.
package half
import (
"math"
"strconv"
)
// Float16 represents IEEE 754 half-precision floating-point numbers (binary16).
type Float16 uint16
// Precision indicates whether the conversion to Float16 is
// exact, subnormal without dropped bits, inexact, underflow, or overflow.
type Precision int
const (
// PrecisionExact is for non-subnormals that don't drop bits during conversion.
// All of these can round-trip. Should always convert to float16.
PrecisionExact Precision = iota
// PrecisionUnknown is for subnormals that don't drop bits during conversion but
// not all of these can round-trip so precision is unknown without more effort.
// Only 2046 of these can round-trip and the rest cannot round-trip.
PrecisionUnknown
// PrecisionInexact is for dropped significand bits and cannot round-trip.
// Some of these are subnormals. Cannot round-trip float32->float16->float32.
PrecisionInexact
// PrecisionUnderflow is for Underflows. Cannot round-trip float32->float16->float32.
PrecisionUnderflow
// PrecisionOverflow is for Overflows. Cannot round-trip float32->float16->float32.
PrecisionOverflow
)
// PrecisionFromfloat32 returns Precision without performing
// the conversion. Conversions from both Infinity and NaN
// values will always report PrecisionExact even if NaN payload
// or NaN-Quiet-Bit is lost. This function is kept simple to
// allow inlining and run < 0.5 ns/op, to serve as a fast filter.
func PrecisionFromfloat32(f32 float32) Precision {
u32 := math.Float32bits(f32)
if u32 == 0 || u32 == 0x80000000 {
// +- zero will always be exact conversion
return PrecisionExact
}
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
const EXPSHIFT uint32 = 23
const EXPBIAS uint32 = 127
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
const DROPMASK uint32 = COEFMASK >> 10
exp := int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
coef := u32 & COEFMASK
if exp == 128 {
// +- infinity or NaN
// apps may want to do extra checks for NaN separately
return PrecisionExact
}
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format says,
// "Decimals between 2^24 (minimum positive subnormal) and 2^14 (maximum subnormal): fixed interval 2^24"
if exp < -24 {
return PrecisionUnderflow
}
if exp > 15 {
return PrecisionOverflow
}
if (coef & DROPMASK) != uint32(0) {
// these include subnormals and non-subnormals that dropped bits
return PrecisionInexact
}
if exp < -14 {
// Subnormals. Caller may want to test these further.
// There are 2046 subnormals that can successfully round-trip f32->f16->f32
// and 20 of those 2046 have 32-bit input coef == 0.
// RFC 7049 and 7049bis Draft 12 don't precisely define "preserves value"
// so some protocols and libraries will choose to handle subnormals differently
// when deciding to encode them to CBOR float32 vs float16.
return PrecisionUnknown
}
return PrecisionExact
}
// Frombits returns the float16 number corresponding to the IEEE 754 binary16
// representation u16, with the sign bit of u16 and the result in the same bit
// position. Frombits(Bits(x)) == x.
func Frombits(u16 uint16) Float16 {
return Float16(u16)
}
// Fromfloat32 returns a Float16 value converted from f32. Conversion uses
// IEEE default rounding (nearest int, with ties to even).
func Fromfloat32(f32 float32) Float16 {
return Float16(f32bitsToF16bits(math.Float32bits(f32)))
}
// ErrInvalidNaNValue indicates a NaN was not received.
const ErrInvalidNaNValue = float16Error("float16: invalid NaN value, expected IEEE 754 NaN")
type float16Error string
func (e float16Error) Error() string { return string(e) }
// FromNaN32ps converts nan to IEEE binary16 NaN while preserving both
// signaling and payload. Unlike Fromfloat32(), which can only return
// qNaN because it sets quiet bit = 1, this can return both sNaN and qNaN.
// If the result is infinity (sNaN with empty payload), then the
// lowest bit of payload is set to make the result a NaN.
// Returns ErrInvalidNaNValue and 0x7c01 (sNaN) if nan isn't IEEE 754 NaN.
// This function was kept simple to be able to inline.
func FromNaN32ps(nan float32) (Float16, error) {
const SNAN = Float16(uint16(0x7c01)) // signaling NaN
u32 := math.Float32bits(nan)
sign := u32 & 0x80000000
exp := u32 & 0x7f800000
coef := u32 & 0x007fffff
if (exp != 0x7f800000) || (coef == 0) {
return SNAN, ErrInvalidNaNValue
}
u16 := uint16((sign >> 16) | uint32(0x7c00) | (coef >> 13))
if (u16 & 0x03ff) == 0 {
// result became infinity, make it NaN by setting lowest bit in payload
u16 |= 0x0001
}
return Float16(u16), nil
}
// NaN returns a Float16 of IEEE 754 binary16 not-a-number (NaN).
// Returned NaN value 0x7e01 has all exponent bits = 1 with the
// first and last bits = 1 in the significand. This is consistent
// with Go's 64-bit math.NaN(). Canonical CBOR in RFC 7049 uses 0x7e00.
func NaN() Float16 {
return Float16(0x7e01)
}
// Inf returns a Float16 with an infinity value with the specified sign.
// A sign >= returns positive infinity.
// A sign < 0 returns negative infinity.
func Inf(sign int) Float16 {
if sign >= 0 {
return Float16(0x7c00)
}
return Float16(0x8000 | 0x7c00)
}
// Float32 returns a float32 converted from f (Float16).
// This is a lossless conversion.
func (f Float16) Float32() float32 {
u32 := f16bitsToF32bits(uint16(f))
return math.Float32frombits(u32)
}
// Bits returns the IEEE 754 binary16 representation of f, with the sign bit
// of f and the result in the same bit position. Bits(Frombits(x)) == x.
func (f Float16) Bits() uint16 {
return uint16(f)
}
// IsNaN reports whether f is an IEEE 754 binary16 “not-a-number” value.
func (f Float16) IsNaN() bool {
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0)
}
// IsQuietNaN reports whether f is a quiet (non-signaling) IEEE 754 binary16
// “not-a-number” value.
func (f Float16) IsQuietNaN() bool {
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0) && (f&0x0200 != 0)
}
// IsInf reports whether f is an infinity (inf).
// A sign > 0 reports whether f is positive inf.
// A sign < 0 reports whether f is negative inf.
// A sign == 0 reports whether f is either inf.
func (f Float16) IsInf(sign int) bool {
return ((f == 0x7c00) && sign >= 0) ||
(f == 0xfc00 && sign <= 0)
}
// IsFinite returns true if f is neither infinite nor NaN.
func (f Float16) IsFinite() bool {
return (uint16(f) & uint16(0x7c00)) != uint16(0x7c00)
}
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
func (f Float16) IsNormal() bool {
exp := uint16(f) & uint16(0x7c00)
return (exp != uint16(0x7c00)) && (exp != 0)
}
// Signbit reports whether f is negative or negative zero.
func (f Float16) Signbit() bool {
return (uint16(f) & uint16(0x8000)) != 0
}
// String satisfies the fmt.Stringer interface.
func (f Float16) String() string {
return strconv.FormatFloat(float64(f.Float32()), 'f', -1, 32)
}
// f16bitsToF32bits returns uint32 (float32 bits) converted from specified uint16.
func f16bitsToF32bits(in uint16) uint32 {
// All 65536 conversions with this were confirmed to be correct
// by Montgomery Edwards⁴⁴⁸ (github.com/x448).
sign := uint32(in&0x8000) << 16 // sign for 32-bit
exp := uint32(in&0x7c00) >> 10 // exponenent for 16-bit
coef := uint32(in&0x03ff) << 13 // significand for 32-bit
if exp == 0x1f {
if coef == 0 {
// infinity
return sign | 0x7f800000 | coef
}
// NaN
return sign | 0x7fc00000 | coef
}
if exp == 0 {
if coef == 0 {
// zero
return sign
}
// normalize subnormal numbers
exp++
for coef&0x7f800000 == 0 {
coef <<= 1
exp--
}
coef &= 0x007fffff
}
return sign | ((exp + (0x7f - 0xf)) << 23) | coef
}
// f32bitsToF16bits returns uint16 (Float16 bits) converted from the specified float32.
// Conversion rounds to nearest integer with ties to even.
func f32bitsToF16bits(u32 uint32) uint16 {
// Translated from Rust to Go by Montgomery Edwards⁴⁴⁸ (github.com/x448).
// All 4294967296 conversions with this were confirmed to be correct by x448.
// Original Rust implementation is by Kathryn Long (github.com/starkat99) with MIT license.
sign := u32 & 0x80000000
exp := u32 & 0x7f800000
coef := u32 & 0x007fffff
if exp == 0x7f800000 {
// NaN or Infinity
nanBit := uint32(0)
if coef != 0 {
nanBit = uint32(0x0200)
}
return uint16((sign >> 16) | uint32(0x7c00) | nanBit | (coef >> 13))
}
halfSign := sign >> 16
unbiasedExp := int32(exp>>23) - 127
halfExp := unbiasedExp + 15
if halfExp >= 0x1f {
return uint16(halfSign | uint32(0x7c00))
}
if halfExp <= 0 {
if 14-halfExp > 24 {
return uint16(halfSign)
}
c := coef | uint32(0x00800000)
halfCoef := c >> uint32(14-halfExp)
roundBit := uint32(1) << uint32(13-halfExp)
if (c&roundBit) != 0 && (c&(3*roundBit-1)) != 0 {
halfCoef++
}
return uint16(halfSign | halfCoef)
}
uHalfExp := uint32(halfExp) << 10
halfCoef := coef >> 13
roundBit := uint32(0x00001000)
if (coef&roundBit) != 0 && (coef&(3*roundBit-1)) != 0 {
return uint16((halfSign | uHalfExp | halfCoef) + 1)
}
return uint16(halfSign | uHalfExp | halfCoef)
}

View File

@ -0,0 +1,88 @@
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
package half_test
import (
"math"
"testing"
float16 "github.com/sugarme/gotch/half"
)
// prevent compiler optimizing out code by assigning to these
var resultF16 float16.Float16
var resultF32 float32
var resultStr string
var pcn float16.Precision
func BenchmarkFloat32pi(b *testing.B) {
result := float32(0)
pi32 := float32(math.Pi)
pi16 := float16.Fromfloat32(pi32)
for i := 0; i < b.N; i++ {
f16 := float16.Frombits(uint16(pi16))
result = f16.Float32()
}
resultF32 = result
}
func BenchmarkFrombits(b *testing.B) {
result := float16.Float16(0)
pi32 := float32(math.Pi)
pi16 := float16.Fromfloat32(pi32)
for i := 0; i < b.N; i++ {
result = float16.Frombits(uint16(pi16))
}
resultF16 = result
}
func BenchmarkFromFloat32pi(b *testing.B) {
result := float16.Float16(0)
pi := float32(math.Pi)
for i := 0; i < b.N; i++ {
result = float16.Fromfloat32(pi)
}
resultF16 = result
}
func BenchmarkFromFloat32nan(b *testing.B) {
result := float16.Float16(0)
nan := float32(math.NaN())
for i := 0; i < b.N; i++ {
result = float16.Fromfloat32(nan)
}
resultF16 = result
}
func BenchmarkFromFloat32subnorm(b *testing.B) {
result := float16.Float16(0)
subnorm := math.Float32frombits(0x007fffff)
for i := 0; i < b.N; i++ {
result = float16.Fromfloat32(subnorm)
}
resultF16 = result
}
func BenchmarkPrecisionFromFloat32(b *testing.B) {
var result float16.Precision
for i := 0; i < b.N; i++ {
f32 := float32(0.00001) + float32(0.00001)
result = float16.PrecisionFromfloat32(f32)
}
pcn = result
}
func BenchmarkString(b *testing.B) {
var result string
pi32 := float32(math.Pi)
pi16 := float16.Fromfloat32(pi32)
for i := 0; i < b.N; i++ {
result = pi16.String()
}
resultStr = result
}

798
half/float16_test.go Normal file
View File

@ -0,0 +1,798 @@
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
package half_test
import (
"bytes"
"crypto/sha512"
"encoding/binary"
"encoding/hex"
"fmt"
"math"
"testing"
float16 "github.com/sugarme/gotch/half"
)
// wantF32toF16bits is a tiny subset of expected values
var wantF32toF16bits = []struct {
in float32
out uint16
}{
// generated to provide 100% code coverage plus additional tests for rounding, etc.
{in: math.Float32frombits(0x00000000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00000001), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00001fff), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00002000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00003fff), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00004000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x007fffff), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x00800000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x33000000), out: 0x0000}, // in f32=0.000000, out f16=0
{in: math.Float32frombits(0x33000001), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
{in: math.Float32frombits(0x33000002), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
{in: math.Float32frombits(0x387fc000), out: 0x03ff}, // in f32=0.000061, out f16=0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
{in: math.Float32frombits(0x387fffff), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
{in: math.Float32frombits(0x38800000), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
{in: math.Float32frombits(0x38801fff), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
{in: math.Float32frombits(0x38802000), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
{in: math.Float32frombits(0x38803fff), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
{in: math.Float32frombits(0x38804000), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
{in: math.Float32frombits(0x33bfffff), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
{in: math.Float32frombits(0x33c00000), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
{in: math.Float32frombits(0x33c00001), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
{in: math.Float32frombits(0x477fffff), out: 0x7c00}, // in f32=65535.996094, out f16=+Inf
{in: math.Float32frombits(0x47800000), out: 0x7c00}, // in f32=65536.000000, out f16=+Inf
{in: math.Float32frombits(0x7f7fffff), out: 0x7c00}, // in f32=340282346638528859811704183484516925440.000000, out f16=+Inf
{in: math.Float32frombits(0x7f800000), out: 0x7c00}, // in f32=+Inf, out f16=+Inf
{in: math.Float32frombits(0x7f801fff), out: 0x7e00}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7f802000), out: 0x7e01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7f803fff), out: 0x7e01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7f804000), out: 0x7e02}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x7fffffff), out: 0x7fff}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0x80000000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80001fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80002000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80003fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80004000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x807fffff), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0x80800000), out: 0x8000}, // in f32=-0.000000, out f16=-0
{in: math.Float32frombits(0xb87fc000), out: 0x83ff}, // in f32=-0.000061, out f16=-0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
{in: math.Float32frombits(0xb87fffff), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
{in: math.Float32frombits(0xb8800000), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
{in: math.Float32frombits(0xb8801fff), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
{in: math.Float32frombits(0xb8802000), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
{in: math.Float32frombits(0xb8803fff), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
{in: math.Float32frombits(0xb8804000), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
{in: math.Float32frombits(0xc77fffff), out: 0xfc00}, // in f32=-65535.996094, out f16=-Inf
{in: math.Float32frombits(0xc7800000), out: 0xfc00}, // in f32=-65536.000000, out f16=-Inf
{in: math.Float32frombits(0xff7fffff), out: 0xfc00}, // in f32=-340282346638528859811704183484516925440.000000, out f16=-Inf
{in: math.Float32frombits(0xff800000), out: 0xfc00}, // in f32=-Inf, out f16=-Inf
{in: math.Float32frombits(0xff801fff), out: 0xfe00}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0xff802000), out: 0xfe01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0xff803fff), out: 0xfe01}, // in f32=NaN, out f16=NaN
{in: math.Float32frombits(0xff804000), out: 0xfe02}, // in f32=NaN, out f16=NaN
// additional tests
{in: math.Float32frombits(0xc77ff000), out: 0xfc00}, // in f32=-65520.000000, out f16=-Inf
{in: math.Float32frombits(0xc77fef00), out: 0xfbff}, // in f32=-65519.000000, out f16=-65504
{in: math.Float32frombits(0xc77fee00), out: 0xfbff}, // in f32=-65518.000000, out f16=-65504
{in: math.Float32frombits(0xc5802000), out: 0xec01}, // in f32=-4100.000000, out f16=-4100
{in: math.Float32frombits(0xc5801800), out: 0xec01}, // in f32=-4099.000000, out f16=-4100
{in: math.Float32frombits(0xc5801000), out: 0xec00}, // in f32=-4098.000000, out f16=-4096
{in: math.Float32frombits(0xc5800800), out: 0xec00}, // in f32=-4097.000000, out f16=-4096
{in: math.Float32frombits(0xc5800000), out: 0xec00}, // in f32=-4096.000000, out f16=-4096
{in: math.Float32frombits(0xc57ff000), out: 0xec00}, // in f32=-4095.000000, out f16=-4096
{in: math.Float32frombits(0xc57fe000), out: 0xebff}, // in f32=-4094.000000, out f16=-4094
{in: math.Float32frombits(0xc57fd000), out: 0xebfe}, // in f32=-4093.000000, out f16=-4092
{in: math.Float32frombits(0xc5002000), out: 0xe801}, // in f32=-2050.000000, out f16=-2050
{in: math.Float32frombits(0xc5001000), out: 0xe800}, // in f32=-2049.000000, out f16=-2048
{in: math.Float32frombits(0xc5000829), out: 0xe800}, // in f32=-2048.510010, out f16=-2048
{in: math.Float32frombits(0xc5000800), out: 0xe800}, // in f32=-2048.500000, out f16=-2048
{in: math.Float32frombits(0xc50007d7), out: 0xe800}, // in f32=-2048.489990, out f16=-2048
{in: math.Float32frombits(0xc5000000), out: 0xe800}, // in f32=-2048.000000, out f16=-2048
{in: math.Float32frombits(0xc4fff052), out: 0xe800}, // in f32=-2047.510010, out f16=-2048
{in: math.Float32frombits(0xc4fff000), out: 0xe800}, // in f32=-2047.500000, out f16=-2048
{in: math.Float32frombits(0xc4ffefae), out: 0xe7ff}, // in f32=-2047.489990, out f16=-2047
{in: math.Float32frombits(0xc4ffe000), out: 0xe7ff}, // in f32=-2047.000000, out f16=-2047
{in: math.Float32frombits(0xc4ffc000), out: 0xe7fe}, // in f32=-2046.000000, out f16=-2046
{in: math.Float32frombits(0xc4ffa000), out: 0xe7fd}, // in f32=-2045.000000, out f16=-2045
{in: math.Float32frombits(0xbf800000), out: 0xbc00}, // in f32=-1.000000, out f16=-1
{in: math.Float32frombits(0xbf028f5c), out: 0xb814}, // in f32=-0.510000, out f16=-0.5097656
{in: math.Float32frombits(0xbf000000), out: 0xb800}, // in f32=-0.500000, out f16=-0.5
{in: math.Float32frombits(0xbefae148), out: 0xb7d7}, // in f32=-0.490000, out f16=-0.48999023
{in: math.Float32frombits(0x3efae148), out: 0x37d7}, // in f32=0.490000, out f16=0.48999023
{in: math.Float32frombits(0x3f000000), out: 0x3800}, // in f32=0.500000, out f16=0.5
{in: math.Float32frombits(0x3f028f5c), out: 0x3814}, // in f32=0.510000, out f16=0.5097656
{in: math.Float32frombits(0x3f800000), out: 0x3c00}, // in f32=1.000000, out f16=1
{in: math.Float32frombits(0x3fbeb852), out: 0x3df6}, // in f32=1.490000, out f16=1.4902344
{in: math.Float32frombits(0x3fc00000), out: 0x3e00}, // in f32=1.500000, out f16=1.5
{in: math.Float32frombits(0x3fc147ae), out: 0x3e0a}, // in f32=1.510000, out f16=1.5097656
{in: math.Float32frombits(0x3fcf1bbd), out: 0x3e79}, // in f32=1.618034, out f16=1.6181641
{in: math.Float32frombits(0x401f5c29), out: 0x40fb}, // in f32=2.490000, out f16=2.4902344
{in: math.Float32frombits(0x40200000), out: 0x4100}, // in f32=2.500000, out f16=2.5
{in: math.Float32frombits(0x4020a3d7), out: 0x4105}, // in f32=2.510000, out f16=2.5097656
{in: math.Float32frombits(0x402df854), out: 0x4170}, // in f32=2.718282, out f16=2.71875
{in: math.Float32frombits(0x40490fdb), out: 0x4248}, // in f32=3.141593, out f16=3.140625
{in: math.Float32frombits(0x40b00000), out: 0x4580}, // in f32=5.500000, out f16=5.5
{in: math.Float32frombits(0x44ffa000), out: 0x67fd}, // in f32=2045.000000, out f16=2045
{in: math.Float32frombits(0x44ffc000), out: 0x67fe}, // in f32=2046.000000, out f16=2046
{in: math.Float32frombits(0x44ffe000), out: 0x67ff}, // in f32=2047.000000, out f16=2047
{in: math.Float32frombits(0x44ffefae), out: 0x67ff}, // in f32=2047.489990, out f16=2047
{in: math.Float32frombits(0x44fff000), out: 0x6800}, // in f32=2047.500000, out f16=2048
{in: math.Float32frombits(0x44fff052), out: 0x6800}, // in f32=2047.510010, out f16=2048
{in: math.Float32frombits(0x45000000), out: 0x6800}, // in f32=2048.000000, out f16=2048
{in: math.Float32frombits(0x450007d7), out: 0x6800}, // in f32=2048.489990, out f16=2048
{in: math.Float32frombits(0x45000800), out: 0x6800}, // in f32=2048.500000, out f16=2048
{in: math.Float32frombits(0x45000829), out: 0x6800}, // in f32=2048.510010, out f16=2048
{in: math.Float32frombits(0x45001000), out: 0x6800}, // in f32=2049.000000, out f16=2048
{in: math.Float32frombits(0x450017d7), out: 0x6801}, // in f32=2049.489990, out f16=2050
{in: math.Float32frombits(0x45001800), out: 0x6801}, // in f32=2049.500000, out f16=2050
{in: math.Float32frombits(0x45001829), out: 0x6801}, // in f32=2049.510010, out f16=2050
{in: math.Float32frombits(0x45002000), out: 0x6801}, // in f32=2050.000000, out f16=2050
{in: math.Float32frombits(0x45003000), out: 0x6802}, // in f32=2051.000000, out f16=2052
{in: math.Float32frombits(0x457fd000), out: 0x6bfe}, // in f32=4093.000000, out f16=4092
{in: math.Float32frombits(0x457fe000), out: 0x6bff}, // in f32=4094.000000, out f16=4094
{in: math.Float32frombits(0x457ff000), out: 0x6c00}, // in f32=4095.000000, out f16=4096
{in: math.Float32frombits(0x45800000), out: 0x6c00}, // in f32=4096.000000, out f16=4096
{in: math.Float32frombits(0x45800800), out: 0x6c00}, // in f32=4097.000000, out f16=4096
{in: math.Float32frombits(0x45801000), out: 0x6c00}, // in f32=4098.000000, out f16=4096
{in: math.Float32frombits(0x45801800), out: 0x6c01}, // in f32=4099.000000, out f16=4100
{in: math.Float32frombits(0x45802000), out: 0x6c01}, // in f32=4100.000000, out f16=4100
{in: math.Float32frombits(0x45ad9c00), out: 0x6d6d}, // in f32=5555.500000, out f16=5556
{in: math.Float32frombits(0x45ffe800), out: 0x6fff}, // in f32=8189.000000, out f16=8188
{in: math.Float32frombits(0x45fff000), out: 0x7000}, // in f32=8190.000000, out f16=8192
{in: math.Float32frombits(0x45fff800), out: 0x7000}, // in f32=8191.000000, out f16=8192
{in: math.Float32frombits(0x46000000), out: 0x7000}, // in f32=8192.000000, out f16=8192
{in: math.Float32frombits(0x46000400), out: 0x7000}, // in f32=8193.000000, out f16=8192
{in: math.Float32frombits(0x46000800), out: 0x7000}, // in f32=8194.000000, out f16=8192
{in: math.Float32frombits(0x46000c00), out: 0x7000}, // in f32=8195.000000, out f16=8192
{in: math.Float32frombits(0x46001000), out: 0x7000}, // in f32=8196.000000, out f16=8192
{in: math.Float32frombits(0x46001400), out: 0x7001}, // in f32=8197.000000, out f16=8200
{in: math.Float32frombits(0x46001800), out: 0x7001}, // in f32=8198.000000, out f16=8200
{in: math.Float32frombits(0x46001c00), out: 0x7001}, // in f32=8199.000000, out f16=8200
{in: math.Float32frombits(0x46002000), out: 0x7001}, // in f32=8200.000000, out f16=8200
{in: math.Float32frombits(0x46002400), out: 0x7001}, // in f32=8201.000000, out f16=8200
{in: math.Float32frombits(0x46002800), out: 0x7001}, // in f32=8202.000000, out f16=8200
{in: math.Float32frombits(0x46002c00), out: 0x7001}, // in f32=8203.000000, out f16=8200
{in: math.Float32frombits(0x46003000), out: 0x7002}, // in f32=8204.000000, out f16=8208
{in: math.Float32frombits(0x467fec00), out: 0x73ff}, // in f32=16379.000000, out f16=16376
{in: math.Float32frombits(0x467ff000), out: 0x7400}, // in f32=16380.000000, out f16=16384
{in: math.Float32frombits(0x467ff400), out: 0x7400}, // in f32=16381.000000, out f16=16384
{in: math.Float32frombits(0x467ff800), out: 0x7400}, // in f32=16382.000000, out f16=16384
{in: math.Float32frombits(0x467ffc00), out: 0x7400}, // in f32=16383.000000, out f16=16384
{in: math.Float32frombits(0x46800000), out: 0x7400}, // in f32=16384.000000, out f16=16384
{in: math.Float32frombits(0x46800200), out: 0x7400}, // in f32=16385.000000, out f16=16384
{in: math.Float32frombits(0x46800400), out: 0x7400}, // in f32=16386.000000, out f16=16384
{in: math.Float32frombits(0x46800600), out: 0x7400}, // in f32=16387.000000, out f16=16384
{in: math.Float32frombits(0x46800800), out: 0x7400}, // in f32=16388.000000, out f16=16384
{in: math.Float32frombits(0x46800a00), out: 0x7400}, // in f32=16389.000000, out f16=16384
{in: math.Float32frombits(0x46800c00), out: 0x7400}, // in f32=16390.000000, out f16=16384
{in: math.Float32frombits(0x46800e00), out: 0x7400}, // in f32=16391.000000, out f16=16384
{in: math.Float32frombits(0x46801000), out: 0x7400}, // in f32=16392.000000, out f16=16384
{in: math.Float32frombits(0x46801200), out: 0x7401}, // in f32=16393.000000, out f16=16400
{in: math.Float32frombits(0x46801400), out: 0x7401}, // in f32=16394.000000, out f16=16400
{in: math.Float32frombits(0x46801600), out: 0x7401}, // in f32=16395.000000, out f16=16400
{in: math.Float32frombits(0x46801800), out: 0x7401}, // in f32=16396.000000, out f16=16400
{in: math.Float32frombits(0x46801a00), out: 0x7401}, // in f32=16397.000000, out f16=16400
{in: math.Float32frombits(0x46801c00), out: 0x7401}, // in f32=16398.000000, out f16=16400
{in: math.Float32frombits(0x46801e00), out: 0x7401}, // in f32=16399.000000, out f16=16400
{in: math.Float32frombits(0x46802000), out: 0x7401}, // in f32=16400.000000, out f16=16400
{in: math.Float32frombits(0x46802200), out: 0x7401}, // in f32=16401.000000, out f16=16400
{in: math.Float32frombits(0x46802400), out: 0x7401}, // in f32=16402.000000, out f16=16400
{in: math.Float32frombits(0x46802600), out: 0x7401}, // in f32=16403.000000, out f16=16400
{in: math.Float32frombits(0x46802800), out: 0x7401}, // in f32=16404.000000, out f16=16400
{in: math.Float32frombits(0x46802a00), out: 0x7401}, // in f32=16405.000000, out f16=16400
{in: math.Float32frombits(0x46802c00), out: 0x7401}, // in f32=16406.000000, out f16=16400
{in: math.Float32frombits(0x46802e00), out: 0x7401}, // in f32=16407.000000, out f16=16400
{in: math.Float32frombits(0x46803000), out: 0x7402}, // in f32=16408.000000, out f16=16416
{in: math.Float32frombits(0x46ffee00), out: 0x77ff}, // in f32=32759.000000, out f16=32752
{in: math.Float32frombits(0x46fff000), out: 0x7800}, // in f32=32760.000000, out f16=32768
{in: math.Float32frombits(0x46fff200), out: 0x7800}, // in f32=32761.000000, out f16=32768
{in: math.Float32frombits(0x46fff400), out: 0x7800}, // in f32=32762.000000, out f16=32768
{in: math.Float32frombits(0x46fff600), out: 0x7800}, // in f32=32763.000000, out f16=32768
{in: math.Float32frombits(0x46fff800), out: 0x7800}, // in f32=32764.000000, out f16=32768
{in: math.Float32frombits(0x46fffa00), out: 0x7800}, // in f32=32765.000000, out f16=32768
{in: math.Float32frombits(0x46fffc00), out: 0x7800}, // in f32=32766.000000, out f16=32768
{in: math.Float32frombits(0x46fffe00), out: 0x7800}, // in f32=32767.000000, out f16=32768
{in: math.Float32frombits(0x47000000), out: 0x7800}, // in f32=32768.000000, out f16=32768
{in: math.Float32frombits(0x47000100), out: 0x7800}, // in f32=32769.000000, out f16=32768
{in: math.Float32frombits(0x47000200), out: 0x7800}, // in f32=32770.000000, out f16=32768
{in: math.Float32frombits(0x47000300), out: 0x7800}, // in f32=32771.000000, out f16=32768
{in: math.Float32frombits(0x47000400), out: 0x7800}, // in f32=32772.000000, out f16=32768
{in: math.Float32frombits(0x47000500), out: 0x7800}, // in f32=32773.000000, out f16=32768
{in: math.Float32frombits(0x47000600), out: 0x7800}, // in f32=32774.000000, out f16=32768
{in: math.Float32frombits(0x47000700), out: 0x7800}, // in f32=32775.000000, out f16=32768
{in: math.Float32frombits(0x47000800), out: 0x7800}, // in f32=32776.000000, out f16=32768
{in: math.Float32frombits(0x47000900), out: 0x7800}, // in f32=32777.000000, out f16=32768
{in: math.Float32frombits(0x47000a00), out: 0x7800}, // in f32=32778.000000, out f16=32768
{in: math.Float32frombits(0x47000b00), out: 0x7800}, // in f32=32779.000000, out f16=32768
{in: math.Float32frombits(0x47000c00), out: 0x7800}, // in f32=32780.000000, out f16=32768
{in: math.Float32frombits(0x47000d00), out: 0x7800}, // in f32=32781.000000, out f16=32768
{in: math.Float32frombits(0x47000e00), out: 0x7800}, // in f32=32782.000000, out f16=32768
{in: math.Float32frombits(0x47000f00), out: 0x7800}, // in f32=32783.000000, out f16=32768
{in: math.Float32frombits(0x47001000), out: 0x7800}, // in f32=32784.000000, out f16=32768
{in: math.Float32frombits(0x47001100), out: 0x7801}, // in f32=32785.000000, out f16=32800
{in: math.Float32frombits(0x47001200), out: 0x7801}, // in f32=32786.000000, out f16=32800
{in: math.Float32frombits(0x47001300), out: 0x7801}, // in f32=32787.000000, out f16=32800
{in: math.Float32frombits(0x47001400), out: 0x7801}, // in f32=32788.000000, out f16=32800
{in: math.Float32frombits(0x47001500), out: 0x7801}, // in f32=32789.000000, out f16=32800
{in: math.Float32frombits(0x47001600), out: 0x7801}, // in f32=32790.000000, out f16=32800
{in: math.Float32frombits(0x47001700), out: 0x7801}, // in f32=32791.000000, out f16=32800
{in: math.Float32frombits(0x47001800), out: 0x7801}, // in f32=32792.000000, out f16=32800
{in: math.Float32frombits(0x47001900), out: 0x7801}, // in f32=32793.000000, out f16=32800
{in: math.Float32frombits(0x47001a00), out: 0x7801}, // in f32=32794.000000, out f16=32800
{in: math.Float32frombits(0x47001b00), out: 0x7801}, // in f32=32795.000000, out f16=32800
{in: math.Float32frombits(0x47001c00), out: 0x7801}, // in f32=32796.000000, out f16=32800
{in: math.Float32frombits(0x47001d00), out: 0x7801}, // in f32=32797.000000, out f16=32800
{in: math.Float32frombits(0x47001e00), out: 0x7801}, // in f32=32798.000000, out f16=32800
{in: math.Float32frombits(0x47001f00), out: 0x7801}, // in f32=32799.000000, out f16=32800
{in: math.Float32frombits(0x47002000), out: 0x7801}, // in f32=32800.000000, out f16=32800
{in: math.Float32frombits(0x47002100), out: 0x7801}, // in f32=32801.000000, out f16=32800
{in: math.Float32frombits(0x47002200), out: 0x7801}, // in f32=32802.000000, out f16=32800
{in: math.Float32frombits(0x47002300), out: 0x7801}, // in f32=32803.000000, out f16=32800
{in: math.Float32frombits(0x47002400), out: 0x7801}, // in f32=32804.000000, out f16=32800
{in: math.Float32frombits(0x47002500), out: 0x7801}, // in f32=32805.000000, out f16=32800
{in: math.Float32frombits(0x47002600), out: 0x7801}, // in f32=32806.000000, out f16=32800
{in: math.Float32frombits(0x47002700), out: 0x7801}, // in f32=32807.000000, out f16=32800
{in: math.Float32frombits(0x47002800), out: 0x7801}, // in f32=32808.000000, out f16=32800
{in: math.Float32frombits(0x47002900), out: 0x7801}, // in f32=32809.000000, out f16=32800
{in: math.Float32frombits(0x47002a00), out: 0x7801}, // in f32=32810.000000, out f16=32800
{in: math.Float32frombits(0x47002b00), out: 0x7801}, // in f32=32811.000000, out f16=32800
{in: math.Float32frombits(0x47002c00), out: 0x7801}, // in f32=32812.000000, out f16=32800
{in: math.Float32frombits(0x47002d00), out: 0x7801}, // in f32=32813.000000, out f16=32800
{in: math.Float32frombits(0x47002e00), out: 0x7801}, // in f32=32814.000000, out f16=32800
{in: math.Float32frombits(0x47002f00), out: 0x7801}, // in f32=32815.000000, out f16=32800
{in: math.Float32frombits(0x47003000), out: 0x7802}, // in f32=32816.000000, out f16=32832
{in: math.Float32frombits(0x477fe500), out: 0x7bff}, // in f32=65509.000000, out f16=65504
{in: math.Float32frombits(0x477fe100), out: 0x7bff}, // in f32=65505.000000, out f16=65504
{in: math.Float32frombits(0x477fee00), out: 0x7bff}, // in f32=65518.000000, out f16=65504
{in: math.Float32frombits(0x477fef00), out: 0x7bff}, // in f32=65519.000000, out f16=65504
{in: math.Float32frombits(0x477feffd), out: 0x7bff}, // in f32=65519.988281, out f16=65504
{in: math.Float32frombits(0x477ff000), out: 0x7c00}, // in f32=65520.000000, out f16=+Inf
}
func TestPrecisionFromfloat32(t *testing.T) {
for i, v := range wantF32toF16bits {
f16 := float16.Fromfloat32(v.in)
u16 := uint16(f16)
if u16 != v.out {
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
}
checkPrecision(t, v.in, f16, uint64(i))
}
f32 := float32(5.5) // value that doesn't drop any bits in the significand, is within normal exponent range
pre := float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionExact {
t.Errorf("f32bits=0x%08x, wanted=PrecisionExact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionExact, pre)
}
f32 = math.Float32frombits(0x38000000) // subnormal value with coef = 0 that can round-trip float32->float16->float32
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnknown {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
}
f32 = math.Float32frombits(0x387fc000) // subnormal value with coef !=0 that can round-trip float32->float16->float32
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnknown {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
}
f32 = math.Float32frombits(0x33c00000) // subnormal value with no dropped bits that cannot round-trip float32->float16->float32
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnknown {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
}
f32 = math.Float32frombits(0x38000001) // subnormal value with dropped non-zero bits > 0
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionInexact {
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
}
f32 = float32(math.Pi) // value that cannot "preserve value" because it drops bits in the significand
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionInexact {
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
}
f32 = math.Float32frombits(0x1) // value that will underflow
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnderflow {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
}
f32 = math.Float32frombits(0x33000000) // value that will underflow
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionUnderflow {
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
}
f32 = math.Float32frombits(0x47800000) // value that will overflow
pre = float16.PrecisionFromfloat32(f32)
if pre != float16.PrecisionOverflow {
t.Errorf("f32bits=0x%08x, wanted=PrecisionOverflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionOverflow, pre)
}
}
func TestFromNaN32ps(t *testing.T) {
for i, v := range wantF32toF16bits {
f16 := float16.Fromfloat32(v.in)
u16 := uint16(f16)
if u16 != v.out {
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
}
checkFromNaN32ps(t, v.in, f16)
}
// since checkFromNaN32ps rejects non-NaN input, try one here
nan, err := float16.FromNaN32ps(float32(math.Pi))
if err != float16.ErrInvalidNaNValue {
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted err float16.ErrInvalidNaNValue, got err = %q", err)
}
if err.Error() != "float16: invalid NaN value, expected IEEE 754 NaN" {
t.Errorf("unexpected string value returned by err.Error() for ErrInvalidNaNValue: %s", err.Error())
}
if uint16(nan) != 0x7c01 { // signaling NaN
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted nan = 0x7c01, got nan = 0x%04x", uint16(nan))
}
}
// Test a small subset of possible conversions from float32 to Float16.
// TestSomeFromFloat32 runs in under 1 second while TestAllFromFloat32 takes about 45 seconds.
func TestSomeFromFloat32(t *testing.T) {
for i, v := range wantF32toF16bits {
f16 := float16.Fromfloat32(v.in)
u16 := uint16(f16)
if u16 != v.out {
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
}
}
}
// Test all possible 4294967296 float32 input values and results for
// Fromfloat32(), FromNaN32ps(), and PrecisionFromfloat32().
func TestAllFromFloat32(t *testing.T) {
if testing.Short() {
t.Skip("skipping TestAllFromFloat32 in short mode.")
}
fmt.Printf("WARNING: TestAllFromFloat32 should take about 1-2 minutes to run on amd64, other platforms may take longer...\n")
// Blake2b is "3f310bc5608a087462d361644fe66feeb4c68145f6f18eb6f1439cd7914888b6df9e30ae5350dce0635162cc6a2f23b31b3e4353ca132a3c552bdbd58baa54e6"
const wantSHA512 = "08670429a475164d6c4a080969e35231c77ef7069b430b5f38af22e013796b7818bbe8f5942a6ddf26de0e1dfc67d02243f483d85729ebc3762fc2948a5ca1f8"
const batchSize uint32 = 16384
results := make([]uint16, batchSize)
buf := new(bytes.Buffer)
h := sha512.New()
for i := uint64(0); i < uint64(0xFFFFFFFF); i += uint64(batchSize) {
// fill results
for j := uint32(0); j < batchSize; j++ {
inF32 := math.Float32frombits(uint32(i) + j)
f16 := float16.Fromfloat32(inF32)
results[j] = uint16(f16)
checkPrecision(t, inF32, f16, i)
checkFromNaN32ps(t, inF32, f16)
}
// convert results to []byte
err := binary.Write(buf, binary.LittleEndian, results)
if err != nil {
panic(err)
}
// update hash with []byte of results
_, err = h.Write(buf.Bytes())
if err != nil {
panic(err)
}
buf.Reset()
}
// display hash digest in hex
digest := h.Sum(nil)
gotSHA512hex := hex.EncodeToString(digest)
if gotSHA512hex != wantSHA512 {
t.Errorf("gotSHA512hex = %s", gotSHA512hex)
}
}
// Test all 65536 conversions from float16 to float32.
// TestAllToFloat32 runs in under 1 second.
func TestAllToFloat32(t *testing.T) {
// Blake2b is "078d8e3fac9480de1493f22c8f9bfc1eb2051537c536f00f621557d70eed1af057a487c3e252f6d593769f5288d5ab66d8e9cd1adba359838802944bdb731f4d"
const wantSHA512 = "1a4ccec9fd7b6e83310c6b4958a25778cd95f8d4f88b19950e4b8d6932a955f7fbd96b1c9bd9b2a79c3a9d34d653f55e671f8f86e6a5a876660cd38479001aa6"
const batchSize uint32 = 16384
results := make([]float32, batchSize)
buf := new(bytes.Buffer)
h := sha512.New()
for i := uint64(0); i < uint64(0xFFFF); i += uint64(batchSize) {
// fill results
for j := uint32(0); j < batchSize; j++ {
inU16 := uint16(i) + uint16(j)
f16 := float16.Float16(inU16)
results[j] = f16.Float32()
}
// convert results to []byte
err := binary.Write(buf, binary.LittleEndian, results)
if err != nil {
panic(err)
}
// update hash with []byte of results
_, err = h.Write(buf.Bytes())
if err != nil {
panic(err)
}
buf.Reset()
}
// display hash digest in hex
digest := h.Sum(nil)
gotSHA512hex := hex.EncodeToString(digest)
if gotSHA512hex != wantSHA512 {
t.Errorf("Float16toFloat32: gotSHA512hex = %s", gotSHA512hex)
}
}
func TestFrombits(t *testing.T) {
x := uint16(0x1234)
f16 := float16.Frombits(x)
if uint16(f16) != f16.Bits() || uint16(f16) != x {
t.Errorf("float16.Frombits(0x7fff) returned %04x, wanted %04x", uint16(f16), x)
}
}
func TestNaN(t *testing.T) {
nan := float16.NaN()
if !nan.IsNaN() {
t.Errorf("nan.IsNaN() returned false, wanted true")
}
}
func TestInf(t *testing.T) {
posInf := float16.Inf(0)
if uint16(posInf) != 0x7c00 {
t.Errorf("float16.Inf(0) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
}
posInf = float16.Inf(1)
if uint16(posInf) != 0x7c00 {
t.Errorf("float16.Inf(1) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
}
negInf := float16.Inf(-1)
if uint16(negInf) != 0xfc00 {
t.Errorf("float16.Inf(-1) returned %04x, wanted %04x", uint16(negInf), 0xfc00)
}
}
func TestBits(t *testing.T) {
x := uint16(0x1234)
f16 := float16.Frombits(x)
if uint16(f16) != f16.Bits() || f16.Bits() != x {
t.Errorf("Bits() returned %04x, wanted %04x", uint16(f16), x)
}
}
func TestIsFinite(t *testing.T) {
// IsFinite returns true if f is neither infinite nor NaN.
finite := float16.Fromfloat32(float32(1.5))
if !finite.IsFinite() {
t.Errorf("finite.Infinite() returned false, wanted true")
}
posInf := float16.Inf(0)
if posInf.IsFinite() {
t.Errorf("posInf.Infinite() returned true, wanted false")
}
negInf := float16.Inf(-1)
if negInf.IsFinite() {
t.Errorf("negInf.Infinite() returned true, wanted false")
}
nan := float16.NaN()
if nan.IsFinite() {
t.Errorf("nan.Infinite() returned true, wanted false")
}
}
func TestIsNaN(t *testing.T) {
f16 := float16.Float16(0)
if f16.IsNaN() {
t.Errorf("Float16(0).IsNaN() returned true, wanted false")
}
f16 = float16.Float16(0x7e00)
if !f16.IsNaN() {
t.Errorf("Float16(0x7e00).IsNaN() returned false, wanted true")
}
}
func TestIsQuietNaN(t *testing.T) {
f16 := float16.Float16(0)
if f16.IsQuietNaN() {
t.Errorf("Float16(0).IsQuietNaN() returned true, wanted false")
}
f16 = float16.Float16(0x7e00)
if !f16.IsQuietNaN() {
t.Errorf("Float16(0x7e00).IsQuietNaN() returned false, wanted true")
}
f16 = float16.Float16(0x7e00 ^ 0x0200)
if f16.IsQuietNaN() {
t.Errorf("Float16(0x7e00 ^ 0x0200).IsQuietNaN() returned true, wanted false")
}
}
func TestIsNormal(t *testing.T) {
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
zero := float16.Frombits(0)
if zero.IsNormal() {
t.Errorf("zero.IsNormal() returned true, wanted false")
}
posInf := float16.Inf(0)
if posInf.IsNormal() {
t.Errorf("posInf.IsNormal() returned true, wanted false")
}
negInf := float16.Inf(-1)
if negInf.IsNormal() {
t.Errorf("negInf.IsNormal() returned true, wanted false")
}
nan := float16.NaN()
if nan.IsNormal() {
t.Errorf("nan.IsNormal() returned true, wanted false")
}
subnormal := float16.Frombits(0x0001)
if subnormal.IsNormal() {
t.Errorf("subnormal.IsNormal() returned true, wanted false")
}
normal := float16.Fromfloat32(float32(1.5))
if !normal.IsNormal() {
t.Errorf("normal.IsNormal() returned false, wanted true")
}
}
func TestSignbit(t *testing.T) {
f16 := float16.Fromfloat32(float32(0.0))
if f16.Signbit() {
t.Errorf("float16.Fromfloat32(float32(0)).Signbit() returned true, wanted false")
}
f16 = float16.Fromfloat32(float32(2.0))
if f16.Signbit() {
t.Errorf("float16.Fromfloat32(float32(2)).Signbit() returned true, wanted false")
}
f16 = float16.Fromfloat32(float32(-2.0))
if !f16.Signbit() {
t.Errorf("float16.Fromfloat32(float32(-2)).Signbit() returned false, wanted true")
}
}
func TestString(t *testing.T) {
f16 := float16.Fromfloat32(1.5)
s := f16.String()
if s != "1.5" {
t.Errorf("Float16(1.5).String() returned %s, wanted 1.5", s)
}
f16 = float16.Fromfloat32(3.141593)
s = f16.String()
if s != "3.140625" {
t.Errorf("Float16(3.141593).String() returned %s, wanted 3.140625", s)
}
}
func TestIsInf(t *testing.T) {
f16 := float16.Float16(0)
if f16.IsInf(0) {
t.Errorf("Float16(0).IsInf(0) returned true, wanted false")
}
f16 = float16.Float16(0x7c00)
if !f16.IsInf(0) {
t.Errorf("Float16(0x7c00).IsInf(0) returned false, wanted true")
}
f16 = float16.Float16(0x7c00)
if !f16.IsInf(1) {
t.Errorf("Float16(0x7c00).IsInf(1) returned false, wanted true")
}
f16 = float16.Float16(0x7c00)
if f16.IsInf(-1) {
t.Errorf("Float16(0x7c00).IsInf(-1) returned true, wanted false")
}
f16 = float16.Float16(0xfc00)
if !f16.IsInf(0) {
t.Errorf("Float16(0xfc00).IsInf(0) returned false, wanted true")
}
f16 = float16.Float16(0xfc00)
if f16.IsInf(1) {
t.Errorf("Float16(0xfc00).IsInf(1) returned true, wanted false")
}
f16 = float16.Float16(0xfc00)
if !f16.IsInf(-1) {
t.Errorf("Float16(0xfc00).IsInf(-1) returned false, wanted true")
}
}
func float32parts(f32 float32) (exp int32, coef uint32, dropped uint32) {
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
const EXPSHIFT uint32 = 23
const EXPBIAS uint32 = 127
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
const DROPMASK uint32 = COEFMASK >> 10
u32 := math.Float32bits(f32)
exp = int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
coef = u32 & COEFMASK
dropped = coef & DROPMASK
return exp, coef, dropped
}
func isNaN32(f32 float32) bool {
exp, coef, _ := float32parts(f32)
return (exp == 128) && (coef != 0)
}
func isQuietNaN32(f32 float32) bool {
exp, coef, _ := float32parts(f32)
return (exp == 128) && (coef != 0) && ((coef & 0x00400000) != 0)
}
func checkFromNaN32ps(t *testing.T, f32 float32, f16 float16.Float16) {
if !isNaN32(f32) {
return
}
u32 := math.Float32bits(f32)
nan16, err := float16.FromNaN32ps(f32)
if isQuietNaN32(f32) {
// result should be the same
if err != nil {
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
}
if uint16(nan16) != uint16(f16) {
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted nan16 = %v, got nan16 = %v", u32, f32, f16, nan16)
}
} else {
// result should differ only by the signaling/quiet bit unless payload is empty
if err != nil {
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
}
coef := uint16(f16) & uint16(0x03ff)
payload := uint16(f16) & uint16(0x01ff)
diff := uint16(nan16 ^ f16)
if payload == 0 {
// the lowest bit needed to be set to prevent turning sNaN into infinity, so 2 bits differ
if diff != 0x0201 {
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0201, got 0x%04x", u32, f32, diff)
}
} else {
// only the quiet bit was restored, so 1 bit differs
if diff != 0x0200 {
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0200, got 0x%04x. f16=0x%04x n16=0x%04x coef=0x%04x", u32, f32, diff, uint16(f16), uint16(nan16), coef)
}
}
}
}
func checkPrecision(t *testing.T, f32 float32, f16 float16.Float16, i uint64) {
// TODO: rewrite this test when time allows
u32 := math.Float32bits(f32)
u16 := f16.Bits()
f32bis := f16.Float32()
u32bis := math.Float32bits(f32bis)
pre := float16.PrecisionFromfloat32(f32)
roundtripped := u32 == u32bis
exp32, coef32, dropped32 := float32parts(f32)
if roundtripped {
checkRoundTrippedPrecision(t, u32, u16, u32bis, exp32, coef32, dropped32)
return
}
if pre == float16.PrecisionExact {
// this should only happen if both input and output are NaN
if !(f16.IsNaN() && isNaN32(f32)) {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionExact when roundtrip failed with non-special value", i, u32, f32, u16, u32bis, f32bis)
}
} else if pre == float16.PrecisionUnknown {
if exp32 < -24 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionUnderflow", i, u32, f32, u16, u32bis, f32bis)
}
if dropped32 != 0 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionInexact", i, u32, f32, u16, u32bis, f32bis)
}
} else if pre == float16.PrecisionInexact {
checkPrecisionInexact(t, u32, u16, u32bis, exp32, coef32, dropped32)
} else if pre == float16.PrecisionUnderflow {
if exp32 >= -14 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnderflow when exp32 is >= -14", i, u32, f32, u16, u32bis, f32bis)
}
} else if pre == float16.PrecisionOverflow {
if exp32 <= 15 {
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionOverflow when exp32 is <= 15", i, u32, f32, u16, u32bis, f32bis)
}
}
}
func checkPrecisionInexact(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
f32 := math.Float32frombits(u32)
f32bis := math.Float32frombits(u32bis)
if exp32 < -24 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionUnderflow", u32, f32, u16, u32bis, f32bis)
}
if exp32 > 15 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionOverflow", u32, f32, u16, u32bis, f32bis)
}
if coef32 == 0 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when coef32 is 0", u32, f32, u16, u32bis, f32bis)
}
if dropped32 == 0 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when dropped32 is 0", u32, f32, u16, u32bis, f32bis)
}
}
func checkRoundTrippedPrecision(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
f32 := math.Float32frombits(u32)
f32bis := math.Float32frombits(u32bis)
pre := float16.PrecisionFromfloat32(f32)
f16 := float16.Frombits(u16)
if dropped32 != 0 {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), dropped32 != 0 with successful roundtrip", u32, f32, u16, u32bis, f32bis)
}
if pre != float16.PrecisionExact {
// there are 2046 values that are subnormal and can round-trip float32->float16->float32
if pre != float16.PrecisionUnknown {
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%032b) (%f), out f16bits=0x%04x (%v), back=0x%08x (%f), got %v, wanted PrecisionExact, exp=%d, coef=%d, drpd=%d", u32, u32, f32, u16, f16, u32bis, f32bis, pre, exp32, coef32, dropped32)
}
}
}

12
init.go
View File

@ -4,11 +4,14 @@ import (
"fmt"
"log"
"os"
"strconv"
)
var (
CachedDir string = "NOT_SETTING"
gotchEnvKey string = "GOTCH_CACHE"
CachedDir string = "NOT_SETTING"
gotchEnvKey string = "GOTCH_CACHE"
gotchDebugKey string = "GOTCH_DEBUG"
Debug bool = false
)
func init() {
@ -16,10 +19,13 @@ func init() {
CachedDir = fmt.Sprintf("%s/.cache/gotch", homeDir) // default dir: "{$HOME}/.cache/gotch"
initEnv()
// log.Printf("INFO: CacheDir=%q\n", CacheDir)
}
func initEnv() {
if v, err := strconv.ParseBool(os.Getenv(gotchDebugKey)); err == nil {
Debug = v
}
val := os.Getenv(gotchEnvKey)
if val != "" {
CachedDir = val

File diff suppressed because it is too large Load Diff

View File

@ -4,10 +4,25 @@ package libtch
//#include "stdbool.h"
//#include "torch_api.h"
/*
bool is_null(int* pointer) {
if (NULL == pointer) {
return true;
}
return false;
}
*/
import "C"
import "unsafe"
func IsNull(ctensor Ctensor) bool {
// return C.is_null(ctensor)
ret := C.is_null((*C.int)(unsafe.Pointer(ctensor)))
return (bool)(ret)
}
// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`:
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);

View File

@ -59,12 +59,6 @@ func NewTensor() Ctensor {
return C.at_new_tensor()
}
// int at_device(tensor);
func AtDevice(ts Ctensor) int {
cint := C.at_device(ts)
return *(*int)(unsafe.Pointer(&cint))
}
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
@ -88,6 +82,30 @@ func AtDataPtr(t Ctensor) unsafe.Pointer {
return C.at_data_ptr(t)
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_mkldnn(tensor);
func AtIsMkldnn(ts Ctensor) bool {
retVal := C.at_is_mkldnn(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_device(tensor);
func AtDevice(ts Ctensor) int {
cint := C.at_device(ts)
return *(*int)(unsafe.Pointer(&cint))
}
// size_t at_dim(tensor);
func AtDim(t Ctensor) uint64 {
result := C.at_dim(t)
@ -100,12 +118,24 @@ func AtShape(t Ctensor, ptr unsafe.Pointer) {
C.at_shape(t, c_ptr)
}
// void at_stride(tensor, int64_t *);
func AtStride(t Ctensor, ptr unsafe.Pointer) {
c_ptr := (*C.int64_t)(ptr)
C.at_stride(t, c_ptr)
}
// int at_scalar_type(tensor);
func AtScalarType(t Ctensor) int32 {
result := C.at_scalar_type(t)
return *(*int32)(unsafe.Pointer(&result))
}
// int at_is_contiguous(tensor);
func AtIsContiguous(ts Ctensor) bool {
retVal := C.at_is_contiguous(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
func GetAndResetLastErr() *C.char {
return C.get_and_reset_last_err()
}
@ -134,6 +164,24 @@ func AtcSetBenchmarkCudnn(b int) {
C.atc_set_benchmark_cudnn(cb)
}
// void atc_synchronize(int64_t device_index);
func AtcSynchronize(deviceIndex int64) {
cDeviceIndex := *(*C.int64_t)(unsafe.Pointer(&deviceIndex))
C.atc_synchronize(cDeviceIndex)
}
// int atc_get_device();
func AtcGetDevice() int {
cDeviceIndex := C.atc_get_device()
return int(cDeviceIndex)
}
// int atc_set_device(int device_index);
func AtcSetDevice(deviceIndex int) int {
cDeviceIndex := C.int(deviceIndex)
return int(cDeviceIndex)
}
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
ctensor := (C.tensor)(ts)
@ -158,18 +206,6 @@ func AtRequiresGrad(ts Ctensor) bool {
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_defined(tensor);
func AtDefined(ts Ctensor) bool {
retVal := C.at_defined(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// int at_is_sparse(tensor);
func AtIsSparse(ts Ctensor) bool {
retVal := C.at_is_sparse(ts)
return *(*bool)(unsafe.Pointer(&retVal))
}
// void at_backward(tensor, int, int);
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
@ -364,7 +400,7 @@ func AtFree(ts Ctensor) {
C.at_free(ts)
}
//int at_grad_set_enabled(int b);
// int at_grad_set_enabled(int b);
func AtGradSetEnabled(b int) int {
cbool := *(*C.int)(unsafe.Pointer(&b))
cretVal := C.at_grad_set_enabled(cbool)
@ -870,3 +906,13 @@ func AtoConstantPadNd(ptr *Ctensor, self Ctensor, padData []int64, padLen int, v
cpadLen := *(*C.int)(unsafe.Pointer(&padLen))
C.ato_constant_pad_nd(ptr, self, cpadDataPtr, cpadLen, value)
}
// // NOTE. TT. added to test new API generated
// func AtgRandn1(sizeData []int64, sizeLen int, optionsKind int32, optionsDevice int32) Ctensor {
// csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
// csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
// coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
// coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
//
// return C.atg_randn1(csizeDataPtr, csizeLen, coptionsKind, coptionsDevice)
// }

View File

@ -1,14 +1,14 @@
#include<torch/csrc/autograd/engine.h>
#include<torch/csrc/jit/runtime/graph_executor.h>
#include "torch_api.h"
#include "ATen/core/interned_strings.h"
#include <ATen/autocast_mode.h>
#include <stdexcept>
#include <torch/csrc/autograd/engine.h>
#include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
#include <torch/csrc/jit/passes/normalize_ops.h>
#include<torch/torch.h>
#include<ATen/autocast_mode.h>
#include<torch/script.h>
#include<stdexcept>
#include<vector>
#include "torch_api.h"
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <vector>
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
@ -36,15 +36,20 @@ vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len) {
return result;
}
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs, int len) {
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
int len) {
vector<c10::optional<torch::Tensor>> result;
for (int i = 0; i < len; ++i) {
result.push_back(vs[i] != nullptr ? c10::optional<torch::Tensor>(*(vs[i])) : c10::nullopt);
result.push_back(vs[i] != nullptr ? c10::optional<torch::Tensor>(*(vs[i]))
: c10::nullopt);
}
return c10::List<c10::optional<torch::Tensor>>(result);
}
at::Device device_of_int(int d) {
if (d == -3)
return at::Device(at::kVulkan);
// if (d == -2) return at::Device(at::kMPS);
if (d < 0)
return at::Device(at::kCPU);
return at::Device(at::kCUDA, /*index=*/d);
@ -81,19 +86,20 @@ tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
void at_copy_data(tensor tensor, void *vs, size_t numel,
size_t elt_size_in_bytes) {
PROTECT(if (elt_size_in_bytes != tensor->element_size()) throw std::
invalid_argument("incoherent element sizes in bytes");
if (numel > tensor->numel()) throw std::invalid_argument(
"target numel is larger than tensor numel");
if (tensor->device().type() != at::kCPU) {
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
} else {
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
})
PROTECT(
if (elt_size_in_bytes != tensor->element_size()) throw std::
invalid_argument("incoherent element sizes in bytes");
if (numel > tensor->numel()) throw std::invalid_argument(
"target numel is larger than tensor numel");
if (tensor->device().type() != at::kCPU) {
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
} else {
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
})
}
tensor at_shallow_clone(tensor t) {
@ -139,15 +145,20 @@ int at_scalar_type(tensor t) {
return -1;
}
int at_is_contiguous(tensor t) {
PROTECT(return t->is_contiguous();)
return -1;
}
// void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor
// inf_scale) { PROTECT( at::_amp_non_finite_check_and_unscale_(*t, *found_inf,
// *inf_scale);
// )
// }
void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor inf_scale) {
PROTECT(
at::_amp_foreach_non_finite_check_and_unscale_(*t, *found_inf, *inf_scale);
)
void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf,
tensor inf_scale) {
PROTECT(at::_amp_foreach_non_finite_check_and_unscale_(*t, *found_inf,
*inf_scale);)
}
void at_autocast_clear_cache() { at::autocast::clear_cache(); }
@ -176,7 +187,7 @@ bool at_autocast_set_enabled(bool b) {
int at_device(tensor t) {
PROTECT(auto device = t->device(); if (device.type() == at::kCPU) return -1;
if (device.type() == at::kCUDA) return device.index();)
return -2;
return -99; // error
}
void at_backward(tensor t, int keep_graph, int create_graph) {
@ -407,13 +418,13 @@ void at_run_backward(tensor *tensors, int ntensors, tensor *inputs, int ninputs,
for (int i = 0; i < ntensors; ++i)
grads.push_back(torch::ones_like(*tensors[i]));
auto vl = torch::autograd::Engine::get_default_engine().execute(roots, grads, keep_graph, create_graph, false, inputs_);
auto vl = torch::autograd::Engine::get_default_engine().execute(
roots, grads, keep_graph, create_graph, false, inputs_);
for (int i = 0; i < ninputs; ++i) {
outputs[i] = static_cast<tensor>(new torch::autograd::Variable(vl[i]));
})
}
optimizer ato_adam(double learning_rate, double beta1, double beta2,
double weight_decay) {
PROTECT(auto options = torch::optim::AdamOptions(learning_rate)
@ -510,11 +521,11 @@ void ato_set_learning_rate_group(optimizer t, size_t group,
set_lr_group<torch::optim::SGDOptions>(t, group, learning_rate);)
}
void ato_constant_pad_nd(tensor *out__, tensor self, int64_t *pad_data, int pad_len, scalar value) {
PROTECT(
auto outputs__ = torch::constant_pad_nd(*self, torch::IntArrayRef(pad_data, pad_len), *value);
out__[0] = new torch::Tensor(outputs__);
)
void ato_constant_pad_nd(tensor *out__, tensor self, int64_t *pad_data,
int pad_len, scalar value) {
PROTECT(auto outputs__ = torch::constant_pad_nd(
*self, torch::IntArrayRef(pad_data, pad_len), *value);
out__[0] = new torch::Tensor(outputs__);)
}
// ============ set/get learning rates ==============================
@ -536,15 +547,17 @@ template <class T> void set_lrs(optimizer t, double *learning_rates) {
}
void ato_set_learning_rates(optimizer t, double *lrs, int lrs_num) {
PROTECT(int ngroup = t->param_groups().size(); if (lrs == nullptr) {
throw std::invalid_argument("Input learning rates should not be null");
} if (ngroup != lrs_num) {
throw std::invalid_argument("Size of input learning rates is unequal to "
"number of parameter groups.");
} set_lrs<torch::optim::AdamOptions>(t, lrs);
set_lrs<torch::optim::AdamWOptions>(t, lrs);
set_lrs<torch::optim::RMSpropOptions>(t, lrs);
set_lrs<torch::optim::SGDOptions>(t, lrs);)
PROTECT(
int ngroup = t->param_groups().size(); if (lrs == nullptr) {
throw std::invalid_argument("Input learning rates should not be null");
} if (ngroup != lrs_num) {
throw std::invalid_argument(
"Size of input learning rates is unequal to "
"number of parameter groups.");
} set_lrs<torch::optim::AdamOptions>(t, lrs);
set_lrs<torch::optim::AdamWOptions>(t, lrs);
set_lrs<torch::optim::RMSpropOptions>(t, lrs);
set_lrs<torch::optim::SGDOptions>(t, lrs);)
}
template <class T> void get_lrs(optimizer t, vector<double> &lrs) {
@ -753,6 +766,26 @@ void atc_set_benchmark_cudnn(int b) {
at::globalContext().setBenchmarkCuDNN(b);
}
void atc_synchronize(int64_t device_index) {
PROTECT(return torch::cuda::synchronize(device_index);)
}
// returns current CUDA device index.
int atc_get_device() {
PROTECT(at::Device d(at::kCUDA);
auto *g = c10::impl::getDeviceGuardImpl(d.type()); d = g->getDevice();
return d.index();)
return -99; // error
}
// set new cuda device with input device index.
void atc_set_device(int device_index) {
PROTECT(at::Device new_device(at::kCUDA);
new_device = device_of_int(device_index);
auto *g = c10::impl::getDeviceGuardImpl(new_device.type());
g->setDevice(new_device);)
}
module atm_load(char *filename) {
PROTECT(return new torch::jit::script::Module(torch::jit::load(filename));)
return nullptr;
@ -1007,13 +1040,14 @@ void ati_to_generic_list(ivalue i, ivalue *outputs, int noutputs) {
}
void ati_to_generic_dict(ivalue i, ivalue *outputs, int noutputs) {
PROTECT(auto dict = i->toGenericDict(); if (dict.size() != noutputs) {
throw std::invalid_argument("unexpected dict size");
} int k = 0;
for (auto it = dict.begin(); it != dict.end(); ++it) {
outputs[k++] = new torch::jit::IValue(it->key());
outputs[k++] = new torch::jit::IValue(it->value());
})
PROTECT(
auto dict = i->toGenericDict(); if (dict.size() != noutputs) {
throw std::invalid_argument("unexpected dict size");
} int k = 0;
for (auto it = dict.begin(); it != dict.end(); ++it) {
outputs[k++] = new torch::jit::IValue(it->key());
outputs[k++] = new torch::jit::IValue(it->value());
})
}
void ati_to_int_list(ivalue i, int64_t *outputs, int noutputs) {

View File

@ -3,6 +3,10 @@
#include <stdint.h>
#ifdef __cplusplus
#include <stdexcept>
#include <torch/torch.h>
using namespace std;
thread_local char *torch_last_err = nullptr;
extern "C" {
@ -46,6 +50,7 @@ size_t at_dim(tensor);
void at_shape(tensor, int64_t *);
void at_stride(tensor, int64_t *);
int at_scalar_type(tensor);
int at_is_contiguous(tensor);
void at__amp_non_finite_check_and_unscale(tensor, tensor, tensor);
@ -105,13 +110,8 @@ void at_set_num_threads(int n_threads);
void at_free(tensor);
void at_run_backward(tensor *tensors,
int ntensors,
tensor *inputs,
int ninputs,
tensor *outputs,
int keep_graph,
int create_graph);
void at_run_backward(tensor *tensors, int ntensors, tensor *inputs, int ninputs,
tensor *outputs, int keep_graph, int create_graph);
optimizer ato_adam(double learning_rate, double beta1, double beta2,
double weight_decay);
@ -141,8 +141,10 @@ int64_t ato_param_group_num(optimizer);
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
void ato_add_param_group(optimizer, tensor *params, int param_num);
// TT. added option pad value. Original generated API `atg_constant_pad_nd` no option of adding pad value.
void ato_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len, scalar value);
// TT. added option pad value. Original generated API `atg_constant_pad_nd` no
// option of adding pad value.
void ato_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len,
scalar value);
scalar ats_int(int64_t);
scalar ats_float(double);
@ -155,6 +157,12 @@ int atc_cuda_device_count();
int atc_cuda_is_available();
int atc_cudnn_is_available();
void atc_set_benchmark_cudnn(int b);
void atc_synchronize(int64_t device_index);
// TT. added for testing qt
// ref. https://github.com/pytorch/pytorch/issues/14959
int atc_get_device();
void atc_set_device(int device_index);
module atm_load(char *);
module atm_load_on_device(char *, int device);
@ -212,6 +220,11 @@ void ati_free(ivalue);
#ifdef __cplusplus
}; // extern "C"
#endif
std::vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len);
at::Device device_of_int(int d);
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
int len);
#endif
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

146
mem-util.go Normal file
View File

@ -0,0 +1,146 @@
package gotch
// helper to debug memory blow-up
import (
"fmt"
"os"
"runtime"
"strings"
"text/tabwriter"
)
func PrintMemStats(messageOpt ...string) {
message := "Memory Stats"
if len(messageOpt) > 0 {
message = fmt.Sprintf("%s: %s", message, messageOpt[0])
}
var rtm runtime.MemStats
runtime.ReadMemStats(&rtm)
tp := newTablePrinter()
tp.title = message
tp.AddRecord("|", "Allocated heap objects", padRight(fmt.Sprintf("%v", rtm.Mallocs), 10), "|")
tp.AddRecord("|", "Released heap objects", padRight(fmt.Sprintf("%v", rtm.Frees), 10), "|")
tp.AddRecord("|", "Living heap objects", padRight(fmt.Sprintf("%v", rtm.HeapObjects), 10), "|")
tp.AddRecord("|", "Memory in use by heap objects (bytes)", padRight(fmt.Sprintf("%v", rtm.HeapAlloc), 10), "|")
tp.AddRecord("|", "Reserved memory (by Go runtime for heap, stack,...) (bytes)", padRight(fmt.Sprintf("%v", rtm.Sys), 10), "|")
tp.AddRecord("|", "Total pause time by GC (nanoseconds)", padRight(fmt.Sprintf("%v", rtm.PauseTotalNs), 10), "|")
tp.AddRecord("|", "Number of GC called", padRight(fmt.Sprintf("%v", rtm.NumGC), 10), "|")
// tp.AddRecord("Last GC called", fmt.Sprintf("%v", time.UnixMilli(int64(rtm.LastGC/1_000_000))))
tp.Print()
}
type tablePrinter struct {
w *tabwriter.Writer
maxLength int
title string
}
type printItem struct {
val string
alignRight bool
}
func item(val string, alignRightOpt ...bool) printItem {
alignRight := false
if len(alignRightOpt) > 0 {
alignRight = alignRightOpt[0]
}
return printItem{
val: val,
alignRight: alignRight,
}
}
func newTablePrinter() *tablePrinter {
w := tabwriter.NewWriter(
os.Stdout, //output
0, // min width
1, // tabwidth
2, // padding
' ', // padding character
0, // align left
)
return &tablePrinter{
w: w,
maxLength: 0,
}
}
func (tp *tablePrinter) AddRecord(items ...string) {
tp.printRecord(items...)
}
func (tp *tablePrinter) AlignRight() {
tp.w.Init(
os.Stdout, //output
0, // min width
1, // tabwidth
2, // padding
' ', // padding character
tabwriter.AlignRight,
) // flags
}
func (tp *tablePrinter) AlignLeft() {
tp.w.Init(
os.Stdout, //output
0, // min width
1, // tabwidth
2, // padding
' ', // padding character
0, // align left
) // flags
}
func (tp *tablePrinter) printRecord(rec ...string) {
var val string
for i, item := range rec {
switch i {
case 0:
val = item
case len(rec) - 1:
val += fmt.Sprintf("\t%s\n", item)
default:
val += fmt.Sprintf("\t%s", item)
}
}
nbytes, err := tp.w.Write([]byte(val))
if err != nil {
panic(err)
}
if nbytes > tp.maxLength {
tp.maxLength = nbytes
}
}
func (tp *tablePrinter) Print() {
printBorder(tp.maxLength)
printLine(tp.maxLength, tp.title)
printBorder(tp.maxLength)
tp.w.Flush()
printBorder(tp.maxLength)
}
func padRight(val interface{}, rightEnd int) string {
value := fmt.Sprintf("%v", val)
pad := fmt.Sprintf("%s", strings.Repeat(" ", rightEnd-len(value)))
return fmt.Sprintf("%s%s", pad, value)
}
func printLine(lineLength int, value string) {
fmt.Printf("| %s %s\n", value, padRight("|", lineLength-len(value)-1))
}
func printBorder(length int) {
line := fmt.Sprintf("%s", strings.Repeat("-", length))
fmt.Printf("+%s+\n", line)
}

View File

@ -12,7 +12,7 @@ import (
type Init interface {
// creates a new tensor with specified initiation
InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor)
InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor)
// re-initializes (in-place) an existing tensor with the specified initiation
Set(tensor *ts.Tensor)
@ -25,18 +25,24 @@ type constInit struct {
value float64
}
var _ Init = new(constInit)
func NewConstInit(v float64) constInit {
return constInit{v}
}
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
func (c constInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
var err error
kind := gotch.Float
switch {
case c.value == 0.0:
retVal = ts.MustZeros(dims, kind, device)
retVal = ts.MustZeros(dims, dtype, device)
case c.value == 1.0:
retVal = ts.MustOnes(dims, kind, device)
retVal = ts.MustOnes(dims, dtype, device)
default:
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
@ -68,18 +74,29 @@ type randnInit struct {
stdev float64
}
var _ Init = new(randnInit)
func NewRandnInit(mean, stdev float64) randnInit {
return randnInit{mean, stdev}
}
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
return ts.MustRandn(dims, gotch.Float, device)
func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
initTs := ts.MustRandn(dims, gotch.Float, device)
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
ts.NoGrad(func() {
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
retVal = ts.MustRandn(dims, dtype, device)
}
initTs := ts.MustRandn(dims, dtype, device)
retVal = initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
})
return retVal
}
func (r randnInit) Set(tensor *ts.Tensor) {
@ -88,9 +105,11 @@ func (r randnInit) Set(tensor *ts.Tensor) {
log.Fatalf("randInit - Set method call error: %v\n", err)
}
initTs := r.InitTensor(dims, tensor.MustDevice())
tensor.Copy_(initTs)
initTs.MustDrop()
ts.NoGrad(func() {
initTs := r.InitTensor(dims, tensor.MustDevice())
tensor.Copy_(initTs)
initTs.MustDrop()
})
}
// uniformInit :
@ -101,18 +120,26 @@ type uniformInit struct {
up float64
}
var _ Init = new(uniformInit)
func NewUniformInit(lo, up float64) uniformInit {
return uniformInit{lo, up}
}
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
var err error
kind := gotch.Float
retVal = ts.MustZeros(dims, kind, device)
retVal.Uniform_(u.lo, u.up)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
var err error
ts.NoGrad(func() {
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(u.lo, u.up)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
}
})
return retVal
}
@ -174,6 +201,8 @@ type kaimingUniformInit struct {
NonLinearity string
}
var _ Init = new(kaimingUniformInit)
func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
o := DefaultKaimingOptions()
for _, opt := range opts {
@ -187,26 +216,37 @@ func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
}
}
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
fanIn, _, err := CalculateFans(dims)
if err != nil {
panic(err)
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
if err != nil {
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
panic(err)
}
/*
fanIn, _, err := CalculateFans(dims)
if err != nil {
panic(err)
}
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
if err != nil {
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
panic(err)
}
// Calculate uniform bounds from standard deviation
bound := math.Sqrt(3.0) * std
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
kind := gotch.Float
retVal = ts.MustZeros(dims, kind, device)
retVal.Uniform_(-bound, bound)
// Calculate uniform bounds from standard deviation
bound := math.Sqrt(3.0) * std
// NOTE. This is a well-known memory leak!!!
// Avoid to use it for now!!!
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(-bound, bound)
*/
// NOTE. For now, just make a random norm
retVal = ts.MustRandn(dims, dtype, device)
return retVal
}
@ -347,3 +387,36 @@ func contains(items []string, item string) bool {
}
return false
}
// XavierUniform fills the input tensor with values according to the method
// described in the paper `Understanding the difficulty of training deep feedforward neural networks`
// using a uniform distribution
//
// Also known as Glorot initialization.
//
// Paper: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
// Pytorch implementation: https://github.com/pytorch/pytorch/blob/df50f91571891ec3f87977a2bdd4a2b609d70afc/torch/nn/init.py#L310
func XavierUniform_(x *ts.Tensor, gainOpt ...float64) {
gain := 1.0
if len(gainOpt) > 0 {
gain = gainOpt[0]
}
size := x.MustSize()
dtype := x.DType()
device := x.MustDevice()
fanIn, fanOut, err := CalculateFans(size)
if err != nil {
panic(err)
}
std := gain * math.Sqrt(2.0/float64(fanIn+fanOut))
// calculate uniform bounds from standard deviation
a := math.Sqrt(3.0) * std
uniformInit := NewUniformInit(-a, a)
src := uniformInit.InitTensor(size, device, dtype)
x.Copy_(src)
src.MustDrop()
}

44
nn/init_test.go Normal file
View File

@ -0,0 +1,44 @@
package nn
import (
"fmt"
"testing"
"time"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)
// Test whether InitTensor() can cause memory blow-up due to accumulate gradient.
func TestInitTensor_Memcheck(t *testing.T) {
gotch.PrintMemStats("Start")
device := gotch.CPU
vs := NewVarStore(device)
params := 500
path := vs.Root()
dims := []int64{1024, 1024}
for i := 0; i < params; i++ {
ts.NoGrad(func() {
name := fmt.Sprintf("param_%v", i)
x := ts.MustRandn(dims, gotch.DefaultDType, device)
path.MustAdd(name, x, false)
x.MustDrop()
})
}
// vs.Summary()
fmt.Printf("vs created...\n")
// printMemStats("After varstore created")
vs.Destroy()
ts.CleanUp()
fmt.Printf("vs deleted...\n")
// printMemStats("After varstore deleted")
time.Sleep(time.Second * 10)
gotch.PrintMemStats("Final")
}

View File

@ -79,7 +79,7 @@ func (m *TrainableCModule) Save(file string) error {
// ForwardT implements ModuleT for TrainableCModule.
// NOTE: train parameter will not be used.
func (m *TrainableCModule) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
retVal, err := m.Inner.ForwardTs([]ts.Tensor{*x})
retVal, err := m.Inner.ForwardTs([]*ts.Tensor{x})
if err != nil {
log.Fatal(err)
}

View File

@ -6,7 +6,6 @@ import (
"fmt"
"math"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)
@ -22,6 +21,8 @@ type LinearConfig struct {
func DefaultLinearConfig() *LinearConfig {
negSlope := math.Sqrt(5)
return &LinearConfig{
// NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!!
// Avoid using it now.
WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)),
BsInit: nil,
Bias: true,
@ -41,11 +42,7 @@ type Linear struct {
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
var bs *ts.Tensor
// bs has size of output dimension
switch c.Bias {
case false:
bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device())
case true:
if c.Bias {
switch {
case c.BsInit == nil:
shape := []int64{inDim, outDim}
@ -65,8 +62,10 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
}
}
ws := vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false)
return &Linear{
Ws: vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
Ws: ws,
Bs: bs,
}
}
@ -87,29 +86,31 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
//
// Example:
//
// inDim := 3
// outDim := 2
// batchSize := 4
// weights: 2x3
// [ 1 1 1
// 1 1 1 ]
// inDim := 3
// outDim := 2
// batchSize := 4
// weights: 2x3
// [ 1 1 1
// 1 1 1 ]
//
// input node: 3x4
// [ 1 1 1
// 1 1 1
// 1 1 1
// 1 1 1 ]
// input node: 3x4
// [ 1 1 1
// 1 1 1
// 1 1 1
// 1 1 1 ]
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true)
if l.Bs != nil {
return mul.MustAdd(l.Bs, true)
} else {
return mul
}
}
// ForwardT implements ModuleT interface for Linear layer.
//
// NOTE: train param will not be used.
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true)
}

View File

@ -1,7 +1,6 @@
package nn
import (
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)
@ -68,7 +67,7 @@ func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tenso
reduction := options.Reduction
ignoreIndex := options.IgnoreIndex
logSm := logits.MustLogSoftmax(-1, gotch.Float, false)
logSm := logits.MustLogSoftmax(-1, dtype, false)
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
ws.MustDrop()

View File

@ -7,7 +7,6 @@ import (
"log"
"math"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)
@ -387,12 +386,12 @@ func WithErrorIfNonFinite(v bool) ClipOpt {
}
}
/// Clips gradient L2 norm over all trainable parameters.
// / Clips gradient L2 norm over all trainable parameters.
//
// The norm is computed over all gradients together, as if they were
// concatenated into a single vector.
//
/// Args:
// / Args:
// - max: max norm of the gradient
// - o.NormType. Type of the used p-norm, can be "inf" for infinity norm. Default= 2.0
// - o.ErrorIfNonFinite bool. If true, throw error if total norm of the gradients from paramters is "nan", "inf" or "-inf". Default=false
@ -413,15 +412,19 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
}
var (
norms []ts.Tensor
norms []*ts.Tensor
totalNorm *ts.Tensor
)
device := opt.varstore.device
// FIXME. What about mixed-precision?
dtype := parameters[0].DType()
if o.NormType == math.Inf(1) {
for _, v := range opt.varstore.vars {
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
norms = append(norms, *n)
norms = append(norms, n)
}
// total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
totalNorm = ts.MustStack(norms, 0).MustMax(true)
@ -431,14 +434,14 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
norms = append(norms, *x)
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
norms = append(norms, x)
}
}
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
for _, x := range norms {
x.MustDrop()
}
@ -556,7 +559,7 @@ func (opt *Optimizer) ParamGroupNum() int {
return int(ngroup)
}
func (opt *Optimizer) AddParamGroup(tensors []ts.Tensor) {
func (opt *Optimizer) AddParamGroup(tensors []*ts.Tensor) {
err := opt.opt.AddParamGroup(tensors)
if err != nil {
log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err)

View File

@ -74,7 +74,7 @@ func DefaultRNNConfig() *RNNConfig {
//
// https://en.wikipedia.org/wiki/Long_short-term_memory
type LSTM struct {
flatWeights []ts.Tensor
flatWeights []*ts.Tensor
hiddenDim int64
config *RNNConfig
device gotch.Device
@ -89,7 +89,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
}
gateDim := 4 * hiddenDim
flatWeights := make([]ts.Tensor, 0)
flatWeights := make([]*ts.Tensor, 0)
for i := 0; i < int(cfg.NumLayers); i++ {
if i != 0 {
@ -102,7 +102,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
case 2: // bi-directional
// forward
@ -110,14 +110,14 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
wHh := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
// reverse
wIhR := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
wHhR := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
bIhR := vs.MustZeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
bHhR := vs.MustZeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
flatWeights = append(flatWeights, wIhR, wHhR, bIhR, bHhR)
}
}
@ -149,7 +149,9 @@ func (l *LSTM) ZeroState(batchDim int64) State {
layerDim := l.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, l.hiddenDim}
zeros := ts.MustZeros(shape, gotch.Float, l.device)
dtype := l.flatWeights[0].DType()
zeros := ts.MustZeros(shape, dtype, l.device)
retVal := &LSTMState{
Tensor1: zeros.MustShallowClone(),
@ -188,7 +190,7 @@ func (l *LSTM) Seq(input *ts.Tensor) (*ts.Tensor, State) {
func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
output, h, c := input.MustLstm([]ts.Tensor{*inState.(*LSTMState).Tensor1, *inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
output, h, c := input.MustLstm([]*ts.Tensor{inState.(*LSTMState).Tensor1, inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
return output, &LSTMState{
Tensor1: h,
@ -209,7 +211,7 @@ func (gs *GRUState) Value() *ts.Tensor {
//
// https://en.wikipedia.org/wiki/Gated_recurrent_unit
type GRU struct {
flatWeights []ts.Tensor
flatWeights []*ts.Tensor
hiddenDim int64
config *RNNConfig
device gotch.Device
@ -223,7 +225,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
}
gateDim := 3 * hiddenDim
flatWeights := make([]ts.Tensor, 0)
flatWeights := make([]*ts.Tensor, 0)
for i := 0; i < int(cfg.NumLayers); i++ {
for n := 0; n < int(numDirections); n++ {
@ -239,7 +241,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
bIh := vs.MustZeros("b_ih", []int64{gateDim})
bHh := vs.MustZeros("b_hh", []int64{gateDim})
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
}
}
@ -269,7 +271,8 @@ func (g *GRU) ZeroState(batchDim int64) State {
layerDim := g.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, g.hiddenDim}
tensor := ts.MustZeros(shape, gotch.Float, g.device)
dtype := g.flatWeights[0].DType()
tensor := ts.MustZeros(shape, dtype, g.device)
return &GRUState{Tensor: tensor}
}

View File

@ -46,7 +46,7 @@ func (s *Sequential) AddFn(fn ts.Module) {
}
// ForwardAll applies the forward pass and returns the output for each layer.
func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []ts.Tensor) {
func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []*ts.Tensor) {
var n uint8 = uint8(len(s.layers))
if len(opts) > 0 {
@ -54,11 +54,11 @@ func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []ts.Tenso
}
if s.IsEmpty() {
return []ts.Tensor{*xs.MustShallowClone()}
return []*ts.Tensor{xs.MustShallowClone()}
}
for i := 0; i < int(n); i++ {
retVal = append(retVal, *s.layers[i].Forward(xs))
retVal = append(retVal, s.layers[i].Forward(xs))
}
return retVal
@ -85,15 +85,15 @@ func (s *Sequential) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
}
// forward sequentially
outs := make([]ts.Tensor, len(s.layers))
outs := make([]*ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ {
if i == 0 {
outs[0] = *s.layers[i].Forward(xs)
outs[0] = s.layers[i].Forward(xs)
defer outs[0].MustDrop()
} else if i == len(s.layers)-1 {
return s.layers[i].Forward(&outs[i-1])
return s.layers[i].Forward(outs[i-1])
} else {
outs[i] = *s.layers[i].Forward(&outs[i-1])
outs[i] = s.layers[i].Forward(outs[i-1])
defer outs[i].MustDrop()
}
}
@ -106,7 +106,7 @@ type SequentialT struct {
layers []ts.ModuleT
}
/// SeqT creates a new empty sequential layer.
// / SeqT creates a new empty sequential layer.
func SeqT() *SequentialT {
return &SequentialT{
layers: make([]ts.ModuleT, 0),
@ -139,15 +139,15 @@ func (s *SequentialT) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
}
// forward sequentially
outs := make([]ts.Tensor, len(s.layers))
outs := make([]*ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ {
if i == 0 {
outs[0] = *s.layers[i].ForwardT(xs, train)
outs[0] = s.layers[i].ForwardT(xs, train)
defer outs[0].MustDrop()
} else if i == len(s.layers)-1 {
return s.layers[i].ForwardT(&outs[i-1], train)
return s.layers[i].ForwardT(outs[i-1], train)
} else {
outs[i] = *s.layers[i].ForwardT(&outs[i-1], train)
outs[i] = s.layers[i].ForwardT(outs[i-1], train)
defer outs[i].MustDrop()
}
}
@ -179,7 +179,7 @@ func (s *SequentialT) AddFnT(fn ts.ModuleT) {
}
// ForwardAll applies the forward pass and returns the output for each layer.
func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (retVal []ts.Tensor) {
func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (retVal []*ts.Tensor) {
var n uint8 = uint8(len(s.layers))
if len(opts) > 0 {
@ -187,13 +187,13 @@ func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (ret
}
if s.IsEmpty() {
return []ts.Tensor{*xs.MustShallowClone()}
return []*ts.Tensor{xs.MustShallowClone()}
}
currTs := xs
for i := 0; i < int(n); i++ {
res := s.layers[i].ForwardT(currTs, train)
retVal = append(retVal, *res)
retVal = append(retVal, res)
currTs = res
}

View File

@ -78,15 +78,15 @@ func (vs *VarStore) IsEmpty() bool {
}
// TrainableVariabless returns reference to all trainable variables kept in VarStore.
func (vs *VarStore) TrainableVariables() []ts.Tensor {
func (vs *VarStore) TrainableVariables() []*ts.Tensor {
vs.Lock()
defer vs.Unlock()
var trainables []ts.Tensor
var trainables []*ts.Tensor
for _, v := range vs.vars {
x := v.Tensor
if x.MustRequiresGrad() {
trainables = append(trainables, *x)
trainables = append(trainables, x)
}
}
@ -192,6 +192,8 @@ func (vs *VarStore) Load(filepath string) error {
x.Tensor.MustDrop()
}
ts.CleanUp()
return nil
}
@ -227,6 +229,9 @@ func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error {
v.Tensor.Copy_(currTs)
})
}
ts.CleanUp()
return nil
}
@ -286,6 +291,8 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) {
x.Tensor.MustDrop()
}
ts.CleanUp()
return missingVariables, nil
}
@ -336,6 +343,8 @@ func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string,
})
}
ts.CleanUp()
return missingVariables, nil
}
@ -406,6 +415,8 @@ func (vs *VarStore) Copy(src *VarStore) error {
srcDevTs.MustDrop()
}
ts.CleanUp()
return nil
}
@ -417,12 +428,21 @@ func (vs *VarStore) Summary() {
layers = append(layers, name)
}
sort.Strings(layers)
var dtype gotch.DType
isFirst := true
for _, l := range layers {
var x *ts.Tensor
var isBuffer bool
for name, v := range vars {
if name == l {
x = v.Tensor
// Get DType of first tensor for representation only
if isFirst {
dtype = x.DType()
}
isFirst = false
isBuffer = v.Type == "buffer"
break
}
@ -435,25 +455,33 @@ func (vs *VarStore) Summary() {
}
fmt.Printf("Num of layers: %v\n", len(vars))
fmt.Printf("DType: %v\n", dtype)
}
// Destroy deletes all tensors in varstore and set it to nil.
func (vs *VarStore) Destroy() {
vs.Lock()
for n, v := range vs.vars {
v.Tensor.MustDrop()
delete(vs.vars, n)
}
vs.Unlock()
vs = nil
}
// ToDType casts all variables in VarStore to specified DType.
//
// NOTE. only float-like types (Half, Float, Double) can ensure convertible.
// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.
func (vs *VarStore) ToDType(dtype gotch.DType) {
vs.Root().ToDType(dtype)
}
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToHalf() {
vs.Root().ToHalf()
}
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
// NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype.
func (vs *VarStore) ToFloat() {
vs.Root().ToFloat()
}
@ -465,6 +493,25 @@ func (vs *VarStore) ToDouble() {
vs.Root().ToDouble()
}
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToHalf() {
vs.Root().ToHalf()
}
// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToBFloat16() {
vs.Root().ToBFloat16()
}
func (vs *VarStore) ToDevice(device gotch.Device) {
p := vs.Root()
p.ToDevice(device)
}
// Path methods:
// =============
@ -520,6 +567,7 @@ func (p *Path) add(name string, newTs *ts.Tensor, trainable bool, varType string
tensor *ts.Tensor
err error
)
if trainable {
tensor, err = newTs.SetRequiresGrad(true, false)
if err != nil {
@ -664,7 +712,7 @@ func (p *Path) SetGroup(g uint) {
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
//
// NOTE. this method should be used for floating-point conversion, i.e.,
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double".
// "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double".
func (p *Path) ToDType(dtype gotch.DType) {
p.varstore.Lock()
defer p.varstore.Unlock()
@ -686,28 +734,63 @@ func (p *Path) toFloat(dtype gotch.DType) {
for name, v := range p.varstore.vars {
if strings.Contains(name, path) {
dtype := v.Tensor.DType()
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double {
if gotch.IsFloatDType(dtype) {
newVar := v
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
p.varstore.vars[name] = newVar
}
}
}
ts.CleanUp()
}
// ToHalf casts all variables in current path and subpaths to `Half` precision.
// ToFloat casts all variables in current path and subpaths to `Float` precision.
func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) {
dtype := gotch.Float
if len(floatDTypeOpt) > 0 {
dt := floatDTypeOpt[0]
if !gotch.IsFloatDType(dt) {
// Ingore the option
if gotch.Debug {
log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt)
}
} else {
dtype = dt
}
}
p.toFloat(dtype)
}
// ToDouble casts all variables in current path and subpaths to `Double` precision dtype.
func (p *Path) ToDouble() {
p.toFloat(gotch.Double)
}
// ToHalf casts all variables in current path and subpaths to `Half` precision dtype.
func (p *Path) ToHalf() {
p.toFloat(gotch.Half)
}
// ToFloat casts all variables in current path and subpaths to `Float` precision.
func (p *Path) ToFloat() {
p.toFloat(gotch.Float)
// ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype.
func (p *Path) ToBFloat16() {
p.toFloat(gotch.BFloat16)
}
// ToDouble casts all variables in current path and subpaths to `Double` precision.
func (p *Path) ToDouble() {
p.toFloat(gotch.Double)
func (p *Path) ToDevice(device gotch.Device) {
p.varstore.Lock()
defer p.varstore.Unlock()
path := strings.Join(p.path, SEP)
for name, v := range p.varstore.vars {
if strings.Contains(name, path) {
newVar := v
newVar.Tensor = v.Tensor.MustTo(device, true)
p.varstore.vars[name] = newVar
}
}
ts.CleanUp()
}
// ZerosNoTrain creates a new variable initialized with zeros.
@ -718,7 +801,8 @@ func (p *Path) ToDouble() {
// The variable uses a float tensor initialized with zeros.
func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
device := p.Device()
z, err := ts.Zeros(dims, gotch.Float, device)
dtype := gotch.DefaultDType
z, err := ts.Zeros(dims, dtype, device)
if err != nil {
err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err)
return nil, err
@ -755,7 +839,8 @@ func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.T
// The variable uses a float tensor initialized with ones.
func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
device := p.Device()
z, err := ts.Ones(dims, gotch.Float, device)
dtype := gotch.DefaultDType
z, err := ts.Ones(dims, dtype, device)
if err != nil {
err = fmt.Errorf("Path.OneNoTrain() failed: %w", err)
return nil, err
@ -792,12 +877,19 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
// The variable uses a float tensor initialized as per the
// related argument.
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
v := ini.InitTensor(dims, p.varstore.device)
dtype := gotch.DefaultDType
// v := ini.InitTensor(dims, p.varstore.device, dtype)
var v *ts.Tensor
v = ini.InitTensor(dims, p.varstore.device, dtype)
out, err := p.Add(name, v, true, opts...)
if err != nil {
return nil, err
}
v.MustDrop()
return out, err
}
@ -1098,7 +1190,8 @@ func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor {
// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable.
func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
o := ts.MustOnes(dims, gotch.Float, e.path.Device())
dtype := gotch.DefaultDType
o := ts.MustOnes(dims, dtype, e.path.Device())
out, err := e.path.getOrAddWithLock(e.name, o, true, opts...)
if err != nil {
return nil, err
@ -1161,7 +1254,8 @@ func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.
// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable.
func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
dtype := gotch.DefaultDType
z := ts.MustZeros(dims, dtype, e.path.Device())
out, err := e.path.getOrAddWithLock(e.name, z, true, opts...)
if err != nil {
return nil, err

View File

@ -1,10 +1,12 @@
package nn_test
import (
"fmt"
"os"
"path/filepath"
"reflect"
"testing"
"time"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
@ -133,3 +135,43 @@ func TestSaveLoad(t *testing.T) {
t.Errorf("Failed deleting varstore saved file: %v\n", filenameAbs)
}
}
// Test whether create params in varstore can cause memory blow-up due to accumulate gradient.
func TestVarstore_Memcheck(t *testing.T) {
gotch.PrintMemStats("Start")
device := gotch.CPU
vs := nn.NewVarStore(device)
params := 1000
path := vs.Root()
// dims := []int64{1024, 1024}
config := nn.DefaultLinearConfig()
inDim := int64(1024)
outDim := int64(1024)
var layers []nn.Linear
for i := 0; i < params; i++ {
ts.NoGrad(func() {
name := fmt.Sprintf("param_%v", i)
l := nn.NewLinear(path.Sub(name), inDim, outDim, config)
layers = append(layers, *l)
// x := ts.MustRandn(dims, gotch.DefaultDType, device)
// path.MustAdd(name, x, false)
// x.MustDrop()
})
}
// vs.Summary()
fmt.Printf("vs created...\n")
// printMemStats("After varstore created")
vs.Destroy()
ts.CleanUp()
fmt.Printf("vs deleted...\n")
// printMemStats("After varstore deleted")
time.Sleep(time.Second * 10)
gotch.PrintMemStats("Final")
}

View File

@ -1,6 +1,7 @@
package pickle_test
import (
"fmt"
"log"
"github.com/sugarme/gotch"
@ -18,11 +19,13 @@ func ExampleLoadInfo() {
panic(err)
}
err = pickle.LoadInfo(modelFile)
m, err := pickle.LoadModelInfo(modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Println(m)
// Output:
// classifier.0.bias - [4096]
// classifier.0.weight - [4096 25088]
@ -57,4 +60,25 @@ func ExampleLoadInfo() {
// features.7.bias - [128]
// features.7.weight - [128 128 3 3]
// Num of variables: 32
// Tensor DType: Float
}
func ExampleModelFloat16() {
modelName := "HuggingFaceH4/tiny-random-LlamaForCausalLM"
url := "https://huggingface.co/HuggingFaceH4/tiny-random-LlamaForCausalLM/resolve/main/pytorch_model.bin"
modelFile, err := gotch.CachedPath(url, modelName)
if err != nil {
panic(err)
}
m, err := pickle.LoadModelInfo(modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Model DType: %v\n", m.DType())
// Output:
// Model DType: Half
}

View File

@ -20,6 +20,7 @@ import (
"reflect"
"sort"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/ts"
)
@ -60,83 +61,41 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
dictResult := *result.(*Dict)
for _, item := range dictResult {
name := item.Key
itemTyp := reflect.TypeOf(item.Value).String()
switch itemTyp {
case "*pickle.Dict": // Nested *pickle.Dict case
subResult := *item.Value.(*Dict)
for _, subItem := range subResult {
subName := subItem.Key
x, ok := subItem.Value.(*StorageTensor)
if !ok {
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName)
continue
}
data := x.Source.GetData()
size := x.Size
dtype := x.Source.DType()
device := x.Source.Device()
stride := x.Stride
storageOffset := x.StorageOffset
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if x.RequiresGrad {
x1.MustRequiresGrad_(x.RequiresGrad)
}
namedTensors[name.(string)] = x1
}
default:
sx, isStorageTensor := item.Value.(*StorageTensor)
// if !isStorageTensor {
// err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
// return nil, err
// }
if !isStorageTensor {
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name)
continue
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
sx, isStorageTensor := item.Value.(*StorageTensor)
if !isStorageTensor {
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
return nil, err
}
data := sx.Source.GetData()
size := sx.Size
dtype := sx.Source.DType()
device := sx.Source.Device()
stride := sx.Stride
storageOffset := sx.StorageOffset
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
// log.Printf("data: %v\n", data)
// Dealing with Pytorch `..._tracked` variables.
if reflect.ValueOf(data).Len() == 0 {
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
continue
}
// TODO. should we just skip them?
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
size = []int64{1}
stride = []int64{1}
}
x := ts.MustOfSlice(data, ts.WithDType(dtype)).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
if sx.RequiresGrad {
x.MustRequiresGrad_(sx.RequiresGrad)
}
namedTensors[name.(string)] = x
}
case "*pickle.OrderedDict":
dictResult := result.(*OrderedDict)
@ -581,35 +540,74 @@ func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
return missingVariables, nil
}
// LoadInfo loads pretrained weights and prints out name and shape of weights.
func LoadInfo(modelFile string) error {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadInfo() failed: %w", err)
return err
}
type ModelInfor struct {
weights map[string][]int64
dtype gotch.DType
}
layers := make([]string, 0, len(weights))
for tsName := range weights {
func NewModelInfor(weights map[string][]int64, dtype gotch.DType) *ModelInfor {
return &ModelInfor{
weights: weights,
dtype: dtype,
}
}
func (m *ModelInfor) String() string {
var summary string
layers := make([]string, 0, len(m.weights))
for tsName := range m.weights {
layers = append(layers, tsName)
}
sort.Strings(layers)
for _, l := range layers {
var x *ts.Tensor
for tsName, tsVal := range weights {
var x []int64
for tsName, shape := range m.weights {
if tsName == l {
x = tsVal
x = shape
break
}
}
fmt.Printf("%s - %+v\n", l, x.MustSize())
summary += fmt.Sprintf("%s - %+v\n", l, x)
}
fmt.Printf("Num of variables: %v\n", len(weights))
summary += fmt.Sprintf("Num of variables: %v\n", len(m.weights))
summary += fmt.Sprintf("Tensor DType: %v\n", m.dtype)
for _, x := range weights {
x.MustDrop()
}
return nil
return summary
}
func (m *ModelInfor) DType() gotch.DType {
return m.dtype
}
func (m *ModelInfor) Parameters() int {
return len(m.weights)
}
// LoadInfo loads pretrained weights and prints out name and shape of weights.
func LoadModelInfo(modelFile string) (*ModelInfor, error) {
weights, err := Decode(modelFile)
if err != nil {
err = fmt.Errorf("LoadInfo() failed: %w", err)
return nil, err
}
w := make(map[string][]int64)
var dtype gotch.DType
isFirst := true
for n, x := range weights {
w[n] = x.MustSize()
if isFirst {
dtype = x.DType()
isFirst = false
}
}
m := NewModelInfor(w, dtype)
ts.CleanUp()
return m, nil
}

View File

@ -7,6 +7,7 @@ import (
"math"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/half"
)
// This file implements Pytorch storage data types.
@ -67,7 +68,8 @@ func (s *HalfStorageClass) New(size int, location string) Storage {
type HalfStorage struct {
BaseStorage
Data []float32
// Data []float32
Data []half.Float16
}
var _ Storage = &HalfStorage{}
@ -77,7 +79,7 @@ func (s *HalfStorage) SetFromFile(r io.Reader) error {
}
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]float32, size)
data := make([]half.Float16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
@ -85,7 +87,7 @@ func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
return err
}
u16 := binary.LittleEndian.Uint16(bytes)
data[i] = math.Float32frombits(FloatBits16to32(u16))
data[i] = half.Float16(u16)
}
s.Data = data
return nil
@ -96,7 +98,7 @@ func (s *HalfStorage) GetData() interface{} {
}
func (s *HalfStorage) DType() gotch.DType {
return gotch.Float
return gotch.Half
}
func (s *HalfStorage) Device() gotch.Device {
@ -108,6 +110,62 @@ func (s *HalfStorage) Device() gotch.Device {
}
}
// BFloat16Storage:
// ================
type BFloat16StorageClass struct{}
var _ StorageClass = &BFloat16StorageClass{}
func (s *BFloat16StorageClass) New(size int, location string) Storage {
return &BFloat16Storage{
BaseStorage: BaseStorage{Size: size, Location: location},
Data: nil,
}
}
type BFloat16Storage struct {
BaseStorage
Data []half.BFloat16
}
var _ Storage = &BFloat16Storage{}
func (s *BFloat16Storage) SetFromFile(r io.Reader) error {
return setFromFile(s, r)
}
func (s *BFloat16Storage) SetFromFileWithSize(r io.Reader, size int) error {
data := make([]half.BFloat16, size)
br := NewLimitedBufferReader(r, size, 2, 512)
for i := 0; i < size; i++ {
bytes, err := br.ReadNext()
if err != nil {
return err
}
u16 := binary.LittleEndian.Uint16(bytes)
data[i] = half.BFloat16(u16)
}
s.Data = data
return nil
}
func (s *BFloat16Storage) GetData() interface{} {
return s.Data
}
func (s *BFloat16Storage) DType() gotch.DType {
return gotch.BFloat16
}
func (s *BFloat16Storage) Device() gotch.Device {
switch s.Location {
case "cuda":
return gotch.CudaIfAvailable()
default:
return gotch.CPU
}
}
// FloatStorage:
// =============

View File

@ -1,7 +1,7 @@
#!/bin/bash
LIBTORCH_VERSION="${LIBTORCH_VER:-1.11.0}"
CUDA_VERSION="${CUDA_VER:-11.3}"
LIBTORCH_VERSION="${LIBTORCH_VER:-2.0.1}"
CUDA_VERSION="${CUDA_VER:-11.7}"
if [ "${CUDA_VERSION}" == "cpu" ]; then
CU_VERSION="cpu"

View File

@ -232,6 +232,7 @@ func (tdi *TextDataIter) Progress() float32 {
progress := float32(startIndex) / float32(availableIndices)
return progress
}
// Labels returns the number of different `character` (rune) used by the dataset.
func (td *TextData) Labels() (retVal int64) {
return int64(len(td.CharForLabel))
@ -281,12 +282,12 @@ func (tdi *TextDataIter) Next() (*Tensor, bool) {
indexes := indexesTs.Int64Values()
indexesTs.MustDrop()
var batch []Tensor
var batch []*Tensor
for _, idx := range indexes {
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
idxTs := tdi.Data.Idx(narrowIdx)
batch = append(batch, *idxTs)
batch = append(batch, idxTs)
}
retVal := MustStack(batch, 0)

View File

@ -20,8 +20,8 @@ func TestTextData_NewTextData(t *testing.T) {
log.Fatal(err)
}
txt := `héllo`
// txt := "h\xC3\xA9llo"
// txt := `héllo`
txt := "h\xC3\xA9llo"
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
if err != nil {
log.Fatal(err)
@ -33,7 +33,7 @@ func TestTextData_NewTextData(t *testing.T) {
}
wantData := []float64{0, 1, 2, 3, 3, 4}
gotData := textData.CloneData().Float64Values()
gotData := textData.Data.Float64Values()
if !reflect.DeepEqual(wantData, gotData) {
t.Errorf("Want data: %v\n", wantData)
@ -111,5 +111,4 @@ func TestTextDataIter(t *testing.T) {
vals := sum.Int64Values()
t.Logf("sum: %v\n", vals)
}
}

View File

@ -5,6 +5,7 @@ import "C"
import (
"fmt"
"runtime/debug"
"unsafe"
lib "github.com/sugarme/gotch/libtch"
@ -41,7 +42,8 @@ func TorchErr() error {
cptr := (*C.char)(lib.GetAndResetLastErr())
errStr := ptrToString(cptr)
if errStr != "" {
return fmt.Errorf("Libtorch API Error: %v\n", errStr)
trace := string(debug.Stack())
return fmt.Errorf("Libtorch API Error: %v\n%v\n", errStr, trace)
}
return nil

View File

@ -17,7 +17,7 @@ func LoadHwc(path string) (*Tensor, error) {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor), nil
}
// SaveHwc save an image from tensor. It expects a tensor of shape [height,
@ -38,5 +38,5 @@ func ResizeHwc(ts *Tensor, outWidth, outHeight int64) (*Tensor, error) {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor), nil
}

View File

@ -85,26 +85,26 @@ type InsertNewAxis struct{}
// NewSelect creates an tensor indexer with given index.
// `index` must be in range of tensor dimension. E.g. tensor shape [2,8]
// will have size = 2, hence `index` should be in range from [0,2)
func NewSelect(index int64) Select {
return Select{index}
func NewSelect(index int64) *Select {
return &Select{index}
}
func NewNarrow(start, end int64) Narrow {
return Narrow{Start: start, End: end}
func NewNarrow(start, end int64) *Narrow {
return &Narrow{Start: start, End: end}
}
func NewIndexSelect(ts *Tensor) IndexSelect {
return IndexSelect{Index: ts}
func NewIndexSelect(ts *Tensor) *IndexSelect {
return &IndexSelect{Index: ts}
}
func NewInsertNewAxis() InsertNewAxis {
return InsertNewAxis{}
func NewInsertNewAxis() *InsertNewAxis {
return &InsertNewAxis{}
}
func NewSliceIndex(sl []int64) IndexSelect {
func NewSliceIndex(sl []int64) *IndexSelect {
ts := MustOfSlice(sl)
return IndexSelect{Index: ts}
return &IndexSelect{Index: ts}
}
// type SelectFn func(int64)
@ -120,7 +120,7 @@ func NewSliceIndex(sl []int64) IndexSelect {
// )
type IndexOp interface {
Idx(index interface{}) Tensor
Idx(index interface{}) *Tensor
}
// implement IndexOp for Tensor:
@ -129,7 +129,7 @@ type IndexOp interface {
// Idx implements `IndexOp` interface for Tensor
//
// NOTE:
// - `index`: expects type `TensorIndexer` or `[]TensorIndexer`
// - `index`: expects type `TensorIndexer` or `[]*TensorIndexer`
func (ts *Tensor) Idx(index interface{}) (retVal *Tensor) {
// indexTyp := reflect.TypeOf(index)
@ -137,8 +137,9 @@ func (ts *Tensor) Idx(index interface{}) (retVal *Tensor) {
var indexes []TensorIndexer
switch indexVal.Kind().String() {
case "struct": // T: A
typ := indexVal.Kind().String()
switch typ {
case "ptr": // T: A
indexes = append(indexes, index.(TensorIndexer))
case "slice": // T: []TensorIndexer
switch len(index.([]TensorIndexer)) {
@ -201,7 +202,7 @@ func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal *Tensor, err error)
// Make sure number of non-newaxis is not exceed number of dimensions
var numNewAxis int = 0
for _, ti := range indexSpec {
if reflect.TypeOf(ti).Name() == "InsertNewAxis" {
if reflect.TypeOf(ti).String() == "*ts.InsertNewAxis" {
numNewAxis += 1
}
}
@ -218,10 +219,10 @@ func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal *Tensor, err error)
// Make sure tensor conforms the format
for _, spec := range indexSpec {
// If `spec` is `IndexSelect` type and
if reflect.TypeOf(spec).Name() == "IndexSelect" {
if reflect.ValueOf(spec).Kind() == reflect.Struct {
inputTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(*Tensor)
// If `spec` is `*IndexSelect` type and
if reflect.TypeOf(spec).String() == "*ts.IndexSelect" {
if reflect.ValueOf(spec).Kind() == reflect.Ptr {
inputTensor := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Index").Interface().(*Tensor)
// 1. Either its input tensor has dimension > 1, throw error.
inputTensorShape, err := inputTensor.Size()
@ -257,33 +258,32 @@ func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal *Tensor, err error)
// `spec` is a function type implements `TensorIndexer`
for _, spec := range indexSpec {
switch reflect.TypeOf(spec).Name() {
case "InsertNewAxis":
switch reflect.TypeOf(spec).String() {
case "*ts.InsertNewAxis":
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
if err != nil {
return retVal, err
}
nextIdx = currIdx + 1
case "Select": // 1 field: `Index`
index := reflect.ValueOf(spec).FieldByName("Index").Interface().(int64)
case "*ts.Select": // 1 field: `Index`
index := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Index").Interface().(int64)
nextTensor, err = currTensor.Select(currIdx, index, true) // TODO: double-check is `*index` or `index`
if err != nil {
return retVal, err
}
nextIdx = currIdx // not advanced because select() squeezes dimension
case "Narrow": // 2 fields: `(Start, End int64)`
case "*ts.Narrow": // 2 fields: `(Start, End int64)`
// TODO: implement for `Unbounded`, `Included`, `Excluded` ranges
// NOTE: for now, just implement (Included(start), Excluded(end))` case
start := reflect.ValueOf(spec).FieldByName("Start").Interface().(int64)
end := reflect.ValueOf(spec).FieldByName("End").Interface().(int64)
start := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Start").Interface().(int64)
end := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("End").Interface().(int64)
nextTensor, err = currTensor.Narrow(currIdx, start, end-start, true)
if err != nil {
return retVal, err
}
nextIdx = currIdx + 1
case "IndexSelect": // 1 field `(Index *Tensor)`
indexTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(*Tensor)
case "*ts.IndexSelect": // 1 field `(Index *Tensor)`
indexTensor := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Index").Interface().(*Tensor)
device, err := currTensor.Device()
if err != nil {
return retVal, err

10
ts/init.go Normal file
View File

@ -0,0 +1,10 @@
package ts
import (
// "runtime/debug"
)
func init() {
// debug.SetMemoryLimit()
// debug.SetGCPercent(100) // ratio freshly allocated data to live data remaining after previous collection.
}

View File

@ -26,7 +26,7 @@ func (it *Iterable) Next() (item interface{}, ok bool) {
}
var err error
switch it.ItemKind.Kind().String() {
switch it.ItemKind.GoKind().String() {
case "int64":
item, err = it.Content.Int64Value([]int64{it.Index})
if err != nil {

106
ts/jit.go
View File

@ -22,25 +22,23 @@ type CIValue struct {
civalue lib.Civalue
}
type IValueKind struct {
reflect.Type
}
type IValueKind int
var (
NoneVal IValueKind = IValueKind{reflect.TypeOf(nil)}
TensorVal IValueKind = IValueKind{reflect.TypeOf(Tensor{})}
DoubleVal IValueKind = IValueKind{reflect.TypeOf(float64(1))}
IntVal IValueKind = IValueKind{reflect.TypeOf(int64(1))}
BoolVal IValueKind = IValueKind{reflect.TypeOf(true)}
TupleVal IValueKind = IValueKind{reflect.TypeOf([]IValue{})}
IntListVal IValueKind = IValueKind{reflect.TypeOf([]int64{})}
DoubleListVal IValueKind = IValueKind{reflect.TypeOf([]float64{})}
BoolListVal IValueKind = IValueKind{reflect.TypeOf([]bool{})}
StringVal IValueKind = IValueKind{reflect.TypeOf("")}
TensorListVal IValueKind = IValueKind{reflect.TypeOf([]Tensor{})}
GenericListVal IValueKind = IValueKind{reflect.TypeOf([]IValue{})}
GenericDictVal IValueKind = IValueKind{reflect.TypeOf(map[IValue]IValue{})} // 2 elements. ? map[IValue]IValue
GenericVal IValueKind = IValueKind{reflect.TypeOf(IValue{})}
const (
NoneVal IValueKind = iota
TensorVal // *Tensor
DoubleVal // float64
IntVal // int64
BoolVal // bool
TupleVal // []*IValue
IntListVal // []int64
DoubleListVal // []float64
BoolListVal // []bool
StringVal // string
TensorListVal // []*Tensor
GenericListVal // []*IValue
GenericDictVal // map[IValue]IValue - 2 elements
GenericVal // *IValue
)
type IValue struct {
@ -51,7 +49,6 @@ type IValue struct {
// NewIValue creates a new IValue from given value of various types.
func NewIValue(v interface{}) *IValue {
retVal := &IValue{value: v}
if v == nil {
retVal.kind = NoneVal
@ -62,7 +59,7 @@ func NewIValue(v interface{}) *IValue {
inputTypeStr := reflect.TypeOf(v).Kind().String()
switch inputTypeStr {
case "Tensor":
case "*Tensor":
retVal.kind = TensorVal
retVal.name = "Tensor"
case "float64":
@ -87,17 +84,7 @@ func NewIValue(v interface{}) *IValue {
retVal.kind = StringVal
retVal.name = "String"
case "slice":
fmt.Printf("slice elem type: %q\n", reflect.TypeOf(v).Elem().Kind().String())
switch reflect.TypeOf(v).Elem().Kind().String() {
case "IValue":
switch len(v.([]IValue)) {
case 2:
retVal.kind = TupleVal
retVal.name = "Tuple"
default:
retVal.kind = GenericListVal
retVal.name = "GenericList"
}
case "int64":
retVal.kind = IntListVal
retVal.name = "IntList"
@ -119,31 +106,38 @@ func NewIValue(v interface{}) *IValue {
case "bool":
retVal.kind = BoolListVal
retVal.name = "BoolList"
case "struct": // NOTE: only supported `Tensor` type
case "ptr": // NOTE: only supported `*Tensor` type
val := reflect.Indirect(reflect.ValueOf(v))
switch {
// 1. Tuple (Tensor, Tensor)
case val.Type() == reflect.TypeOf([]Tensor{}) && val.Len() == 2:
// 1. Tuple (*Tensor, *Tensor)
case val.Type().String() == "[]*ts.Tensor" && val.Len() == 2:
retVal.kind = TensorListVal
retVal.name = "Tuple"
retVal.value = v.([]Tensor)
retVal.value = v.([]*Tensor)
// 2. List (Tensor, Tensor, ...)
case val.Type() == reflect.TypeOf([]Tensor{}) && val.Len() > 2:
// 2. List (*Tensor, *Tensor, ...)
case val.Type().String() == "[]*ts.Tensor" && val.Len() > 2:
retVal.kind = TensorListVal
retVal.name = "TensorList"
retVal.value = v.([]Tensor)
retVal.value = v.([]*Tensor)
case val.Type().String() == "[]*ts.IValue" && val.Len() == 2:
retVal.kind = TupleVal
retVal.name = "Tuple"
retVal.value = v.([]*IValue)
case val.Type().String() == "[]*ts.IValue" && val.Len() > 2, val.Type().String() == "[]*ts.IValue" && val.Len() == 1:
retVal.kind = GenericListVal
retVal.name = "GenericList"
default:
log.Fatalf("NewIValue method call - 'slice -> struct' case - Unsupported type (%v)\n", reflect.TypeOf(v).Kind().String())
log.Fatalf("NewIValue method call - 'slice -> struct' case - Unsupported type (%v)\n", val.Type().String())
}
}
case "map":
// TODO: exclude map of type other than IValue type
retVal.kind = GenericDictVal
retVal.name = "GenericDict"
case "struct":
case "ptr":
val := reflect.Indirect(reflect.ValueOf(v))
fieldName := val.Type().Field(0).Name
fieldName := val.Type().Field(2).Name
switch fieldName {
case "ctensor":
retVal.kind = TensorVal
@ -172,7 +166,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
return &CIValue{civalue: cval}, nil
case "Tensor":
cval := lib.AtiTensor(iv.value.(Tensor).ctensor)
cval := lib.AtiTensor(iv.value.(*Tensor).ctensor)
if err := TorchErr(); err != nil {
return nil, err
}
@ -203,7 +197,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
case "Tuple":
val := reflect.Indirect(reflect.ValueOf(iv.value))
switch {
// 1. Tuple is (Tensor, Tensor)
// 1. Tuple is (*Tensor, *Tensor)
case val.Type() == reflect.TypeOf([]Tensor{}):
var v []Tensor = iv.value.([]Tensor)
var cvals []lib.Civalue
@ -223,9 +217,9 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
}
return &CIValue{civalue: tuple}, nil
// 2. Tuple is (IValue, IValue)
// 2. Tuple is (*IValue, *IValue)
default:
var v []IValue = iv.value.([]IValue)
var v []*IValue = iv.value.([]*IValue)
var cvals []lib.Civalue
for _, i := range v {
cval, err := i.ToCIValue()
@ -328,7 +322,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
return &CIValue{civalue: cval}, nil
case "TensorList":
var vals []Tensor = iv.value.([]Tensor)
var vals []*Tensor = iv.value.([]*Tensor)
var cvals []lib.Ctensor
for _, i := range vals {
cvals = append(cvals, i.ctensor)
@ -450,7 +444,7 @@ func IValueFromC(cval *CIValue) (*IValue, error) {
return nil, err
}
return &IValue{
value: Tensor{tensor},
value: newTensor(tensor),
kind: TensorVal,
name: "Tensor",
}, nil
@ -518,14 +512,14 @@ func IValueFromC(cval *CIValue) (*IValue, error) {
elemName := v.Name()
switch elemName {
case "Tensor":
var vals []Tensor
var vals []*Tensor
for _, civalue := range civalues {
v, err := IValueFromC(&civalue)
if err != nil {
return nil, err
}
vals = append(vals, v.Value().(Tensor))
vals = append(vals, v.Value().(*Tensor))
}
if len == 2 {
return &IValue{
@ -725,12 +719,12 @@ func IValueFromC(cval *CIValue) (*IValue, error) {
}
// 3. Get values
var tensors []Tensor
tensors = append(tensors, Tensor{ctensor: *ptr1})
var tensors []*Tensor
tensors = append(tensors, newTensor(*ptr1))
currPtr := ptr1
for i := 1; i < int(len); i++ {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1)))
tensors = append(tensors, Tensor{ctensor: *nextPtr})
tensors = append(tensors, newTensor(*nextPtr))
currPtr = nextPtr
}
@ -1017,7 +1011,7 @@ func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (*CModule, er
}
// ForwardTs performs the forward pass for a model on some specified tensor inputs.
func (cm *CModule) ForwardTs(tensors []Tensor) (*Tensor, error) {
func (cm *CModule) ForwardTs(tensors []*Tensor) (*Tensor, error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
@ -1061,11 +1055,11 @@ func (cm *CModule) ForwardTs(tensors []Tensor) (*Tensor, error) {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor), nil
}
// ForwardIs performs the forward pass for a model on some specified ivalue input.
func (cm *CModule) ForwardIs(ivalues []IValue) (*IValue, error) {
func (cm *CModule) ForwardIs(ivalues []*IValue) (*IValue, error) {
var civalues []lib.Civalue
for _, i := range ivalues {
@ -1145,7 +1139,7 @@ func (cm *CModule) NamedParameters() ([]NamedTensor, error) {
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
Tensor: &Tensor{v.Ctensor},
Tensor: newTensor(v.Ctensor),
}
namedTensors = append(namedTensors, namedTensor)
@ -1194,7 +1188,7 @@ func (cm *CModule) SetEval() {
// Forwad implements Module interface for CModule.
func (cm *CModule) Forward(tensor *Tensor) (*Tensor, error) {
var tensors []Tensor = []Tensor{*tensor}
var tensors []*Tensor = []*Tensor{tensor}
return cm.ForwardTs(tensors)
}

View File

@ -59,7 +59,7 @@ func TestModuleForwardTs(t *testing.T) {
ts1 := ts.TensorFrom([]int64{42})
ts2 := ts.TensorFrom([]int64{1337})
res, err := foo.ForwardTs([]ts.Tensor{*ts1, *ts2})
res, err := foo.ForwardTs([]*ts.Tensor{ts1, ts2})
if err != nil {
t.Error(err)
}
@ -83,17 +83,17 @@ func TestModuleForwardIValue(t *testing.T) {
ts1 := ts.TensorFrom([]int64{42})
ts2 := ts.TensorFrom([]int64{1337})
iv1 := ts.NewIValue(*ts1)
iv2 := ts.NewIValue(*ts2)
iv1 := ts.NewIValue(ts1)
iv2 := ts.NewIValue(ts2)
got, err := foo.ForwardIs([]ts.IValue{*iv1, *iv2})
got, err := foo.ForwardIs([]*ts.IValue{iv1, iv2})
if err != nil {
t.Error(err)
}
expectedTs1 := ts.TensorFrom([]int64{1421})
expectedTs2 := ts.TensorFrom([]int64{-1295})
want := ts.NewIValue([]ts.Tensor{*expectedTs1, *expectedTs2})
want := ts.NewIValue([]*ts.Tensor{expectedTs1, expectedTs2})
if !reflect.DeepEqual(want.Name(), got.Name()) {
t.Errorf("Expected Ivalue Name: %v\n", want.Name())

15
ts/layout.go Normal file
View File

@ -0,0 +1,15 @@
package ts
// include/c10/core/Layout.h
type Layout int8
const (
Strided Layout = iota // 0
Sparse // 1
SparseCsr // 2
Mkldnn // 3
SparseCsc // 4
SparseBsr // 5
SparseBsc // 6
NumOptions // 7
)

File diff suppressed because it is too large Load Diff

View File

@ -101,22 +101,22 @@ func (h *NpyHeader) ToString() (string, error) {
shape := strings.Join(shapeStr, ",")
var descr string
switch h.descr.Kind().String() {
// case "float16": // NOTE. No float16 in Go primary types. TODO. implement
// descr = "f2"
case "float32":
switch h.descr {
case gotch.Half:
descr = "f2"
case gotch.Float:
descr = "f4"
case "float64":
case gotch.Double:
descr = "f8"
case "int":
case gotch.Int:
descr = "i4"
case "int64":
case gotch.Int64:
descr = "i8"
case "int16":
case gotch.Int16:
descr = "i2"
case "int8":
case gotch.Int8:
descr = "i1"
case "uint8":
case gotch.Uint8:
descr = "u1"
default:
err := fmt.Errorf("Unsupported kind: %v\n", h.descr)
@ -306,6 +306,11 @@ func ReadNpy(filepath string) (*Tensor, error) {
return nil, err
}
// NOTE(TT.). case tensor 1 element with shape = []
if len(data) > 0 && len(header.shape) == 0 {
header.shape = []int64{1}
}
return OfDataSize(data, header.shape, header.descr)
}
@ -348,6 +353,11 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
return nil, err
}
// NOTE(TT.). case tensor 1 element with shape = []
if len(data) > 0 && len(header.shape) == 0 {
header.shape = []int64{1}
}
tensor, err := OfDataSize(data, header.shape, header.descr)
if err != nil {
return nil, err

View File

@ -69,7 +69,7 @@ func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (*COptimizer, error
}
// AddParameters adds parameters as a slice of tensors to optimizer
func (co *COptimizer) AddParameters(tensors []Tensor) error {
func (co *COptimizer) AddParameters(tensors []*Tensor) error {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
@ -127,7 +127,7 @@ func (co *COptimizer) ParamGroupNum() (int64, error) {
return ngroup, nil
}
func (co *COptimizer) AddParamGroup(tensors []Tensor) error {
func (co *COptimizer) AddParamGroup(tensors []*Tensor) error {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)

View File

@ -2,17 +2,14 @@ package ts
// Other tensor methods
import (
"github.com/sugarme/gotch"
)
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
weight := NewTensor()
reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100)
logSm := ts.MustLogSoftmax(-1, gotch.Float, false)
dtype := ts.DType()
logSm := ts.MustLogSoftmax(-1, dtype, false)
return logSm.MustNllLoss(targets, weight, reduction, ignoreIndex, true)
}
@ -21,7 +18,8 @@ func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
argmax := ts.MustArgmax([]int64{-1}, false, false)
eq1 := argmax.MustEqTensor(targets, true)
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
dtype := ts.DType()
return eq1.MustTotype(dtype, true).MustMean(dtype, true)
}
func (ts *Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal *Tensor) {

View File

@ -43,3 +43,20 @@ func ExampleTensorSplitWithSizes(t *testing.T) {
// 8 9
// [ CPUFloatType{4,2} ]
}
// Test Unbind op specific for BFloat16/Half
func TestTensorUnbind(t *testing.T) {
// device := gotch.CudaIfAvailable()
device := gotch.CPU
dtype := gotch.BFloat16
// dtype := gotch.Half // <- NOTE. Libtorch API Error: "arange_cpu" not implemented for 'Half'
x := ts.MustArange(ts.IntScalar(60), dtype, device).MustView([]int64{3, 4, 5}, true)
out := x.MustUnbind(0, true)
if len(out) != 3 {
t.Errorf("Want 3, got %v\n", len(out))
}
}

View File

@ -11,9 +11,9 @@ import (
)
// NOTE. This is a temporarily patched to make it run.
// TODO. make change at generator for []Tensor input
// TODO. make change at generator for []*Tensor input
func (ts *Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor, err error) {
func (ts *Tensor) Lstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor, err error) {
// NOTE: `atg_lstm` will create 3 consecutive Ctensors in memory of C land. The first
// Ctensor will have address given by `ctensorPtr1` here.
@ -55,11 +55,11 @@ func (ts *Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, num
return output, h, c, err
}
return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, &Tensor{ctensor: *ctensorPtr3}, nil
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), newTensor(*ctensorPtr3), nil
}
func (ts *Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor) {
func (ts *Tensor) MustLstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor) {
output, h, c, err := ts.Lstm(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil {
@ -69,7 +69,7 @@ func (ts *Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool,
return output, h, c
}
func (ts *Tensor) Gru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor, err error) {
func (ts *Tensor) Gru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor, err error) {
// NOTE: `atg_gru` will create 2 consecutive Ctensors in memory of C land.
// The first Ctensor will have address given by `ctensorPtr1` here.
@ -105,11 +105,11 @@ func (ts *Tensor) Gru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers
return output, h, err
}
return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, nil
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
}
func (ts *Tensor) MustGru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor) {
func (ts *Tensor) MustGru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor) {
output, h, err := ts.Gru(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
if err != nil {
log.Fatal(err)
@ -139,7 +139,7 @@ func (ts *Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2
return ts1, ts2, err
}
return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, nil
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
}
func (ts *Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor) {
@ -169,7 +169,7 @@ func (ts *Tensor) NLLLoss(target *Tensor, del bool) (retVal *Tensor, err error)
return retVal, err
}
retVal = &Tensor{ctensor: *ptr}
retVal = newTensor(*ptr)
return retVal, nil
}
@ -200,7 +200,7 @@ func (ts *Tensor) MustNLLLoss(target *Tensor, del bool) (retVal *Tensor) {
// tensor *atg_where(tensor condition);
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
func AlignTensors(tensors []Tensor) (retVal []Tensor, err error) {
func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
@ -213,21 +213,21 @@ func AlignTensors(tensors []Tensor) (retVal []Tensor, err error) {
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustAlignTensors(tensors []Tensor, del bool) (retVal []Tensor) {
func MustAlignTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
if del {
for _, t := range tensors {
defer t.MustDrop()
@ -242,7 +242,7 @@ func MustAlignTensors(tensors []Tensor, del bool) (retVal []Tensor) {
}
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
func BroadcastTensors(tensors []Tensor) (retVal []Tensor, err error) {
func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
@ -255,21 +255,21 @@ func BroadcastTensors(tensors []Tensor) (retVal []Tensor, err error) {
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) {
func MustBroadcastTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
if del {
for _, t := range tensors {
defer t.MustDrop()
@ -285,29 +285,28 @@ func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) {
}
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) {
func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor) {
func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
@ -321,7 +320,7 @@ func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor)
}
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
func Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
@ -334,21 +333,21 @@ func Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustMeshgrid(tensors []Tensor) (retVal []Tensor) {
func MustMeshgrid(tensors []*Tensor) (retVal []*Tensor) {
retVal, err := Meshgrid(tensors)
if err != nil {
log.Fatal(err)
@ -358,29 +357,28 @@ func MustMeshgrid(tensors []Tensor) (retVal []Tensor) {
}
// tensor *atg_nonzero_numpy(tensor self);
func (ts *Tensor) NonzeroNumpy() (retVal []Tensor, err error) {
func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
@ -396,10 +394,11 @@ func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
// Split splits tensor into chunks
//
// Parameters:
// - splitSize size of a single chunk
// - dim dimension along which to split the tensor.
// - splitSize size of a single chunk
// - dim dimension along which to split the tensor.
//
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts *Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim)
if err = TorchErr(); err != nil {
@ -411,22 +410,21 @@ func (ts *Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
// calculated from there. The vector of tensors will end if the calculated
// pointer value is `null`.
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
@ -442,10 +440,11 @@ func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
// SplitWithSizes splits tensor into chunks
//
// Parameters:
// - splitSizes slice of sizes for each chunk
// - dim dimension along which to split the tensor.
// - splitSizes slice of sizes for each chunk
// - dim dimension along which to split the tensor.
//
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor, err error) {
func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim)
if err = TorchErr(); err != nil {
@ -457,22 +456,21 @@ func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor
// calculated from there. The vector of tensors will end if the calculated
// pointer value is `null`.
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor) {
func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
@ -486,29 +484,29 @@ func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (r
}
// tensor *atg_unbind(tensor self, int64_t dim);
func (ts *Tensor) Unbind(dim int64) (retVal []Tensor, err error) {
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) {
func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []*Tensor) {
if del {
defer ts.MustDrop()
}
@ -522,29 +520,28 @@ func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) {
}
// tensor *atg_where(tensor condition);
func Where(condition Tensor) (retVal []Tensor, err error) {
func Where(condition Tensor) (retVal []*Tensor, err error) {
ctensorsPtr := lib.AtgWhere(condition.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
retVal = append(retVal, newTensor(*currentPtr))
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
retVal = append(retVal, newTensor(*nextPtr))
currentPtr = nextPtr
}
return retVal, nil
}
func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
func MustWhere(condition Tensor, del bool) (retVal []*Tensor) {
if del {
defer condition.MustDrop()
}

View File

@ -11,41 +11,7 @@ import (
)
func (ts *Tensor) ValueGo() interface{} {
dtype := ts.DType()
numel := ts.Numel()
var dst interface{}
switch dtype {
case gotch.Uint8:
dst = make([]uint8, numel)
case gotch.Int8:
dst = make([]int8, numel)
case gotch.Int16:
dst = make([]int16, numel)
case gotch.Int:
dst = make([]int32, numel)
case gotch.Int64:
dst = make([]int64, numel)
case gotch.Float:
dst = make([]float32, numel)
case gotch.Double:
dst = make([]float64, numel)
case gotch.Bool:
dst = make([]bool, numel)
default:
err := fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
log.Fatal(err)
}
err := ts.CopyData(dst, ts.Numel())
if err != nil {
log.Fatal(err)
}
// convert []int32 -> int
if reflect.TypeOf(dst).String() == "[]int32" {
dst = sliceInt32ToInt(dst.([]int32))
}
return dst
return ts.Vals()
}
func shapeToSize(shape []int64) int {
@ -250,8 +216,8 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
// var typ T
typ := ts.DType()
switch typ.String() {
case "float32", "float64":
switch typ {
case gotch.Half, gotch.BFloat16, gotch.Float, gotch.Double:
switch f.verb {
case 'f', 'e', 'E', 'G', 'b':
// accepted. Do nothing
@ -259,7 +225,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
f.verb = 'g'
}
case "uint8", "int8", "int16", "int32", "int64":
case gotch.Uint8, gotch.Int8, gotch.Int16, gotch.Int, gotch.Int64:
switch f.verb {
case 'b':
f.base = 2
@ -273,7 +239,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
f.base = 10
f.verb = 'd'
}
case "bool":
case gotch.Bool:
f.verb = 't'
default:
f.verb = 'v'
@ -318,7 +284,7 @@ func (ts *Tensor) Format(s fmt.State, verb rune) {
shape := toSliceInt(ts.MustSize())
strides := shapeToStrides(shape)
device := ts.MustDevice()
dtype := ts.DType().String()
dtype := ts.DType()
defined := ts.MustDefined()
if verb == 'i' {
fmt.Fprintf(

View File

@ -1,26 +1,87 @@
package ts
import (
// "unsafe"
"fmt"
"log"
"runtime"
"sync/atomic"
"github.com/sugarme/gotch"
lib "github.com/sugarme/gotch/libtch"
)
type Scalar struct {
name string
cscalar lib.Cscalar
}
// free releases C allocated memory.
func freeCScalar(x *Scalar) error {
nbytes := x.nbytes()
atomic.AddInt64(&AllocatedMem, -nbytes)
lock.Lock()
delete(ExistingScalars, x.name)
lock.Unlock()
lib.AtsFree(x.cscalar)
if err := TorchErr(); err != nil {
return err
}
if gotch.Debug {
log.Printf("INFO: Released scalar %q - C memory: %d bytes.\n", x.name, nbytes)
}
return nil
}
func newScalarName(nameOpt ...string) string {
var name string
if len(nameOpt) > 0 {
name = nameOpt[0]
} else {
name = fmt.Sprintf("tensor_%09d", TensorCount)
}
return name
}
func newScalar(cscalar lib.Cscalar, nameOpt ...string) *Scalar {
x := &Scalar{
cscalar: cscalar,
name: newName(nameOpt...),
}
atomic.AddInt64(&ScalarCount, 1)
nbytes := x.nbytes()
atomic.AddInt64(&AllocatedMem, nbytes)
lock.Lock()
ExistingScalars[x.name] = struct{}{}
lock.Unlock()
if gotch.Debug {
log.Printf("INFO: scalar %q added - Allocated memory (%d bytes).\n", x.name, nbytes)
}
runtime.SetFinalizer(x, freeCScalar)
return x
}
func (sc *Scalar) nbytes() int64 {
return 4 // either Int64 or Float64 scalar -> 4 bytes
}
// IntScalar creates a integer scalar
func IntScalar(v int64) *Scalar {
cscalar := lib.AtsInt(v)
return &Scalar{cscalar}
return newScalar(cscalar)
}
// FloatScalar creates a float scalar
func FloatScalar(v float64) *Scalar {
cscalar := lib.AtsFloat(v)
return &Scalar{cscalar}
return newScalar(cscalar)
}
// ToInt returns a integer value
@ -61,14 +122,17 @@ func (sc *Scalar) ToString() (retVal string, err error) {
// TODO: Really? after running s.Drop() and s.ToInt()
// it returns Zero.
func (sc *Scalar) Drop() (err error) {
lib.AtsFree(sc.cscalar)
return TorchErr()
// TODO. FIXME either remove or rewind for specific scenario
return nil
// lib.AtsFree(sc.cscalar)
// return TorchErr()
}
func (sc *Scalar) MustDrop() {
lib.AtsFree(sc.cscalar)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
// TODO. FIXME either remove or rewind for specific scenario
return
// lib.AtsFree(sc.cscalar)
// if err := TorchErr(); err != nil {
// log.Fatal(err)
// }
}

File diff suppressed because it is too large Load Diff

View File

@ -11,34 +11,194 @@ import (
"fmt"
"log"
"reflect"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
gotch "github.com/sugarme/gotch"
lib "github.com/sugarme/gotch/libtch"
)
type Tensor struct {
ctensor lib.Ctensor
}
var (
TensorCount int64 // incremental counting created tensors
ScalarCount int64 // incremental counting created scalars
AllocatedMem int64 // bytes - keeping track of memory created and still occupied by gotch/tensor (excluding mem allocated by libtorch at C side)
func (ts *Tensor) Ctensor() unsafe.Pointer {
return unsafe.Pointer(ts.ctensor)
}
ExistingTensors map[string]struct{} = make(map[string]struct{}) // keep track of existing tensors by name
ExistingScalars map[string]struct{} = make(map[string]struct{}) // keep track of existing scalar by name
lock sync.Mutex
)
// None is an undefined tensor.
// NOTE. None is an undefined tensor.
// It can be used in optional tensor parameter where 'None' value used.
// `ts.MustDefined()` function is used for checking 'null'
var None = NewTensor()
type bigStruct struct {
lots [1e5]byte // 100k - always on host memory.
}
// Tensor is a Go wrapper of a "C tensor pointer" - 8 Bytes (64-bits OS)
// or 4 Bytes (32-bits OS).
// `ctensor` is just a "C pointer" to `torch::Tensor` (torch::Tensor *lib.Ctensor)
//
// NOTE.Tensor should be big enough to be in heap memory.
// (yes, we choose to place tensor consistently in heap memory so that
// it can be targeted by Go garbage collector).
//
// For heap allocation see. https://stackoverflow.com/questions/10866195
type Tensor struct {
d *bigStruct
name string
ctensor lib.Ctensor
calledFrom string
}
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
if len(nameOpt) == 0 {
nameOpt = []string{}
}
name := newName(nameOpt...)
x := new(Tensor)
x.ctensor = ctensor
x.d = new(bigStruct)
atomic.AddInt64(&TensorCount, 1)
nbytes := x.nbytes()
atomic.AddInt64(&AllocatedMem, nbytes)
lock.Lock()
if _, ok := ExistingTensors[name]; ok {
name = fmt.Sprintf("%s_%09d", name, TensorCount)
}
ExistingTensors[name] = struct{}{}
lock.Unlock()
x.name = name
if gotch.Debug {
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n", x.name, nbytes)
}
x.calledFrom = "newTensor()"
runtime.SetFinalizer(x, freeCTensor)
return x
}
// New creates new tensor from C tensor.
func New(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
return newTensor(ctensor, nameOpt...)
}
func CheckCMemLeak() string {
tensors := []string{}
lock.Lock()
for n := range ExistingTensors {
tensors = append(tensors, n)
}
memUsed := AllocatedMem
lock.Unlock()
var msg string
msg += fmt.Sprintf("============================= C MEMORY CHECK RESULT ==================================\n")
msg += fmt.Sprintf("C memory allocated not been released: %v bytes\n", memUsed)
msg += fmt.Sprintf("Tensors not been released: %q\n", tensors)
msg += fmt.Sprintf("======================================================================================\n")
return msg
}
// CleanUp calls double runtime.GC() with sleep time in between.
func CleanUp(sleepTimeOpt ...int) {
sleepTime := time.Duration(1000) // 1 second
if len(sleepTimeOpt) > 0 {
sleepTime = time.Duration(sleepTimeOpt[0])
}
runtime.GC()
time.Sleep(time.Millisecond * sleepTime)
runtime.GC()
}
// Ctensor return C pointer value.
func (ts *Tensor) Ctensor() unsafe.Pointer {
return unsafe.Pointer(ts.ctensor)
}
// free releases C allocated memory.
func freeCTensor(ts *Tensor) error {
if ts == nil || ts.ctensor == nil {
return nil
}
lock.Lock()
defer lock.Unlock()
if _, ok := ExistingTensors[ts.name]; !ok {
log.Printf("WARNING: Probably double free tensor %q. Called from %q. Just skipping...\n", ts.name, ts.calledFrom)
return nil
}
if gotch.Debug {
shape, err := ts.Size()
if err != nil {
err = fmt.Errorf("ERROR: failed to release tensor %q: %w\n", ts.name, err)
}
log.Printf(err.Error())
numel := uint(FlattenDim(shape))
dtype := ts.DType()
nbytes := int64(numel * dtype.Size())
atomic.AddInt64(&AllocatedMem, -nbytes)
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
}
lib.AtFree(ts.ctensor)
if err := TorchErr(); err != nil {
err := fmt.Errorf("ERROR: failed to release tensor %q - %w", ts.name, err)
return err
}
delete(ExistingTensors, ts.name)
// IMPORTANT. make it nil so won't double free.
ts.ctensor = nil
return nil
}
func newName(nameOpt ...string) string {
var name string
if len(nameOpt) > 0 {
name = nameOpt[0]
} else {
name = fmt.Sprintf("tensor_%09d", TensorCount)
}
return name
}
// NewTensor creates a new tensor
func NewTensor() *Tensor {
func NewTensor(nameOpt ...string) *Tensor {
ctensor := lib.AtNewTensor()
return &Tensor{ctensor}
return newTensor(ctensor, nameOpt...)
}
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
cts := (lib.Ctensor)(ctensor)
return &Tensor{cts}
return newTensor(cts)
}
func (ts *Tensor) Name() string {
return ts.name
}
func (ts *Tensor) Dim() uint64 {
@ -56,6 +216,11 @@ func (ts *Tensor) Dim() uint64 {
// to that slice.
func (ts *Tensor) Size() ([]int64, error) {
dim := lib.AtDim(ts.ctensor)
if dim < 0 || dim > 100 {
err := fmt.Errorf("Invalid dim: %v\n", dim)
return nil, err
}
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
@ -82,6 +247,34 @@ func (ts *Tensor) MustSize() []int64 {
return shape
}
func (ts *Tensor) Stride() ([]int64, error) {
dim := lib.AtDim(ts.ctensor)
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return nil, err
}
defer C.free(unsafe.Pointer(szPtr))
lib.AtStride(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return nil, err
}
strides := decodeSize(szPtr, dim)
return strides, nil
}
func (ts *Tensor) MustStride() []int64 {
strides, err := ts.Stride()
if err != nil {
log.Fatal(err)
}
return strides
}
// Size1 returns the tensor size for 1D tensors.
func (ts *Tensor) Size1() (int64, error) {
shape, err := ts.Size()
@ -142,16 +335,19 @@ func (ts *Tensor) Size4() ([]int64, error) {
return shape, nil
}
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
// Decode sz
// 1. Count number of elements in data
elementNum := nsize
// 2. Element size in bytes
eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64)
if err != nil {
log.Fatal(err)
// nbytes calculates tensor data size in bytes.
func (ts *Tensor) nbytes() int64 {
numel := ts.Numel()
if numel == 0 {
return 0 // ts.None
}
nbytes := int(eltSizeInBytes) * int(elementNum)
return int64(numel * ts.DType().Size())
}
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
dtype := gotch.Int64 // tensor size dtype = int64
nbytes := int(nsize) * int(dtype.Size())
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
r := bytes.NewReader(dataSlice)
dataIn := make([]int64, nsize)
@ -162,20 +358,54 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
return dataIn
}
// TensorOptions constructs options to build/rebuild tensor.
type TensorOptions struct {
Name string
DType gotch.DType
Quantized bool
// TODO. can expand as needed
}
type TensorOpt func(*TensorOptions)
func DefaultTensorOptions() *TensorOptions {
return &TensorOptions{
Name: "",
DType: gotch.Float,
Quantized: false,
}
}
func WithName(v string) TensorOpt {
return func(o *TensorOptions) {
o.Name = v
}
}
func WithDType(v gotch.DType) TensorOpt {
return func(o *TensorOptions) {
o.DType = v
}
}
func WithQuantized(v bool) TensorOpt {
return func(o *TensorOptions) {
o.Quantized = v
}
}
// OfSlice creates tensor from a slice data
func OfSlice(data interface{}) (*Tensor, error) {
func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
if reflect.TypeOf(data).String() == "[]int" {
data = sliceIntToInt32(data.([]int))
}
/*
typ, dataLen, err := DataCheck(data)
if err != nil {
return nil, err
}
*/
v := reflect.ValueOf(data)
kind := v.Kind().String()
if kind != "slice" && kind != "array" {
@ -183,10 +413,10 @@ func OfSlice(data interface{}) (*Tensor, error) {
return nil, err
}
typ := reflect.TypeOf(data).Elem()
elementKind := reflect.TypeOf(data).Elem().Kind()
dataLen := v.Len()
dtype, err := gotch.ToDType(typ)
dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized))
if err != nil {
return nil, err
}
@ -194,12 +424,7 @@ func OfSlice(data interface{}) (*Tensor, error) {
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
nbytes := int(dtype.Size()) * elementNum
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
@ -208,29 +433,25 @@ func OfSlice(data interface{}) (*Tensor, error) {
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor, o.Name), nil
// return newTensor(ctensor), nil
}
// OfDataSize creates Tensor from input byte data, shape and dtype.
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
elementNum := ElementCount(shape)
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
nbytes := int(eltSizeInBytes) * int(elementNum)
elementNum := ElementCount(shape)
nbytes := elementNum * int(dtype.Size())
if nbytes != len(data) {
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
@ -244,23 +465,19 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err := TorchErr(); err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
if err = TorchErr(); err != nil {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor, o.Name), nil
// return newTensor(ctensor), nil
}
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
// or panic if error
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
ts, err := OfDataSize(data, size, dtype)
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor {
ts, err := OfDataSize(data, size, dtype, opts...)
if err != nil {
log.Fatal(err)
}
@ -269,8 +486,8 @@ func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
}
// MustOfSlice create a tensor from slice of data. It will be panic if error.
func MustOfSlice(data interface{}) *Tensor {
ts, err := OfSlice(data)
func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
log.Fatal(err)
}
@ -279,8 +496,8 @@ func MustOfSlice(data interface{}) *Tensor {
}
// TensorFrom create a tensor from slice of data. It will be panic if error.
func TensorFrom(data interface{}) *Tensor {
ts, err := OfSlice(data)
func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor {
ts, err := OfSlice(data, opts...)
if err != nil {
log.Fatal(err)
}
@ -299,7 +516,12 @@ func (ts *Tensor) Print() {
}
// NewTensorFromData creates tensor from given data and shape
func NewTensorFromData(data interface{}, shape []int64) (*Tensor, error) {
func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) {
o := DefaultTensorOptions()
for _, opt := range opts {
opt(o)
}
// 1. Check whether data and shape match
elementNum, err := DataDim(data)
if err != nil {
@ -326,34 +548,18 @@ func NewTensorFromData(data interface{}, shape []int64) (*Tensor, error) {
return nil, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
// defer C.free(unsafe.Pointer(ctensor))
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
if err = TorchErr(); err != nil {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor, o.Name), nil
}
func (ts *Tensor) DType() gotch.DType {
cint := lib.AtScalarType(ts.ctensor)
dtype, err := gotch.CInt2DType(cint)
if err != nil {
log.Fatalf("Tensor DType error: %v\n", err)
}
return dtype
return gotch.CKind2DType(cint)
}
func (ts *Tensor) Device() (gotch.Device, error) {
@ -409,6 +615,7 @@ func (ts *Tensor) MustDevice() gotch.Device {
* return retVal
* }
* */
// Float64Value returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
@ -498,6 +705,15 @@ func (ts *Tensor) DataPtr() (unsafe.Pointer, error) {
return datPtr, nil
}
func (ts *Tensor) MustDataPtr() unsafe.Pointer {
p, err := ts.DataPtr()
if err != nil {
panic(err)
}
return p
}
// Defined returns true is the tensor is defined.
func (ts *Tensor) Defined() (bool, error) {
state := lib.AtDefined(ts.ctensor)
@ -528,6 +744,52 @@ func (ts *Tensor) IsSparse() (bool, error) {
return state, nil
}
func (ts *Tensor) MustIsSparse() bool {
state, err := ts.IsSparse()
if err != nil {
log.Fatal(err)
}
return state
}
// IsContiguous returns true is the tensor is contiguous.
func (ts *Tensor) IsContiguous() (bool, error) {
state := lib.AtIsContiguous(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustIsContiguous() bool {
state, err := ts.IsContiguous()
if err != nil {
log.Fatal(err)
}
return state
}
// IsMkldnn returns true is the tensor is mkldnn.
func (ts *Tensor) IsMkldnn() (bool, error) {
state := lib.AtIsMkldnn(ts.ctensor)
if err := TorchErr(); err != nil {
return false, err
}
return state, nil
}
func (ts *Tensor) MustIsMkldnn() bool {
state, err := ts.IsMkldnn()
if err != nil {
log.Fatal(err)
}
return state
}
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
func (ts *Tensor) ZeroGrad() {
@ -558,7 +820,7 @@ func (ts *Tensor) MustBackward() {
}
// RunBackward runs the backward ...
func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraphB bool) ([]Tensor, error) {
func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGraphB bool) ([]*Tensor, error) {
// NOTE: outputs is a slice of tensors with length = len(inputs)
var outputsPtr []*lib.Ctensor
// Are they allocated contigously??? Definitely not.
@ -599,10 +861,10 @@ func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraph
return nil, err
}
var oTensors []Tensor
var oTensors []*Tensor
for i := 0; i < len(inputs); i++ {
outputPtr := outputsPtr[i]
oTensors = append(oTensors, Tensor{ctensor: *outputPtr})
oTensors = append(oTensors, newTensor(*outputPtr))
}
return oTensors, nil
@ -612,7 +874,6 @@ func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraph
//
// NOTE: `dst` located in Go memory. Should it be?
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
// there will be memory leak and or out of range error.
if len(dst) < int(numel) {
@ -621,12 +882,9 @@ func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
}
vs := unsafe.Pointer(&dst[0])
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
if err = TorchErr(); err != nil {
dtype := gotch.Uint8
lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size())
if err := TorchErr(); err != nil {
return err
}
@ -648,55 +906,20 @@ func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
// and number of elements to C land. This may break in the future
// if Go policy changes.
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
gotype, dlen, err := DataCheck(dst)
if err != nil {
return err
}
dtype, err := gotch.ToDType(gotype)
if err != nil {
return err
}
dtype, dlen, err := DataCheck(dst)
if dlen < int(numel) {
err = fmt.Errorf("CopyData Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
return err
}
if ts.DType() != dtype {
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
var vs unsafe.Pointer
switch dtype {
case gotch.Uint8:
vs = unsafe.Pointer(&dst.([]uint8)[0])
case gotch.Int8:
vs = unsafe.Pointer(&dst.([]int8)[0])
case gotch.Int16:
vs = unsafe.Pointer(&dst.([]int16)[0])
case gotch.Int:
vs = unsafe.Pointer(&dst.([]int32)[0])
case gotch.Int64:
vs = unsafe.Pointer(&dst.([]int64)[0])
case gotch.Float:
vs = unsafe.Pointer(&dst.([]float32)[0])
case gotch.Double:
vs = unsafe.Pointer(&dst.([]float64)[0])
case gotch.Bool:
vs = unsafe.Pointer(&dst.([]bool)[0])
default:
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
return err
}
// Get data pointer
dataPtr := reflect.ValueOf(dst).UnsafePointer()
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
if err != nil {
return err
}
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size())
if err = TorchErr(); err != nil {
return err
}
@ -717,20 +940,24 @@ func (ts *Tensor) MustCopyData(dst interface{}, numel uint) {
// Numel returns the total number of elements stored in a tensor.
func (ts *Tensor) Numel() uint {
if !ts.MustDefined() {
return 0 // ts.None case
}
shape := ts.MustSize()
return uint(FlattenDim(shape))
}
// ShallowClone returns a new tensor that share storage with the input tensor.
func (ts *Tensor) ShallowClone() (*Tensor, error) {
ctensor := lib.AtShallowClone(ts.ctensor)
if err := TorchErr(); err != nil {
return nil, err
}
return &Tensor{ctensor}, nil
name := fmt.Sprintf("%s_cloned", ts.name)
return newTensor(ctensor, name), nil
}
// MustShallowClone returns a new tensor that share storage with the input
@ -752,7 +979,7 @@ func (ts *Tensor) Get(index int) (*Tensor, error) {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor), nil
}
// MustGet gets the sub-tensor at the given index. It will panic if error
@ -804,19 +1031,19 @@ func (ts *Tensor) MustSave(path string) {
}
// Load loads a tensor from a file.
func Load(path string) (*Tensor, error) {
func Load(path string, nameOpt ...string) (*Tensor, error) {
ctensor := lib.AtLoad(path)
if err := TorchErr(); err != nil {
return nil, err
}
return &Tensor{ctensor}, nil
return newTensor(ctensor, nameOpt...), nil
}
// MustLoad loads a tensor to a file. It will panic if error
func MustLoad(path string) *Tensor {
ts, err := Load(path)
func MustLoad(path string, nameOpt ...string) *Tensor {
ts, err := Load(path, nameOpt...)
if err != nil {
log.Fatal(err)
}
@ -877,7 +1104,7 @@ func LoadMulti(path string) ([]NamedTensor, error) {
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
Tensor: &Tensor{v.Ctensor},
Tensor: newTensor(v.Ctensor, v.Name),
}
namedTensors = append(namedTensors, namedTensor)
@ -912,7 +1139,7 @@ func LoadMultiWithDevice(path string, device gotch.Device) ([]NamedTensor, error
for _, v := range data.NamedCtensors {
namedTensor := NamedTensor{
Name: v.Name,
Tensor: &Tensor{v.Ctensor},
Tensor: newTensor(v.Ctensor, v.Name),
}
namedTensors = append(namedTensors, namedTensor)
@ -959,18 +1186,22 @@ func (ts *Tensor) MustToString(lw int64) string {
// Drop drops (frees) the tensor
func (ts *Tensor) Drop() error {
lib.AtFree(ts.ctensor)
if err := TorchErr(); err != nil {
return err
if ts.ctensor == nil {
return nil
}
return nil
// Clear SetFinalizer on ts so no double free tensor.
// Ref. https://pkg.go.dev/runtime#SetFinalizer
runtime.SetFinalizer(ts, nil)
ts.calledFrom = "ts.Drop()"
return freeCTensor(ts)
}
// MustDrop drops the tensor. It will be panic if error
func (ts *Tensor) MustDrop() {
if err := ts.Drop(); err != nil {
log.Fatal(err)
panic(err)
}
}
@ -1025,31 +1256,14 @@ func MustGradSetEnabled(b bool) bool {
}
// NoGrad runs a closure without keeping track of gradients.
func NoGrad(fn interface{}) {
// TODO: This is weird but somehow we need to trigger C++ print
// to get loss function updated. Probably it is related to
// C++ cache clearing.
// Next step would be creating a Go func that trigger C++ cache clean
// instead of this ugly hacky way.
newTs := NewTensor()
newTs.Drop()
func NoGrad(fn func()) {
// Switch off Grad
prev := MustGradSetEnabled(false)
MustGradSetEnabled(false)
// Analyze input as function. If not, throw error
f, err := NewFunc(fn)
if err != nil {
log.Fatal(err)
}
// invokes the function
f.Invoke()
fn()
// Switch on Grad
_ = MustGradSetEnabled(prev)
MustGradSetEnabled(true)
}
func NoGrad1(fn func() interface{}) interface{} {
@ -1140,7 +1354,7 @@ func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
float64Ts := ts.MustTotype(gotch.Double, false)
float64Ts.MustCopyData(vec, numel)
float64Ts.MustDrop()
// float64Ts.MustDrop()
if del {
ts.MustDrop()
@ -1175,33 +1389,18 @@ func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
// E.g. res := xs.Vals().([]int64)
func (ts *Tensor) Vals() interface{} {
dtype := ts.DType()
numel := ts.Numel()
numel := int(ts.Numel())
var retVal interface{}
switch dtype.Name() {
case "uint8":
retVal = make([]uint8, numel)
case "int8":
retVal = make([]int8, numel)
case "int16":
retVal = make([]int16, numel)
case "int32":
retVal = make([]int32, numel)
case "int64":
retVal = make([]int64, numel)
case "float32":
retVal = make([]float32, numel)
case "float64":
retVal = make([]float64, numel)
case "bool":
retVal = make([]bool, numel)
default:
log.Fatalf("Unsupported dtype (%v)", dtype)
typ, err := dtype.GoType()
if err != nil {
log.Fatal(err)
}
ts.CopyData(retVal, numel)
return retVal
dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface()
ts.CopyData(dataSlice, uint(numel))
return dataSlice
}
// FlatView flattens a tensor.
@ -1292,7 +1491,8 @@ func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (re
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = &Tensor{ctensor: *ptr}
retVal = newTensor(*ptr)
return retVal, err
}
@ -1306,3 +1506,51 @@ func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool)
return retVal
}
// TT. Added some torch.cuda APIs for handling CUDA qt
// CudaCurrentDevice get device index of current CUDA device.
func CudaCurrentDevice() (int, error) {
currentDeviceIndex := lib.AtcGetDevice()
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err)
return -99, err
}
return currentDeviceIndex, nil
}
// CudaSetDevice set new cuda device index and returns previous cuda index.
func CudaSetDevice(cudaDeviceIndex int) (int, error) {
currentDeviceIndex, err := CudaCurrentDevice()
if err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
}
lib.AtcSetDevice(cudaDeviceIndex)
if err := TorchErr(); err != nil {
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
return -99, err
}
return currentDeviceIndex, nil
}
// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete.
func CudaSynchronize(cudaDeviceIndexOpt ...int) error {
var cudaDeviceIndex int
var err error
if len(cudaDeviceIndexOpt) > 0 {
cudaDeviceIndex = cudaDeviceIndexOpt[0]
} else {
cudaDeviceIndex, err = CudaCurrentDevice()
if err != nil {
err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err)
return err
}
}
lib.AtcSynchronize(int64(cudaDeviceIndex))
return TorchErr()
}

83
ts/tensor_mem_test.go Normal file
View File

@ -0,0 +1,83 @@
package ts
import (
"fmt"
"log"
"math/rand"
"runtime"
"testing"
"time"
"github.com/sugarme/gotch"
)
var n int = 10
func newData() []float32 {
n := 3 * 224 * 224 * 12
data := make([]float32, n)
for i := 0; i < n; i++ {
data[i] = rand.Float32()
}
return data
}
func printMemStats(message string, rtm runtime.MemStats) {
fmt.Println("\n===", message, "===")
fmt.Println("Mallocs: ", rtm.Mallocs)
fmt.Println("Frees: ", rtm.Frees)
fmt.Println("LiveObjects: ", rtm.Mallocs-rtm.Frees)
fmt.Println("PauseTotalNs: ", rtm.PauseTotalNs)
fmt.Println("NumGC: ", rtm.NumGC)
fmt.Println("LastGC: ", time.UnixMilli(int64(rtm.LastGC/1_000_000)))
fmt.Println("HeapObjects: ", rtm.HeapObjects)
fmt.Println("HeapAlloc: ", rtm.HeapAlloc)
}
func TestMem(t *testing.T) {
var rtm runtime.MemStats
runtime.ReadMemStats(&rtm)
printMemStats("Start", rtm)
for i := 0; i < n; i++ {
x := MustOfSlice(newData())
log.Printf("created tensor : %q\n", x.Name())
}
runtime.ReadMemStats(&rtm)
printMemStats("After completing loop", rtm)
runtime.GC()
runtime.ReadMemStats(&rtm)
printMemStats("After forced GC", rtm)
fmt.Printf(CheckCMemLeak())
}
func TestMem1(t *testing.T) {
var rtm runtime.MemStats
runtime.ReadMemStats(&rtm)
printMemStats("Start", rtm)
for i := 0; i < n; i++ {
x, err := Randn([]int64{2, 3, 224, 224}, gotch.Float, gotch.CPU)
if err != nil {
panic(err)
}
log.Printf("created tensor : %q\n", x.Name())
time.Sleep(time.Millisecond * 3000) // 3secs
}
CleanUp()
runtime.ReadMemStats(&rtm)
printMemStats("After completing loop", rtm)
runtime.GC()
runtime.ReadMemStats(&rtm)
printMemStats("After forced GC", rtm)
fmt.Printf(CheckCMemLeak())
}

View File

@ -131,3 +131,61 @@ func TestOfSlice(t *testing.T) {
t.Errorf("Got dtype: %v\n", got)
}
}
func TestCudaCurrentDevice(t *testing.T) {
cudaIdx, err := ts.CudaCurrentDevice()
if err != nil {
t.Error(err)
}
t.Logf("current CUDA index: %v\n", cudaIdx) // should be 0 if having 1 GPU device
x := ts.MustZeros([]int64{2, 3, 4}, gotch.Float, gotch.CudaIfAvailable())
currentCudaIndex := x.MustDevice().Value
t.Logf("x current cuda index: %v\n", currentCudaIndex) // 0
previousCudaIndex, err := ts.CudaSetDevice(currentCudaIndex)
if err != nil {
t.Error(err)
}
t.Logf("Cuda index BEFORE set: %v\n", previousCudaIndex) // 0
cudaIdxAfter, err := ts.CudaCurrentDevice()
if err != nil {
t.Error(err)
}
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
}
func TestTensor_Stride(t *testing.T) {
shape := []int64{2, 3, 4}
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
got := x.MustStride()
want := []int64{12, 4, 1}
if !reflect.DeepEqual(want, got) {
t.Errorf("want %v, got %v\n", want, got)
}
}
func TestTensor_IsContiguous(t *testing.T) {
shape := []int64{2, 3, 4}
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
got := x.MustIsContiguous()
want := true
if !reflect.DeepEqual(want, got) {
t.Errorf("want %v, got %v\n", want, got)
}
}
func TestTensor_IsMkldnn(t *testing.T) {
shape := []int64{2, 3, 4}
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
got := x.MustIsMkldnn()
want := false
if !reflect.DeepEqual(want, got) {
t.Errorf("want %v, got %v\n", want, got)
}
}

View File

@ -75,7 +75,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
if err := w.WriteByte(b); err != nil {
return err
}
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
}
@ -86,14 +86,14 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
if v.Kind() == reflect.Slice {
expected := int(shape[0])
if v.Len() != expected {
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
return fmt.Errorf("EncodeTensor() failed: mismatched slice lengths: %d and %d", v.Len(), expected)
}
}
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
if len(shape) == 1 && v.Len() > 0 {
switch v.Index(0).Kind() {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
return binary.Write(w, nativeEndian, v.Interface())
}
}
@ -107,7 +107,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
}
default:
return fmt.Errorf("unsupported type %v", v.Type())
return fmt.Errorf("EncodeTensor() failed: unsupported type %v", v.Type())
}
return nil
}
@ -122,7 +122,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
return err
}
ptr.Elem().SetBool(b == 1)
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err
}
@ -134,7 +134,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
if len(shape) == 1 && val.Len() > 0 {
switch val.Index(0).Kind() {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
return binary.Read(r, nativeEndian, val.Interface())
}
}
@ -152,10 +152,10 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
}
// ElementCount counts number of element in the tensor given a shape
func ElementCount(shape []int64) int64 {
n := int64(1)
func ElementCount(shape []int64) int {
n := 1
for _, d := range shape {
n *= d
n *= int(d)
}
return n
}
@ -163,54 +163,48 @@ func ElementCount(shape []int64) int64 {
// DataDim returns number of elements in data
// NOTE: only support scalar and (nested) slice/array of scalar type
func DataDim(data interface{}) (retVal int, err error) {
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
return count, err
}
// DataCheck checks the input data for element Go type and number of elements.
// It will return errors if element type is not supported.
func DataCheck(data interface{}) (k reflect.Type, n int, err error) {
// It will return errors if element dtype is not supported.
func DataCheck(data interface{}) (dtype gotch.DType, n int, err error) {
return dataCheck(reflect.ValueOf(data).Interface(), 0)
}
// NOTE: 0 is reflect.Kind() of Invalid
// See: https://golang.org/pkg/reflect/#Kind
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
func dataCheck(data interface{}, count int) (dtype gotch.DType, n int, err error) {
v := reflect.ValueOf(data)
var goType reflect.Type = reflect.TypeOf(data)
var total int = count
var round = 0
switch v.Kind() {
case reflect.Slice, reflect.Array:
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
if round == 0 {
round = v.Len()
}
for i := 0; i < v.Len(); i++ {
round--
goType, total, err = dataCheck(v.Index(i).Interface(), total)
dtype, total, err = dataCheck(v.Index(i).Interface(), total)
if err != nil {
return reflect.TypeOf(reflect.Zero), 0, err
return gotch.Invalid, 0, err
}
}
return goType, total, nil
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
total++
if goType.String() != "invalid" {
goType = v.Type()
}
default:
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
return reflect.TypeOf(reflect.Zero), 0, err
return dtype, total, nil
}
return goType, total, nil
total += 1
dtype, err = gotch.GoKind2DType(v.Kind())
if err != nil {
err = fmt.Errorf("DataCheck() failed: unsupported data structure or type: %v\n", v.Kind())
return gotch.Invalid, 0, err
}
return dtype, total, nil
}
// DataAsPtr write to C memory and returns a C pointer.
@ -219,25 +213,19 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
// Supported data types are scalar, slice/array of scalar type equivalent to
// DType.
func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
// 1. Count number of elements in data
elementNum, err := DataDim(data)
if err != nil {
return nil, err
}
// 2. Element size in bytes
// 2. Number of bytes
dtype, err := gotch.DTypeFromData(data)
if err != nil {
return nil, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
nbytes := int(dtype.Size()) * int(elementNum)
// 3. Get C pointer and prepare C memory buffer for writing
dataPtr, buff := CMalloc(nbytes)

View File

@ -255,9 +255,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor {
}
rgbTs := x.MustUnbind(-3, false)
r := &rgbTs[0]
g := &rgbTs[1]
b := &rgbTs[2]
r := rgbTs[0]
g := rgbTs[1]
b := rgbTs[2]
// This implementation closely follows the TF one:
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
@ -311,9 +311,9 @@ func adjustSaturation(x *ts.Tensor, sat float64) *ts.Tensor {
func rgb2HSV(x *ts.Tensor) *ts.Tensor {
rgbTs := x.MustUnbind(-3, false)
r := &rgbTs[0]
g := &rgbTs[1]
b := &rgbTs[2]
r := rgbTs[0]
g := rgbTs[1]
b := rgbTs[2]
// # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
// # src/libImaging/Convert.c#L330
@ -383,7 +383,7 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
h3 := h2.MustFmod(ts.FloatScalar(1.0), true) // delete h2
// torch.stack((h, s, maxc), dim=-3)
out := ts.MustStack([]ts.Tensor{*h3, *s, *maxC}, -3)
out := ts.MustStack([]*ts.Tensor{h3, s, maxC}, -3)
// Delete intermediate tensors
r.MustDrop()
@ -409,9 +409,9 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
func hsv2RGB(x *ts.Tensor) *ts.Tensor {
hsvTs := x.MustUnbind(-3, false)
h := &hsvTs[0]
s := &hsvTs[1]
v := &hsvTs[2]
h := hsvTs[0]
s := hsvTs[1]
v := hsvTs[2]
// i = torch.floor(h * 6.0)
i := h.MustMulScalar(ts.FloatScalar(6.0), false).MustFloor(true)
// f = (h * 6.0) - i
@ -448,12 +448,12 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
// a2 = torch.stack((t, v, v, q, p, p), dim=-3)
// a3 = torch.stack((p, p, t, v, v, q), dim=-3)
// a4 = torch.stack((a1, a2, a3), dim=-4)
a1 := ts.MustStack([]ts.Tensor{*v, *q, *p, *p, *t, *v}, -3)
a2 := ts.MustStack([]ts.Tensor{*t, *v, *v, *q, *p, *p}, -3)
a3 := ts.MustStack([]ts.Tensor{*p, *p, *t, *v, *v, *q}, -3)
a4 := ts.MustStack([]ts.Tensor{*a1, *a2, *a3}, -4)
a1 := ts.MustStack([]*ts.Tensor{v, q, p, p, t, v}, -3)
a2 := ts.MustStack([]*ts.Tensor{t, v, v, q, p, p}, -3)
a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3)
a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4)
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []ts.Tensor{*mask, *a4})
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}, []int64{0, 1})
// Delete intermediate tensors
h.MustDrop()
@ -491,13 +491,13 @@ func adjustHue(x *ts.Tensor, hue float64) *ts.Tensor {
hsvImg := rgb2HSV(imgFl)
hsvTs := hsvImg.MustUnbind(-3, true)
h := &hsvTs[0]
s := &hsvTs[1]
v := &hsvTs[2]
h := hsvTs[0]
s := hsvTs[1]
v := hsvTs[2]
// h = (h + hue_factor) % 1.0
hAdj := h.MustAddScalar(ts.FloatScalar(hue), false).MustRemainder(ts.FloatScalar(1.0), true)
hsvAdj := ts.MustStack([]ts.Tensor{*hAdj, *s, *v}, -3)
hsvAdj := ts.MustStack([]*ts.Tensor{hAdj, s, v}, -3)
imgHueAdj := hsv2RGB(hsvAdj)
@ -568,7 +568,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
dim := x.MustSize()
c := dim[0]
var chans []ts.Tensor = make([]ts.Tensor, c)
var chans []*ts.Tensor = make([]*ts.Tensor, c)
hNar := ts.NewNarrow(top, top+height)
wNar := ts.NewNarrow(left, left+width)
for i := 0; i < int(c); i++ {
@ -579,7 +579,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
x2 := x1T.Idx(wNar)
x1T.MustDrop()
out := x2.MustT(true)
chans[i] = *out
chans[i] = out
}
cropTs := ts.MustStack(chans, 0)
@ -728,7 +728,7 @@ func applyGridTransform(x, gridInput *ts.Tensor, mode string, fillValue []float6
// dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
// img = torch.cat((img, dummy), dim=1)
dummy := ts.MustOnes([]int64{img.MustSize()[0], 1, img.MustSize()[2], img.MustSize()[3]}, img.DType(), img.MustDevice())
imgCat := ts.MustCat([]ts.Tensor{*img, *dummy}, 1)
imgCat := ts.MustCat([]*ts.Tensor{img, dummy}, 1)
dummy.MustDrop()
img.MustDrop()
@ -779,9 +779,9 @@ func applyGridTransform(x, gridInput *ts.Tensor, mode string, fillValue []float6
// (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
// Args:
// - startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
// ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
// “[top-left, top-right, bottom-right, bottom-left]“ of the original image.
// - endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
// ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
// “[top-left, top-right, bottom-right, bottom-left]“ of the transformed image.
// Returns:
// - octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
func perspectiveCoeff(startPoints, endPoints [][]int64) []float64 {
@ -929,7 +929,7 @@ func perspective(x *ts.Tensor, startPoints, endPoints [][]int64, mode string, fi
// Apply affine transformation on the image keeping image center invariant.
//
//If the image is torch Tensor, it is expected
// If the image is torch Tensor, it is expected
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
// Args:
// - img (Tensor): image to transform.
@ -940,8 +940,8 @@ func perspective(x *ts.Tensor, startPoints, endPoints [][]int64, mode string, fi
// If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
// the second value corresponds to a shear parallel to the y axis.
// - interpolation (InterpolationMode): Desired interpolation enum defined by
// :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
// If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
// :class:`torchvision.transforms.InterpolationMode`. Default is “InterpolationMode.NEAREST“.
// If input is Tensor, only “InterpolationMode.NEAREST“, “InterpolationMode.BILINEAR“ are supported.
// - fill (sequence or number, optional): Pixel fill value for the area outside the transformed
// image. If given a number, the value is used for all bands respectively.
func affine(img *ts.Tensor, angle float64, translations []int64, scale float64, shear []float64, interpolationMode string, fillValue []float64) *ts.Tensor {
@ -982,17 +982,19 @@ func affine(img *ts.Tensor, angle float64, translations []int64, scale float64,
// As it is explained in PIL.Image.rotate
// We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
// where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
// C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
// RSS is rotation with scale and shear matrix
// RSS(a, s, (sx, sy)) =
// = R(a) * S(s) * SHy(sy) * SHx(sx)
// = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
// [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
// [ 0 , 0 , 1 ]
//
// C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
// RSS is rotation with scale and shear matrix
// RSS(a, s, (sx, sy)) =
// = R(a) * S(s) * SHy(sy) * SHx(sx)
// = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
// [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
// [ 0 , 0 , 1 ]
//
// where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
// SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
// [0, 1 ] [-tan(s), 1]
//
// [0, 1 ] [-tan(s), 1]
//
// Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
func getInverseAffineMatrix(center []float64, angle float64, translate []float64, scale float64, shear []float64) []float64 {
@ -1345,11 +1347,11 @@ func equalize(img *ts.Tensor) *ts.Tensor {
}
// batched images
var images []ts.Tensor
var images []*ts.Tensor
for i := 0; i < int(shape[0]); i++ {
x := img.MustSelect(0, int64(i), false)
o := equalizeSingleImage(x)
images = append(images, *o)
images = append(images, o)
x.MustDrop()
}
@ -1363,12 +1365,12 @@ func equalize(img *ts.Tensor) *ts.Tensor {
func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
dim := img.MustSize()
var scaledChans []ts.Tensor = make([]ts.Tensor, int(dim[0]))
var scaledChans []*ts.Tensor = make([]*ts.Tensor, int(dim[0]))
for i := 0; i < int(dim[0]); i++ {
cTs := img.MustSelect(0, int64(i), false)
scaledChan := scaleChannel(cTs)
cTs.MustDrop()
scaledChans[i] = *scaledChan
scaledChans[i] = scaledChan
}
out := ts.MustStack(scaledChans, 0)

View File

@ -83,8 +83,8 @@ func CFLoadDir(dir string) *Dataset {
testImages, testLabels := readFile(fmt.Sprintf("%v/test_batch.bin", dirAbs))
var trainImages []ts.Tensor
var trainLabels []ts.Tensor
var trainImages []*ts.Tensor
var trainLabels []*ts.Tensor
trainFiles := []string{
"data_batch_1.bin",
@ -96,8 +96,8 @@ func CFLoadDir(dir string) *Dataset {
for _, f := range trainFiles {
img, l := readFile(fmt.Sprintf("%v/%v", dirAbs, f))
trainImages = append(trainImages, *img)
trainLabels = append(trainLabels, *l)
trainImages = append(trainImages, img)
trainLabels = append(trainLabels, l)
}
return &Dataset{

View File

@ -39,7 +39,7 @@ func (l *denseLayer) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
ys := ys5.Apply(l.Conv2)
ys5.MustDrop()
res := ts.MustCat([]ts.Tensor{*xs, *ys}, 1)
res := ts.MustCat([]*ts.Tensor{xs, ys}, 1)
ys.MustDrop()
return res

View File

@ -212,7 +212,7 @@ func LoadAndResize(path string, outW int64, outH int64) (*ts.Tensor, error) {
// LoadDir loads all the images in a directory.
func LoadDir(dir string, outW int64, outH int64) (*ts.Tensor, error) {
var filePaths []string // "dir/filename.ext"
var tensors []ts.Tensor
var tensors []*ts.Tensor
files, err := ioutil.ReadDir(dir)
if err != nil {
err = fmt.Errorf("LoadDir - Read directory error: %v\n", err)
@ -228,7 +228,7 @@ func LoadDir(dir string, outW int64, outH int64) (*ts.Tensor, error) {
err = fmt.Errorf("LoadDir - LoadAndResize method call error: %v\n", err)
return nil, err
}
tensors = append(tensors, *tensor)
tensors = append(tensors, tensor)
}
stackedTs, err := ts.Stack(tensors, 0)

View File

@ -169,7 +169,7 @@ func (in *ImageNet) hasSuffix(path string) bool {
}
func (in *ImageNet) loadImageFromDir(dir string) (*ts.Tensor, error) {
var images []ts.Tensor
var images []*ts.Tensor
files, err := ioutil.ReadDir(dir)
if err != nil {
@ -188,7 +188,7 @@ func (in *ImageNet) loadImageFromDir(dir string) (*ts.Tensor, error) {
return nil, err
}
images = append(images, *img)
images = append(images, img)
}
if len(images) == 0 {
@ -243,10 +243,10 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
// fmt.Printf("Classess: %v\n", classes)
var (
trainImages []ts.Tensor
trainLabels []ts.Tensor
testImages []ts.Tensor
testLabels []ts.Tensor
trainImages []*ts.Tensor
trainLabels []*ts.Tensor
testImages []*ts.Tensor
testLabels []*ts.Tensor
)
for labelIdx, labelDir := range classes {
@ -261,10 +261,10 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
}
ntrainTs := trainTs.MustSize()[0]
trainImages = append(trainImages, *trainTs)
trainImages = append(trainImages, trainTs)
trainLabelOnes := ts.MustOnes([]int64{ntrainTs}, gotch.Int64, gotch.CPU)
trainLabels = append(trainLabels, *trainLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
trainLabels = append(trainLabels, trainLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
// test
testDir := fmt.Sprintf("%v/%v", validPath, labelDir)
@ -274,10 +274,10 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
return nil, err
}
ntestTs := testTs.MustSize()[0]
testImages = append(testImages, *testTs)
testImages = append(testImages, testTs)
testLabelOnes := ts.MustOnes([]int64{ntestTs}, gotch.Int64, gotch.CPU)
testLabels = append(testLabels, *testLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
testLabels = append(testLabels, testLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
}
trainImageTs := ts.MustCat(trainImages, 0)
@ -298,7 +298,7 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
}, nil
}
func dropTsSlice(tensors []ts.Tensor) {
func dropTsSlice(tensors []*ts.Tensor) {
for i := 0; i < len(tensors); i++ {
tensors[i].MustDrop()
}

View File

@ -81,7 +81,7 @@ func inceptionA(p *nn.Path, cIn, cPool int64) ts.ModuleT {
bpoolTmp := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
bpoolTs := bpoolTmp.ApplyT(bpool, train)
res := ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
res := ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
return res
})
@ -104,7 +104,7 @@ func inceptionB(p *nn.Path, cIn int64) ts.ModuleT {
bpoolTs := inMaxPool2D(xs, 3, 2)
res := ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *bpoolTs}, 1)
res := ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1)
return res
})
@ -148,7 +148,7 @@ func inceptionC(p *nn.Path, cIn int64, c7 int64) ts.ModuleT {
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
bpoolTs := bpTmp1.ApplyT(bpool, train)
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
return ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
})
}
@ -177,7 +177,7 @@ func inceptionD(p *nn.Path, cIn int64) ts.ModuleT {
bpoolTs := inMaxPool2D(xs, 3, 2)
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *bpoolTs}, 1)
return ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1)
})
}
@ -202,19 +202,19 @@ func inceptionE(p *nn.Path, cIn int64) ts.ModuleT {
b2Tmp := xs.ApplyT(b21, train)
b2aTs := b2Tmp.ApplyT(b22a, train)
b2bTs := b2Tmp.ApplyT(b22b, train)
b2Ts := ts.MustCat([]ts.Tensor{*b2aTs, *b2bTs}, 1)
b2Ts := ts.MustCat([]*ts.Tensor{b2aTs, b2bTs}, 1)
b3Tmp1 := xs.ApplyT(b31, train)
b3Tmp2 := b3Tmp1.ApplyT(b32, train)
b3Tmp1.MustDrop()
b3aTs := b3Tmp2.ApplyT(b33a, train)
b3bTs := b3Tmp2.ApplyT(b33b, train)
b3Ts := ts.MustCat([]ts.Tensor{*b3aTs, *b3bTs}, 1)
b3Ts := ts.MustCat([]*ts.Tensor{b3aTs, b3bTs}, 1)
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
bpoolTs := bpTmp1.ApplyT(bpool, train)
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
return ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
})
}

View File

@ -31,7 +31,7 @@ func fire(p *nn.Path, cIn int64, cSqueeze int64, cExp1 int64, cExp3 int64) ts.Mo
exp3Tmp := tmp2.Apply(exp3)
exp3Ts := exp3Tmp.MustRelu(true)
return ts.MustCat([]ts.Tensor{*exp1Ts, *exp3Ts}, 1)
return ts.MustCat([]*ts.Tensor{exp1Ts, exp3Ts}, 1)
})
}