commit
83394ef093
|
@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
- reworked `ts.Format()`
|
||||
- Added conv2d benchmark
|
||||
- Fixed #88 memory leak at `example/char-rnn`
|
||||
- Added missing tensor `Stride()` and `MustDataPtr()`, `IsMkldnn`, `MustIsMkldnn`, `IsContiguous`, `MustIsContiguous`
|
||||
- Added ts `New()`
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
27
README.md
27
README.md
|
@ -16,16 +16,16 @@
|
|||
|
||||
`gotch` is in active development mode and may have API breaking changes. Feel free to pull request, report issues or discuss any concerns. All contributions are welcome.
|
||||
|
||||
`gotch` current version is **v0.7.0**
|
||||
`gotch` current version is **v0.8.0**
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **Libtorch** C++ v1.11.0 library of [Pytorch](https://pytorch.org/)
|
||||
- **Libtorch** C++ v2.0.1 library of [Pytorch](https://pytorch.org/)
|
||||
|
||||
## Installation
|
||||
|
||||
- Default CUDA version is `11.3` if CUDA is available otherwise using CPU version.
|
||||
- Default Pytorch C++ API version is `1.11.0`
|
||||
- Default CUDA version is `11.7` if CUDA is available otherwise using CPU version.
|
||||
- Default Pytorch C++ API version is `2.0.1`
|
||||
|
||||
**NOTE**: `libtorch` will be installed at **`/usr/local/lib`**
|
||||
|
||||
|
@ -53,7 +53,7 @@
|
|||
```bash
|
||||
wget https://raw.githubusercontent.com/sugarme/gotch/master/setup-gotch.sh
|
||||
chmod +x setup-gotch.sh
|
||||
export CUDA_VER=cpu && export GOTCH_VER=v0.7.0 && bash setup-gotch.sh
|
||||
export CUDA_VER=cpu && export GOTCH_VER=v0.8.0 && bash setup-gotch.sh
|
||||
```
|
||||
|
||||
### GPU
|
||||
|
@ -66,19 +66,10 @@
|
|||
|
||||
#### Step 1: Setup libtorch (skip this step if a valid libtorch already installed in your machine!)
|
||||
|
||||
**IMPORTANT NOTE FOR CUDA 11.1**:
|
||||
- Pytorch has not provided `libtorch-1.11` for CUDA 11.1 yet
|
||||
- If you have CUDA 11.1 installed in your machine and try to install `libtorch-1.11` for CUDA 11.3, you might have [linking issue here](https://github.com/pytorch/pytorch/issues/73829)
|
||||
- Download and install [nightly libtorch 1.11 for CUDA 11.1](https://download.pytorch.org/libtorch/nightly/cu113/libtorch-cxx11-abi-shared-with-deps-latest.zip) will help `gotch` compiled successfully.
|
||||
|
||||
```bash
|
||||
wget https://raw.githubusercontent.com/sugarme/gotch/master/setup-libtorch.sh
|
||||
chmod +x setup-libtorch.sh
|
||||
|
||||
# CUDA 10.2
|
||||
export CUDA_VER=10.2 && bash setup-libtorch.sh
|
||||
# CUDA 11.3
|
||||
export CUDA_VER=11.3 && bash setup-libtorch.sh
|
||||
export CUDA_VER=11.7 && bash setup-libtorch.sh
|
||||
```
|
||||
|
||||
**Update Environment**: in Debian/Ubuntu, add/update the following lines to `.bashrc` file
|
||||
|
@ -95,10 +86,8 @@
|
|||
```bash
|
||||
wget https://raw.githubusercontent.com/sugarme/gotch/master/setup-gotch.sh
|
||||
chmod +x setup-gotch.sh
|
||||
# CUDA 10.2
|
||||
export CUDA_VER=10.2 && export GOTCH_VER=v0.7.0 && bash setup-gotch.sh
|
||||
# CUDA 11.3
|
||||
export CUDA_VER=11.3 && export GOTCH_VER=v0.7.0 && bash setup-gotch.sh
|
||||
# CUDA 11.7
|
||||
export CUDA_VER=11.7 && export GOTCH_VER=v0.8.0 && bash setup-gotch.sh
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
|
651
dtype.go
651
dtype.go
|
@ -3,8 +3,6 @@ package gotch
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
// "log"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
|
@ -12,151 +10,148 @@ import (
|
|||
type CInt = int32
|
||||
|
||||
// DType represents different kind of element that a tensor can hold.
|
||||
// It has an embedded `reflect.Type` for type reflection.
|
||||
type DType struct {
|
||||
reflect.Type
|
||||
}
|
||||
// Ref. https://github.com/pytorch/pytorch/blob/a290cbf32b0c282aa60fa521ca5c6cd19c7f779f/c10/core/ScalarType.h
|
||||
type DType int
|
||||
|
||||
/*
|
||||
* // Custom-made Float16 as not exist in Go
|
||||
* // Ref: https://github.com/golang/go/issues/32022
|
||||
* type GoFloat16 = int16 // not implemented yet
|
||||
* type GoComplexHalf = interface{} // not implemented yet!
|
||||
* */
|
||||
|
||||
// TODO: double check these Torch DType to Go type
|
||||
var (
|
||||
Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0
|
||||
Int8 DType = DType{reflect.TypeOf(int8(1))} // 1
|
||||
Int16 DType = DType{reflect.TypeOf(int16(1))} // 2
|
||||
Int DType = DType{reflect.TypeOf(int32(1))} // 3
|
||||
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
|
||||
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
|
||||
Half DType = DType{reflect.TypeOf(float32(1))} // 5
|
||||
Float DType = DType{reflect.TypeOf(float32(1))} // 6
|
||||
Double DType = DType{reflect.TypeOf(float64(1))} // 7
|
||||
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
|
||||
// ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9
|
||||
// ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10
|
||||
Bool DType = DType{reflect.TypeOf(true)} // 11
|
||||
const (
|
||||
Invalid DType = -1
|
||||
Uint8 DType = 0
|
||||
Int8 DType = 1
|
||||
Int16 DType = 2
|
||||
Int DType = 3
|
||||
Int64 DType = 4
|
||||
Half DType = 5
|
||||
Float DType = 6
|
||||
Double DType = 7
|
||||
ComplexHalf DType = 8
|
||||
ComplexFloat DType = 9
|
||||
ComplexDouble DType = 10
|
||||
Bool DType = 11
|
||||
QInt8 DType = 12
|
||||
QUInt8 DType = 13
|
||||
QInt32 DType = 14
|
||||
BFloat16 DType = 15
|
||||
// ---not implemented ---
|
||||
QUInt4x2 DType = 16
|
||||
QUInt2x4 DType = 17
|
||||
Bits1x8 DType = 18
|
||||
Bits2x4 DType = 19
|
||||
Bits4x2 DType = 20
|
||||
Bits8 DType = 21
|
||||
Bits16 DType = 22
|
||||
)
|
||||
|
||||
var dtypeGoType = map[DType]reflect.Type{
|
||||
Uint8: reflect.TypeOf(uint8(1)),
|
||||
Int8: reflect.TypeOf(int8(1)),
|
||||
Int16: reflect.TypeOf(int16(1)),
|
||||
Int: reflect.TypeOf(int32(1)),
|
||||
Int64: reflect.TypeOf(int64(1)),
|
||||
Half: reflect.TypeOf(float32(1)),
|
||||
Float: reflect.TypeOf(float32(1)),
|
||||
Double: reflect.TypeOf(float64(1)),
|
||||
Bool: reflect.TypeOf(true),
|
||||
var dtype2CKind = map[DType]CInt{
|
||||
Uint8: 0,
|
||||
Int8: 1,
|
||||
Int16: 2,
|
||||
Int: 3,
|
||||
Int64: 4,
|
||||
Half: 5,
|
||||
Float: 6,
|
||||
Double: 7,
|
||||
ComplexHalf: 8,
|
||||
ComplexFloat: 9,
|
||||
ComplexDouble: 10,
|
||||
Bool: 11,
|
||||
QInt8: 12,
|
||||
QUInt8: 13,
|
||||
QInt32: 14,
|
||||
BFloat16: 15,
|
||||
// ---not implemented ---
|
||||
QUInt4x2: 16,
|
||||
QUInt2x4: 17,
|
||||
Bits1x8: 18,
|
||||
Bits2x4: 19,
|
||||
Bits4x2: 20,
|
||||
Bits8: 21,
|
||||
Bits16: 22,
|
||||
}
|
||||
|
||||
// ToDType infers and returns supported equivalent DType from given Go type
|
||||
func ToDType(typ reflect.Type) (retVal DType, err error) {
|
||||
var found = false
|
||||
for key, val := range dtypeGoType {
|
||||
if val == typ {
|
||||
retVal = key
|
||||
found = true
|
||||
break
|
||||
}
|
||||
func (dt DType) CKind() CInt {
|
||||
if cint, ok := dtype2CKind[dt]; ok {
|
||||
return cint
|
||||
}
|
||||
|
||||
if !found {
|
||||
err = fmt.Errorf("Unsupported Go type: %v", typ)
|
||||
return DType{}, err
|
||||
if Debug {
|
||||
log.Printf("WARNING: dt.CKind() failed: no corresponding CKind to this DType %v\n", dt)
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
return -1 // invalid
|
||||
}
|
||||
|
||||
// ToGoType infers and returns supported equivalent Go type from given DType
|
||||
func ToGoType(dtype DType) (retVal reflect.Type, err error) {
|
||||
if _, ok := dtypeGoType[dtype]; !ok {
|
||||
err = fmt.Errorf("Unsupported DType %v", dtype)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retVal = dtypeGoType[dtype]
|
||||
|
||||
return retVal, nil
|
||||
// Back compat
|
||||
func (dt DType) CInt() CInt {
|
||||
return dt.CKind()
|
||||
}
|
||||
|
||||
var dtypeCInt = map[DType]CInt{
|
||||
Uint8: 0,
|
||||
Int8: 1,
|
||||
Int16: 2,
|
||||
Int: 3,
|
||||
Int64: 4,
|
||||
Half: 5,
|
||||
Float: 6,
|
||||
Double: 7,
|
||||
Bool: 11,
|
||||
var ckind2DType map[CInt]DType = map[CInt]DType{
|
||||
0: Uint8,
|
||||
1: Int8,
|
||||
2: Int16,
|
||||
3: Int,
|
||||
4: Int64,
|
||||
5: Half,
|
||||
6: Float,
|
||||
7: Double,
|
||||
8: ComplexHalf,
|
||||
9: ComplexFloat,
|
||||
10: ComplexDouble,
|
||||
11: Bool,
|
||||
12: QInt8,
|
||||
13: QUInt8,
|
||||
14: QInt32,
|
||||
15: BFloat16,
|
||||
// ---not implemented ---
|
||||
16: QUInt4x2,
|
||||
17: QUInt2x4,
|
||||
18: Bits1x8,
|
||||
19: Bits2x4,
|
||||
20: Bits4x2,
|
||||
21: Bits8,
|
||||
22: Bits16,
|
||||
}
|
||||
|
||||
func DType2CInt(dt DType) (retVal CInt, err error) {
|
||||
if _, ok := dtypeCInt[dt]; !ok {
|
||||
err = fmt.Errorf("Unsupported CInt conversion from DType: %v\n", dt)
|
||||
func CKind2DType(ckind int32) DType {
|
||||
if dtype, ok := ckind2DType[ckind]; ok {
|
||||
return dtype
|
||||
}
|
||||
|
||||
retVal = dtypeCInt[dt]
|
||||
|
||||
return retVal, nil
|
||||
if Debug {
|
||||
log.Printf("WARNING: CKind2DType() failed: no corresponding DType to input CInt %v\n", ckind)
|
||||
}
|
||||
return -1 // invalid
|
||||
}
|
||||
|
||||
func (dt DType) CInt() (retVal CInt) {
|
||||
retVal, err := DType2CInt(dt)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
var dtypeSize map[DType]uint = map[DType]uint{
|
||||
Uint8: 1,
|
||||
Int8: 1,
|
||||
Int16: 2,
|
||||
Int: 4,
|
||||
Int64: 8,
|
||||
Half: 2,
|
||||
Float: 4,
|
||||
Double: 8,
|
||||
ComplexHalf: 4,
|
||||
ComplexFloat: 8,
|
||||
ComplexDouble: 16,
|
||||
Bool: 1,
|
||||
QInt8: 1,
|
||||
QUInt8: 1,
|
||||
QInt32: 4,
|
||||
BFloat16: 2,
|
||||
QUInt4x2: 2,
|
||||
QUInt2x4: 1,
|
||||
// ---not implemented ---
|
||||
Bits1x8: 1,
|
||||
Bits2x4: 1,
|
||||
Bits4x2: 1,
|
||||
Bits8: 1,
|
||||
Bits16: 2,
|
||||
}
|
||||
|
||||
func CInt2DType(v CInt) (dtype DType, err error) {
|
||||
var found = false
|
||||
for key, val := range dtypeCInt {
|
||||
if val == v {
|
||||
dtype = key
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
err = fmt.Errorf("Unsuported DType for CInt %v\n", v)
|
||||
return DType{}, err
|
||||
}
|
||||
|
||||
return dtype, nil
|
||||
|
||||
}
|
||||
|
||||
// dtypeSize is a map of DType and its size in Bytes
|
||||
var dtypeSize = map[DType]uint{
|
||||
Uint8: 1,
|
||||
Int8: 1,
|
||||
Int16: 2,
|
||||
Int: 4,
|
||||
Int64: 8,
|
||||
Half: 4, // Should it be?
|
||||
Float: 4,
|
||||
Double: 8,
|
||||
Bool: 1,
|
||||
}
|
||||
|
||||
// DTypeSize returns DType size in Bytes
|
||||
func DTypeSize(dt DType) (retVal uint, err error) {
|
||||
if _, ok := dtypeSize[dt]; !ok {
|
||||
err = fmt.Errorf("Unsupported conversion DType size in Byte for DType: %v\n", dt)
|
||||
return 99, err
|
||||
}
|
||||
|
||||
retVal = dtypeSize[dt]
|
||||
|
||||
return retVal, nil
|
||||
// Size returns dtype size in Bytes.
|
||||
func (dt DType) Size() uint {
|
||||
return dtypeSize[dt]
|
||||
}
|
||||
|
||||
type DTypeDevice struct {
|
||||
|
@ -174,201 +169,255 @@ var (
|
|||
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
|
||||
)
|
||||
|
||||
// Type Inferring:
|
||||
// ===============
|
||||
|
||||
// DTypeFromData infers returns equavalent DType from given data
|
||||
func DTypeFromData(data interface{}) (retVal DType, err error) {
|
||||
|
||||
// NOTE: call `Interface()` to get data type back to interface{} type
|
||||
typ, _, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
if typ.Kind() == reflect.Slice {
|
||||
return ToDType(typ.Elem())
|
||||
}
|
||||
|
||||
return ToDType(typ)
|
||||
var dtype2GoKind map[DType]reflect.Kind = map[DType]reflect.Kind{
|
||||
Uint8: reflect.Uint8,
|
||||
Int8: reflect.Int8,
|
||||
Int16: reflect.Int16,
|
||||
Int: reflect.Int32,
|
||||
Int64: reflect.Int64,
|
||||
Half: reflect.Uint16, // <- just uint16
|
||||
Float: reflect.Float32,
|
||||
Double: reflect.Float64,
|
||||
ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
||||
ComplexFloat: reflect.Complex64,
|
||||
ComplexDouble: reflect.Complex128,
|
||||
Bool: reflect.Bool,
|
||||
QInt8: reflect.Int8,
|
||||
QUInt8: reflect.Uint8,
|
||||
QInt32: reflect.Int32,
|
||||
BFloat16: reflect.Uint16, // <- just uint16
|
||||
// ---not implemented ---
|
||||
QUInt4x2: reflect.Invalid,
|
||||
QUInt2x4: reflect.Invalid,
|
||||
Bits1x8: reflect.Invalid,
|
||||
Bits2x4: reflect.Invalid,
|
||||
Bits4x2: reflect.Invalid,
|
||||
Bits8: reflect.Invalid,
|
||||
Bits16: reflect.Invalid,
|
||||
}
|
||||
|
||||
// NOTE: 0 is reflect.Kind() of Invalid
|
||||
// See: https://golang.org/pkg/reflect/#Kind
|
||||
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
||||
v := reflect.ValueOf(data)
|
||||
var goType reflect.Type = reflect.TypeOf(data)
|
||||
var total int = count
|
||||
var round = 0
|
||||
func (dt DType) GoKind() reflect.Kind {
|
||||
if kind, ok := dtype2GoKind[dt]; ok && kind != reflect.Invalid {
|
||||
return kind
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if round == 0 {
|
||||
round = v.Len()
|
||||
if Debug {
|
||||
log.Printf("WARNING: DType.GoKind() failed: no corresponding Go reflect.Kind to given DType %v\n", dt)
|
||||
}
|
||||
|
||||
return reflect.Invalid
|
||||
}
|
||||
|
||||
const (
|
||||
// NOTE. reflect.Kind 0-26
|
||||
QUInt8Kind reflect.Kind = 27
|
||||
QInt8Kind reflect.Kind = 28
|
||||
QInt32Kind reflect.Kind = 29
|
||||
Float16Kind reflect.Kind = 30
|
||||
BFloat16Kind reflect.Kind = 31
|
||||
QUInt4x2Kind reflect.Kind = 32
|
||||
QUInt2x4Kind reflect.Kind = 33
|
||||
Bits1x8Kind reflect.Kind = 34
|
||||
Bits2x4Kind reflect.Kind = 35
|
||||
Bits4x2Kind reflect.Kind = 36
|
||||
Bits8Kind reflect.Kind = 37
|
||||
Bits16Kind reflect.Kind = 38
|
||||
ComplexHalfKind reflect.Kind = 39
|
||||
)
|
||||
|
||||
var goKind2DType map[reflect.Kind]DType = map[reflect.Kind]DType{
|
||||
reflect.Uint8: Uint8,
|
||||
reflect.Int8: Int8,
|
||||
reflect.Int16: Int16,
|
||||
reflect.Int32: Int,
|
||||
reflect.Int64: Int64,
|
||||
reflect.Float32: Float,
|
||||
reflect.Float64: Double,
|
||||
reflect.Complex64: ComplexFloat,
|
||||
reflect.Complex128: ComplexDouble,
|
||||
reflect.Bool: Bool,
|
||||
reflect.Uint16: Half,
|
||||
|
||||
// Added Kinds
|
||||
QUInt8Kind: QUInt8,
|
||||
QInt8Kind: QInt8,
|
||||
QInt32Kind: QInt32,
|
||||
// Float16Kind: Half,
|
||||
BFloat16Kind: BFloat16,
|
||||
QUInt4x2Kind: QUInt4x2,
|
||||
QUInt2x4Kind: QUInt2x4,
|
||||
Bits1x8Kind: Bits1x8,
|
||||
Bits2x4Kind: Bits2x4,
|
||||
Bits4x2Kind: Bits4x2,
|
||||
Bits8Kind: Bits8,
|
||||
Bits16Kind: Bits16,
|
||||
ComplexHalfKind: ComplexHalf,
|
||||
}
|
||||
|
||||
type DTypeOptions struct {
|
||||
HalfDTypePref DType
|
||||
Quantized bool
|
||||
}
|
||||
|
||||
type DTypeOpt func(*DTypeOptions)
|
||||
|
||||
func DefaultDTypeOptions() *DTypeOptions {
|
||||
return &DTypeOptions{
|
||||
HalfDTypePref: Half,
|
||||
Quantized: false,
|
||||
}
|
||||
}
|
||||
|
||||
func HalfDTypePref(v DType) DTypeOpt {
|
||||
if v != Half && v != BFloat16 {
|
||||
if Debug {
|
||||
log.Printf("WARNING: HalfDTypePref(): Ignoring invalid HalfDTypePref. HalfDTypePref either 'gotch.Half' or 'gotch.BFloat16'. Got %v\n", v)
|
||||
}
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
round--
|
||||
goType, total, err = dataCheck(v.Index(i).Interface(), total)
|
||||
|
||||
if err != nil {
|
||||
return reflect.TypeOf(reflect.Zero), 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return goType, total, nil
|
||||
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
total++
|
||||
if goType.String() != "invalid" {
|
||||
goType = v.Type()
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
|
||||
return reflect.TypeOf(reflect.Zero), 0, err
|
||||
}
|
||||
|
||||
return goType, total, nil
|
||||
}
|
||||
|
||||
// ElementGoType infers and returns Go type of element in given data
|
||||
func ElementGoType(data interface{}) (retVal reflect.Type, err error) {
|
||||
dataValue := reflect.ValueOf(data)
|
||||
return elementType(dataValue)
|
||||
}
|
||||
|
||||
func elementType(data reflect.Value) (dataType reflect.Type, err error) {
|
||||
dataKind := data.Kind()
|
||||
switch dataKind {
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
dataType = data.Type()
|
||||
case reflect.Slice, reflect.Array:
|
||||
data = data.Elem()
|
||||
dataType, err = elementType(data) // recursively type inferring
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
|
||||
return DType{}, err
|
||||
return func(o *DTypeOptions) {
|
||||
o.HalfDTypePref = v
|
||||
}
|
||||
|
||||
return dataType, nil
|
||||
}
|
||||
|
||||
// DataDType infers and returns data type of tensor data
|
||||
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
|
||||
// assuming that all elements in data have the same type
|
||||
switch {
|
||||
case len(shape) == 0:
|
||||
retVal, err = ElementDType(v)
|
||||
case len(shape) > 0:
|
||||
return ElementDType(v.([]interface{})[0])
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||
return DType{}, err
|
||||
func WithQuantized(v bool) DTypeOpt {
|
||||
return func(o *DTypeOptions) {
|
||||
o.Quantized = v
|
||||
}
|
||||
return DType{}, nil
|
||||
}
|
||||
|
||||
// ElementDType infers and returns its own tensor data type
|
||||
func ElementDType(v interface{}) (retVal DType, err error) {
|
||||
switch v.(type) {
|
||||
case uint8:
|
||||
retVal = Uint8
|
||||
case int8:
|
||||
retVal = Int8
|
||||
case int16:
|
||||
retVal = Int16
|
||||
case int32:
|
||||
retVal = Int
|
||||
case int64:
|
||||
retVal = Int64
|
||||
case float32:
|
||||
retVal = Float
|
||||
case float64:
|
||||
retVal = Double
|
||||
case bool:
|
||||
retVal = Bool
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// TypeOf infers and returns element Go type from given tensor DType and shape
|
||||
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
||||
var typ reflect.Type
|
||||
if typ, err = ToGoType(dt); err != nil {
|
||||
return nil, err
|
||||
func GoKind2DType(kind reflect.Kind, opts ...DTypeOpt) (DType, error) {
|
||||
o := DefaultDTypeOptions()
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(shape) == 0:
|
||||
return typ, nil
|
||||
case len(shape) > 0:
|
||||
return reflect.SliceOf(typ), nil
|
||||
case kind == reflect.Uint16 && o.HalfDTypePref == Half:
|
||||
return Half, nil
|
||||
case kind == reflect.Uint16 && o.HalfDTypePref == BFloat16:
|
||||
return BFloat16, nil
|
||||
case kind == reflect.Int8 && o.Quantized:
|
||||
return QInt8, nil
|
||||
case kind == reflect.Uint8 && o.Quantized:
|
||||
return QUInt8, nil
|
||||
case kind == reflect.Int32 && o.Quantized:
|
||||
return QInt32, nil
|
||||
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported data type.")
|
||||
return nil, err
|
||||
dtype, ok := goKind2DType[kind]
|
||||
if !ok {
|
||||
err := fmt.Errorf("GoKind2DType() failed: no corresponding DType to given Go reflect.Kind %v\n", kind)
|
||||
return Invalid, err
|
||||
}
|
||||
return dtype, nil
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* // TypeCheck checks whether data Go type matching DType
|
||||
* func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
||||
* dataValue := reflect.ValueOf(data)
|
||||
* var dataType reflect.Type
|
||||
* var err error
|
||||
* dataType, err = elementType(dataValue)
|
||||
* if err != nil {
|
||||
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||
* msg += err.Error()
|
||||
* return false, msg
|
||||
* }
|
||||
*
|
||||
* matched = dataType == dtype.Type
|
||||
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||
*
|
||||
* return matched, msg
|
||||
* }
|
||||
* */
|
||||
|
||||
var supportedTypes = map[reflect.Kind]bool{
|
||||
reflect.Uint8: true,
|
||||
reflect.Int8: true,
|
||||
reflect.Int16: true,
|
||||
reflect.Int32: true,
|
||||
reflect.Int64: true,
|
||||
reflect.Float32: true,
|
||||
reflect.Float64: true,
|
||||
reflect.Bool: true,
|
||||
var dtype2GoType map[DType]reflect.Type = map[DType]reflect.Type{
|
||||
Uint8: reflect.TypeOf(uint8(0)),
|
||||
Int8: reflect.TypeOf(int8(0)),
|
||||
Int16: reflect.TypeOf(int16(0)),
|
||||
Int: reflect.TypeOf(int(0)),
|
||||
Int64: reflect.TypeOf(int64(0)),
|
||||
Half: reflect.TypeOf(uint16(0)), // <- just uint16
|
||||
Float: reflect.TypeOf(float32(0)),
|
||||
Double: reflect.TypeOf(float64(0)),
|
||||
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
||||
ComplexFloat: reflect.TypeOf(complex64(0)),
|
||||
ComplexDouble: reflect.TypeOf(complex128(0)),
|
||||
Bool: reflect.TypeOf(true),
|
||||
QInt8: reflect.TypeOf(int8(0)),
|
||||
QUInt8: reflect.TypeOf(uint8(0)),
|
||||
QInt32: reflect.TypeOf(int32(0)),
|
||||
BFloat16: reflect.TypeOf(uint16(0)), // <- just uint16
|
||||
// ---not implemented ---
|
||||
QUInt4x2: reflect.TypeOf(int8(0)),
|
||||
QUInt2x4: reflect.TypeOf(uint8(0)),
|
||||
Bits1x8: reflect.TypeOf(uint8(0)),
|
||||
Bits2x4: reflect.TypeOf(uint8(0)),
|
||||
Bits4x2: reflect.TypeOf(uint8(0)),
|
||||
Bits8: reflect.TypeOf(uint8(0)),
|
||||
Bits16: reflect.TypeOf(uint16(0)),
|
||||
}
|
||||
|
||||
var scalarTypes = map[reflect.Kind]bool{
|
||||
reflect.Bool: true,
|
||||
reflect.Int: true,
|
||||
reflect.Int8: true,
|
||||
reflect.Int16: true,
|
||||
reflect.Int32: true,
|
||||
reflect.Int64: true,
|
||||
reflect.Uint: true,
|
||||
reflect.Uint8: true,
|
||||
reflect.Uint16: true,
|
||||
reflect.Uint32: true,
|
||||
reflect.Uint64: true,
|
||||
reflect.Uintptr: true,
|
||||
reflect.Float32: true,
|
||||
reflect.Float64: true,
|
||||
reflect.Complex64: true,
|
||||
reflect.Complex128: true,
|
||||
func (dt DType) GoType() (reflect.Type, error) {
|
||||
typ, ok := dtype2GoType[dt]
|
||||
if !ok {
|
||||
err := fmt.Errorf("DType.GoType() failed: no corresponding Go type to given DType %v\n", typ.String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return typ, nil
|
||||
}
|
||||
|
||||
// IsSupportedScalar checks whether given SCALAR type is supported
|
||||
// TODO: check input is a scalar.
|
||||
func IsSupportedScalar(k reflect.Kind) bool {
|
||||
// if _, ok := scalarTypes[k]; !ok {
|
||||
// log.Fatalf("Input type: %v is not a Go scalar type.", k)
|
||||
// }
|
||||
|
||||
_, retVal := supportedTypes[k]
|
||||
|
||||
return retVal
|
||||
var dtypeNames map[DType]string = map[DType]string{
|
||||
Uint8: "Uint8",
|
||||
Int8: "Int8",
|
||||
Int16: "Int16",
|
||||
Int: "Int",
|
||||
Int64: "Int64",
|
||||
Half: "Half", // <- just uint16
|
||||
Float: "Float",
|
||||
Double: "Double",
|
||||
// ComplexHalf: reflect.Invalid, // no equivalent in Go. Would it be reflect.Float64?
|
||||
ComplexFloat: "ComplexFloat",
|
||||
ComplexDouble: "ComplexDouble",
|
||||
Bool: "Bool",
|
||||
QInt8: "QInt8",
|
||||
QUInt8: "QUInt8",
|
||||
QInt32: "QInt32",
|
||||
BFloat16: "BFloat16", // <- just uint16
|
||||
// ---not implemented ---
|
||||
QUInt4x2: "QUInt4x2",
|
||||
QUInt2x4: "QUInt2x4",
|
||||
Bits1x8: "Bits1x8",
|
||||
Bits2x4: "Bits2x4",
|
||||
Bits4x2: "Bits4x2",
|
||||
Bits8: "Bits8",
|
||||
Bits16: "Bits16",
|
||||
}
|
||||
|
||||
func (dt DType) String() string {
|
||||
return dtypeNames[dt]
|
||||
}
|
||||
|
||||
func DTypeFromData(data interface{}) (DType, error) {
|
||||
dataKind := reflect.TypeOf(data).Kind()
|
||||
|
||||
// Data is a slice/array
|
||||
if dataKind == reflect.Slice || dataKind == reflect.Array {
|
||||
elementKind := reflect.TypeOf(data).Elem().Kind()
|
||||
return GoKind2DType(elementKind)
|
||||
}
|
||||
|
||||
// single element
|
||||
return GoKind2DType(dataKind)
|
||||
}
|
||||
|
||||
// IsFloatDType returns whether dtype is floating point data type.
|
||||
func IsFloatDType(dtype DType) bool {
|
||||
switch dtype {
|
||||
case Double, Float, Half, BFloat16:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Default DType:
|
||||
// ==============
|
||||
var DefaultDType DType = Float
|
||||
|
||||
// SetDefaultDType set DefaultDType to new value and return the previous one.
|
||||
func SetDefaultDType(dtype DType) DType {
|
||||
odtype := DefaultDType
|
||||
DefaultDType = dtype
|
||||
|
||||
if Debug {
|
||||
log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype)
|
||||
}
|
||||
|
||||
return odtype
|
||||
}
|
||||
|
|
|
@ -32,27 +32,18 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
|
|||
|
||||
state := lstm.Step(input, inState)
|
||||
|
||||
// 1. Delete inState tensors (from C land memory)
|
||||
inState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||
// 2. Then update with current state
|
||||
// Update with current state
|
||||
inState = state
|
||||
// 3. Delete intermediate tensors
|
||||
input.MustDrop()
|
||||
inputView.MustDrop()
|
||||
|
||||
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueezeDim(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||
sampledY := forwardTs.MustMultinomial(1, false, true)
|
||||
lastLabel = sampledY.Int64Values()[0]
|
||||
sampledY.MustDrop()
|
||||
char := data.LabelForChar(lastLabel)
|
||||
|
||||
runes = append(runes, char)
|
||||
}
|
||||
|
||||
// Delete the last state
|
||||
inState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||
inState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||
ts.CleanUp(100)
|
||||
}
|
||||
|
||||
return string(runes)
|
||||
}
|
||||
|
@ -93,42 +84,31 @@ func main() {
|
|||
|
||||
batchNarrow := batchTs.MustNarrow(1, 0, SeqLen, false)
|
||||
xsOnehot := batchNarrow.Onehot(labels).MustTo(device, true) // [256, 180, 65]
|
||||
batchNarrow.MustDrop()
|
||||
|
||||
ys := batchTs.MustNarrow(1, 1, SeqLen, true).MustTotype(gotch.Int64, true).MustTo(device, true).MustView([]int64{BatchSize * SeqLen}, true)
|
||||
|
||||
lstmOut, outState := lstm.Seq(xsOnehot)
|
||||
// NOTE. Although outState will not be used. There a hidden memory usage
|
||||
// on C land memory that is needed to free up. Don't use `_`
|
||||
outState.(*nn.LSTMState).Tensor1.MustDrop()
|
||||
outState.(*nn.LSTMState).Tensor2.MustDrop()
|
||||
xsOnehot.MustDrop()
|
||||
lstmOut, _ := lstm.Seq(xsOnehot)
|
||||
|
||||
logits := linear.Forward(lstmOut)
|
||||
lstmOut.MustDrop()
|
||||
lossView := logits.MustView([]int64{BatchSize * SeqLen, labels}, true)
|
||||
|
||||
loss := lossView.CrossEntropyForLogits(ys)
|
||||
ys.MustDrop()
|
||||
lossView.MustDrop()
|
||||
|
||||
opt.BackwardStepClip(loss, 0.5)
|
||||
sumLoss += loss.Float64Values()[0]
|
||||
cntLoss += 1.0
|
||||
loss.MustDrop()
|
||||
|
||||
batchCount++
|
||||
if batchCount%500 == 0 {
|
||||
fmt.Printf("Epoch %v - Batch %v \n", epoch, batchCount)
|
||||
fmt.Printf("\nEpoch %v - Batch %v \n", epoch, batchCount)
|
||||
}
|
||||
fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
|
||||
// fmt.Printf("dataIter: progress: %v\n", dataIter.Progress())
|
||||
fmt.Print(".")
|
||||
|
||||
ts.CleanUp(100)
|
||||
} // infinite for-loop
|
||||
|
||||
sampleStr := sample(data, lstm, linear, device)
|
||||
fmt.Printf("Epoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
|
||||
fmt.Printf("\nEpoch %v - Loss: %v \n", epoch, sumLoss/cntLoss)
|
||||
fmt.Println(sampleStr)
|
||||
|
||||
dataIter.Data.MustDrop()
|
||||
dataIter.Indexes.MustDrop()
|
||||
}
|
||||
}
|
||||
|
|
79
example/mem/main.go
Normal file
79
example/mem/main.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
const (
|
||||
ImageDimNN int64 = 784
|
||||
HiddenNodesNN int64 = 128
|
||||
LabelNN int64 = 10
|
||||
|
||||
BatchSize int64 = 3000
|
||||
|
||||
epochsNN = 200
|
||||
LrNN = 1e-3
|
||||
)
|
||||
|
||||
type model struct {
|
||||
fc *nn.Linear
|
||||
act nn.Func
|
||||
}
|
||||
|
||||
func newModel(vs *nn.VarStore) *model {
|
||||
fc := nn.NewLinear(vs.Root(), ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig())
|
||||
act := nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
||||
return xs.MustRelu(false)
|
||||
})
|
||||
|
||||
return &model{
|
||||
fc: fc,
|
||||
act: act,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *model) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
fc := m.fc.Forward(x)
|
||||
act := m.act.Forward(fc)
|
||||
|
||||
return act
|
||||
}
|
||||
|
||||
func newData() []float32 {
|
||||
n := int(BatchSize * ImageDimNN)
|
||||
data := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
data[i] = rand.Float32()
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func main() {
|
||||
epochs := 4000
|
||||
|
||||
// device := gotch.CPU
|
||||
device := gotch.CudaIfAvailable()
|
||||
vs := nn.NewVarStore(device)
|
||||
m := newModel(vs)
|
||||
|
||||
for i := 0; i < epochs; i++ {
|
||||
// input := ts.MustOfSlice(newData()).MustView([]int64{BatchSize, ImageDimNN}, true).MustTo(device, true)
|
||||
input := ts.MustRandn([]int64{BatchSize, ImageDimNN}, gotch.Float, device)
|
||||
|
||||
ts.NoGrad(func() {
|
||||
_ = m.Forward(input)
|
||||
})
|
||||
|
||||
if i%10 == 0 {
|
||||
fmt.Printf("=================== Epoch %03d completed========================\n", i)
|
||||
}
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
}
|
BIN
example/mem/mem.prof
Normal file
BIN
example/mem/mem.prof
Normal file
Binary file not shown.
80
example/mnist-fp16/README.md
Normal file
80
example/mnist-fp16/README.md
Normal file
|
@ -0,0 +1,80 @@
|
|||
# CNN MNIST training Float vs BFloat16
|
||||
|
||||
## BFloat16 - 16bit floating point
|
||||
|
||||
```bash
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Start eval...Epoch: 0 Loss: 0.05 Test accuracy: 98.05%
|
||||
Start eval...Epoch: 1 Loss: 0.03 Test accuracy: 98.36%
|
||||
Start eval...Epoch: 2 Loss: 0.03 Test accuracy: 98.44%
|
||||
Start eval...Epoch: 3 Loss: 0.18 Test accuracy: 98.44%
|
||||
Start eval...Epoch: 4 Loss: 0.01 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 5 Loss: 0.06 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 6 Loss: 0.21 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 7 Loss: 0.05 Test accuracy: 98.59%
|
||||
Start eval...Epoch: 8 Loss: 0.12 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 9 Loss: 0.12 Test accuracy: 98.48%
|
||||
Start eval...Epoch: 10 Loss: 0.04 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 11 Loss: 0.03 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 12 Loss: 0.04 Test accuracy: 98.48%
|
||||
Start eval...Epoch: 13 Loss: 0.32 Test accuracy: 98.48%
|
||||
Start eval...Epoch: 14 Loss: 0.06 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 15 Loss: 0.10 Test accuracy: 98.55%
|
||||
Start eval...Epoch: 16 Loss: 0.02 Test accuracy: 98.52%
|
||||
Start eval...Epoch: 17 Loss: 0.01 Test accuracy: 98.48%
|
||||
Start eval...Epoch: 18 Loss: 0.01 Test accuracy: 98.67%
|
||||
Start eval...Epoch: 19 Loss: 0.10 Test accuracy: 98.63%
|
||||
Start eval...Epoch: 20 Loss: 0.05 Test accuracy: 98.71%
|
||||
Start eval...Epoch: 21 Loss: 0.01 Test accuracy: 98.79%
|
||||
Start eval...Epoch: 22 Loss: 0.05 Test accuracy: 98.71%
|
||||
Start eval...Epoch: 23 Loss: 0.03 Test accuracy: 98.67%
|
||||
Start eval...Epoch: 24 Loss: 0.03 Test accuracy: 98.67%
|
||||
Start eval...Epoch: 25 Loss: 0.16 Test accuracy: 98.75%
|
||||
Start eval...Epoch: 26 Loss: 0.07 Test accuracy: 98.75%
|
||||
Start eval...Epoch: 27 Loss: 0.01 Test accuracy: 98.75%
|
||||
Start eval...Epoch: 28 Loss: 0.15 Test accuracy: 98.63%
|
||||
Start eval...Epoch: 29 Loss: 0.01 Test accuracy: 98.59%
|
||||
Best test accuracy: 98.79%
|
||||
Taken time: 8.67 mins
|
||||
```
|
||||
|
||||
|
||||
## Float - 32bit floating point
|
||||
|
||||
```bash
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Start eval...Epoch: 0 Loss: 0.27 Test accuracy: 98.42%
|
||||
Start eval...Epoch: 1 Loss: 0.06 Test accuracy: 98.60%
|
||||
Start eval...Epoch: 2 Loss: 0.01 Test accuracy: 98.68%
|
||||
Start eval...Epoch: 3 Loss: 0.01 Test accuracy: 98.63%
|
||||
Start eval...Epoch: 4 Loss: 0.11 Test accuracy: 98.82%
|
||||
Start eval...Epoch: 5 Loss: 0.11 Test accuracy: 99.00%
|
||||
Start eval...Epoch: 6 Loss: 0.00 Test accuracy: 98.93%
|
||||
Start eval...Epoch: 7 Loss: 0.00 Test accuracy: 98.96%
|
||||
Start eval...Epoch: 8 Loss: 0.01 Test accuracy: 99.02%
|
||||
Start eval...Epoch: 9 Loss: 0.04 Test accuracy: 99.04%
|
||||
Start eval...Epoch: 10 Loss: 0.06 Test accuracy: 99.07%
|
||||
Start eval...Epoch: 11 Loss: 0.01 Test accuracy: 99.12%
|
||||
Start eval...Epoch: 12 Loss: 0.00 Test accuracy: 99.12%
|
||||
Start eval...Epoch: 13 Loss: 0.00 Test accuracy: 99.12%
|
||||
Start eval...Epoch: 14 Loss: 0.04 Test accuracy: 99.14%
|
||||
Start eval...Epoch: 15 Loss: 0.07 Test accuracy: 99.12%
|
||||
Start eval...Epoch: 16 Loss: 0.00 Test accuracy: 99.08%
|
||||
Start eval...Epoch: 17 Loss: 0.00 Test accuracy: 99.10%
|
||||
Start eval...Epoch: 18 Loss: 0.08 Test accuracy: 99.16%
|
||||
Start eval...Epoch: 19 Loss: 0.07 Test accuracy: 99.20%
|
||||
Start eval...Epoch: 20 Loss: 0.00 Test accuracy: 99.06%
|
||||
Start eval...Epoch: 21 Loss: 0.05 Test accuracy: 98.97%
|
||||
Start eval...Epoch: 22 Loss: 0.01 Test accuracy: 99.13%
|
||||
Start eval...Epoch: 23 Loss: 0.00 Test accuracy: 99.13%
|
||||
Start eval...Epoch: 24 Loss: 0.01 Test accuracy: 99.16%
|
||||
Start eval...Epoch: 25 Loss: 0.00 Test accuracy: 99.11%
|
||||
Start eval...Epoch: 26 Loss: 0.09 Test accuracy: 99.13%
|
||||
Start eval...Epoch: 27 Loss: 0.00 Test accuracy: 99.14%
|
||||
Start eval...Epoch: 28 Loss: 0.00 Test accuracy: 99.13%
|
||||
Start eval...Epoch: 29 Loss: 0.01 Test accuracy: 99.20%
|
||||
Best test accuracy: 99.20%
|
||||
Taken time: 3.06 mins
|
||||
```
|
149
example/mnist-fp16/main.go
Normal file
149
example/mnist-fp16/main.go
Normal file
|
@ -0,0 +1,149 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
runCNN()
|
||||
}
|
||||
|
||||
const (
|
||||
MnistDir string = "/mnt/projects/numbat/data/mnist"
|
||||
|
||||
epochsCNN = 30
|
||||
batchCNN = 256
|
||||
// batchSize = 256
|
||||
batchSize = 32
|
||||
|
||||
LrCNN = 3 * 1e-4
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
// var device gotch.Device = gotch.CPU
|
||||
var device gotch.Device = gotch.CudaIfAvailable()
|
||||
|
||||
// var dtype gotch.DType = gotch.BFloat16
|
||||
|
||||
// var dtype gotch.DType = gotch.Half
|
||||
var dtype gotch.DType = gotch.Float
|
||||
|
||||
type Net struct {
|
||||
conv1 *nn.Conv2D
|
||||
conv2 *nn.Conv2D
|
||||
fc1 *nn.Linear
|
||||
fc2 *nn.Linear
|
||||
}
|
||||
|
||||
func newNet(vs *nn.Path) *Net {
|
||||
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
|
||||
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
|
||||
fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
|
||||
fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())
|
||||
|
||||
return &Net{
|
||||
conv1,
|
||||
conv2,
|
||||
fc1,
|
||||
fc2}
|
||||
}
|
||||
|
||||
func (n *Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
|
||||
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
|
||||
outC1 := outView1.Apply(n.conv1)
|
||||
|
||||
outMP1 := outC1.MaxPool2DDefault(2, true)
|
||||
outC2 := outMP1.Apply(n.conv2)
|
||||
|
||||
outMP2 := outC2.MaxPool2DDefault(2, true)
|
||||
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
|
||||
|
||||
outFC1 := outView2.Apply(n.fc1)
|
||||
outRelu := outFC1.MustRelu(false)
|
||||
outDropout := ts.MustDropout(outRelu, 0.5, train)
|
||||
return outDropout.Apply(n.fc2)
|
||||
}
|
||||
|
||||
func runCNN() {
|
||||
var ds *vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDir)
|
||||
trainImages := ds.TrainImages.MustTo(device, false) //[60000, 784]
|
||||
trainLabels := ds.TrainLabels.MustTo(device, false) // [60000, 784]
|
||||
testImages := ds.TestImages.MustTo(device, false).MustTotype(dtype, true) // [10000, 784]
|
||||
testLabels := ds.TestLabels.MustTo(device, false).MustTotype(dtype, true) // [10000, 784]
|
||||
|
||||
fmt.Printf("testImages: %v\n", testImages.MustSize())
|
||||
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
|
||||
|
||||
odtype := gotch.SetDefaultDType(dtype)
|
||||
vs := nn.NewVarStore(device)
|
||||
net := newNet(vs.Root())
|
||||
gotch.SetDefaultDType(odtype)
|
||||
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
|
||||
// opt, err := nn.DefaultSGDConfig().Build(vs, LrCNN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var bestAccuracy float64 = 0.0
|
||||
startTime := time.Now()
|
||||
|
||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||
totalSize := ds.TrainImages.MustSize()[0]
|
||||
samples := int(totalSize)
|
||||
// Shuffling
|
||||
index := ts.MustRandperm(int64(totalSize), gotch.Int64, device)
|
||||
imagesTs := trainImages.MustIndexSelect(0, index, false).MustTotype(dtype, true)
|
||||
labelsTs := trainLabels.MustIndexSelect(0, index, false)
|
||||
|
||||
batches := samples / batchSize
|
||||
batchIndex := 0
|
||||
var epocLoss float64
|
||||
for i := 0; i < batches; i++ {
|
||||
start := batchIndex * batchSize
|
||||
size := batchSize
|
||||
if samples-start < batchSize {
|
||||
break
|
||||
}
|
||||
batchIndex += 1
|
||||
|
||||
// Indexing
|
||||
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
|
||||
logits := net.ForwardT(bImages, true)
|
||||
|
||||
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
|
||||
loss = loss.MustSetRequiresGrad(true, true)
|
||||
opt.BackwardStep(loss)
|
||||
epocLoss = loss.Float64Values()[0]
|
||||
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
fmt.Printf("Start eval...")
|
||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1000)
|
||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
|
||||
if testAccuracy > bestAccuracy {
|
||||
bestAccuracy = testAccuracy
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
|
||||
ts.CleanUp()
|
||||
}
|
|
@ -264,3 +264,754 @@ Best test accuracy: 99.11%
|
|||
Taken time: 2.81 mins
|
||||
```
|
||||
|
||||
|
||||
## New Implement with Go `Garbage Collection`
|
||||
|
||||
### Linear
|
||||
|
||||
```bash
|
||||
go run -race . -model=linear -device=cuda
|
||||
Epoch: 0 - Loss: 2.303 - Test accuracy: 68.08%
|
||||
Epoch: 1 - Loss: 1.508 - Test accuracy: 60.77%
|
||||
Epoch: 2 - Loss: 1.388 - Test accuracy: 52.54%
|
||||
Epoch: 3 - Loss: 1.579 - Test accuracy: 64.46%
|
||||
Epoch: 4 - Loss: 1.707 - Test accuracy: 60.47%
|
||||
Epoch: 5 - Loss: 1.194 - Test accuracy: 61.55%
|
||||
Epoch: 6 - Loss: 1.395 - Test accuracy: 70.66%
|
||||
Epoch: 7 - Loss: 1.290 - Test accuracy: 70.43%
|
||||
Epoch: 8 - Loss: 0.892 - Test accuracy: 66.76%
|
||||
Epoch: 9 - Loss: 0.935 - Test accuracy: 71.84%
|
||||
Epoch: 10 - Loss: 1.108 - Test accuracy: 73.77%
|
||||
Epoch: 11 - Loss: 0.906 - Test accuracy: 78.11%
|
||||
Epoch: 12 - Loss: 0.786 - Test accuracy: 79.07%
|
||||
Epoch: 13 - Loss: 0.686 - Test accuracy: 80.85%
|
||||
Epoch: 14 - Loss: 0.666 - Test accuracy: 79.68%
|
||||
Epoch: 15 - Loss: 0.685 - Test accuracy: 82.69%
|
||||
Epoch: 16 - Loss: 0.623 - Test accuracy: 82.13%
|
||||
Epoch: 17 - Loss: 0.592 - Test accuracy: 82.36%
|
||||
Epoch: 18 - Loss: 0.604 - Test accuracy: 80.09%
|
||||
Epoch: 19 - Loss: 0.614 - Test accuracy: 80.69%
|
||||
Epoch: 20 - Loss: 0.613 - Test accuracy: 78.71%
|
||||
Epoch: 21 - Loss: 0.645 - Test accuracy: 81.83%
|
||||
Epoch: 22 - Loss: 0.582 - Test accuracy: 81.11%
|
||||
Epoch: 23 - Loss: 0.591 - Test accuracy: 83.56%
|
||||
Epoch: 24 - Loss: 0.542 - Test accuracy: 83.58%
|
||||
Epoch: 25 - Loss: 0.539 - Test accuracy: 85.27%
|
||||
Epoch: 26 - Loss: 0.507 - Test accuracy: 85.70%
|
||||
Epoch: 27 - Loss: 0.498 - Test accuracy: 86.79%
|
||||
Epoch: 28 - Loss: 0.477 - Test accuracy: 87.10%
|
||||
Epoch: 29 - Loss: 0.467 - Test accuracy: 87.83%
|
||||
Epoch: 30 - Loss: 0.453 - Test accuracy: 88.09%
|
||||
Epoch: 31 - Loss: 0.445 - Test accuracy: 88.44%
|
||||
Epoch: 32 - Loss: 0.436 - Test accuracy: 88.52%
|
||||
Epoch: 33 - Loss: 0.429 - Test accuracy: 88.78%
|
||||
Epoch: 34 - Loss: 0.424 - Test accuracy: 88.78%
|
||||
Epoch: 35 - Loss: 0.419 - Test accuracy: 89.20%
|
||||
Epoch: 36 - Loss: 0.415 - Test accuracy: 89.10%
|
||||
Epoch: 37 - Loss: 0.412 - Test accuracy: 89.49%
|
||||
Epoch: 38 - Loss: 0.409 - Test accuracy: 89.37%
|
||||
Epoch: 39 - Loss: 0.406 - Test accuracy: 89.62%
|
||||
Epoch: 40 - Loss: 0.404 - Test accuracy: 89.58%
|
||||
Epoch: 41 - Loss: 0.402 - Test accuracy: 89.71%
|
||||
Epoch: 42 - Loss: 0.399 - Test accuracy: 89.68%
|
||||
Epoch: 43 - Loss: 0.398 - Test accuracy: 89.85%
|
||||
Epoch: 44 - Loss: 0.396 - Test accuracy: 89.77%
|
||||
Epoch: 45 - Loss: 0.394 - Test accuracy: 89.87%
|
||||
Epoch: 46 - Loss: 0.392 - Test accuracy: 89.89%
|
||||
Epoch: 47 - Loss: 0.391 - Test accuracy: 89.95%
|
||||
Epoch: 48 - Loss: 0.389 - Test accuracy: 89.96%
|
||||
Epoch: 49 - Loss: 0.388 - Test accuracy: 90.05%
|
||||
Epoch: 50 - Loss: 0.387 - Test accuracy: 90.04%
|
||||
Epoch: 51 - Loss: 0.385 - Test accuracy: 90.13%
|
||||
Epoch: 52 - Loss: 0.384 - Test accuracy: 90.11%
|
||||
Epoch: 53 - Loss: 0.383 - Test accuracy: 90.24%
|
||||
Epoch: 54 - Loss: 0.381 - Test accuracy: 90.19%
|
||||
Epoch: 55 - Loss: 0.380 - Test accuracy: 90.27%
|
||||
Epoch: 56 - Loss: 0.379 - Test accuracy: 90.22%
|
||||
Epoch: 57 - Loss: 0.378 - Test accuracy: 90.30%
|
||||
Epoch: 58 - Loss: 0.377 - Test accuracy: 90.28%
|
||||
Epoch: 59 - Loss: 0.376 - Test accuracy: 90.33%
|
||||
Epoch: 60 - Loss: 0.375 - Test accuracy: 90.33%
|
||||
Epoch: 61 - Loss: 0.374 - Test accuracy: 90.37%
|
||||
Epoch: 62 - Loss: 0.372 - Test accuracy: 90.38%
|
||||
Epoch: 63 - Loss: 0.371 - Test accuracy: 90.41%
|
||||
Epoch: 64 - Loss: 0.371 - Test accuracy: 90.42%
|
||||
Epoch: 65 - Loss: 0.370 - Test accuracy: 90.43%
|
||||
Epoch: 66 - Loss: 0.369 - Test accuracy: 90.45%
|
||||
Epoch: 67 - Loss: 0.368 - Test accuracy: 90.51%
|
||||
Epoch: 68 - Loss: 0.367 - Test accuracy: 90.52%
|
||||
Epoch: 69 - Loss: 0.366 - Test accuracy: 90.53%
|
||||
Epoch: 70 - Loss: 0.365 - Test accuracy: 90.53%
|
||||
Epoch: 71 - Loss: 0.364 - Test accuracy: 90.55%
|
||||
Epoch: 72 - Loss: 0.363 - Test accuracy: 90.57%
|
||||
Epoch: 73 - Loss: 0.363 - Test accuracy: 90.57%
|
||||
Epoch: 74 - Loss: 0.362 - Test accuracy: 90.58%
|
||||
Epoch: 75 - Loss: 0.361 - Test accuracy: 90.59%
|
||||
Epoch: 76 - Loss: 0.360 - Test accuracy: 90.59%
|
||||
Epoch: 77 - Loss: 0.359 - Test accuracy: 90.62%
|
||||
Epoch: 78 - Loss: 0.359 - Test accuracy: 90.67%
|
||||
Epoch: 79 - Loss: 0.358 - Test accuracy: 90.67%
|
||||
Epoch: 80 - Loss: 0.357 - Test accuracy: 90.70%
|
||||
Epoch: 81 - Loss: 0.357 - Test accuracy: 90.71%
|
||||
Epoch: 82 - Loss: 0.356 - Test accuracy: 90.72%
|
||||
Epoch: 83 - Loss: 0.355 - Test accuracy: 90.77%
|
||||
Epoch: 84 - Loss: 0.355 - Test accuracy: 90.77%
|
||||
Epoch: 85 - Loss: 0.354 - Test accuracy: 90.78%
|
||||
Epoch: 86 - Loss: 0.353 - Test accuracy: 90.80%
|
||||
Epoch: 87 - Loss: 0.353 - Test accuracy: 90.82%
|
||||
Epoch: 88 - Loss: 0.352 - Test accuracy: 90.83%
|
||||
Epoch: 89 - Loss: 0.351 - Test accuracy: 90.82%
|
||||
Epoch: 90 - Loss: 0.351 - Test accuracy: 90.84%
|
||||
Epoch: 91 - Loss: 0.350 - Test accuracy: 90.87%
|
||||
Epoch: 92 - Loss: 0.350 - Test accuracy: 90.88%
|
||||
Epoch: 93 - Loss: 0.349 - Test accuracy: 90.87%
|
||||
Epoch: 94 - Loss: 0.348 - Test accuracy: 90.89%
|
||||
Epoch: 95 - Loss: 0.348 - Test accuracy: 90.89%
|
||||
Epoch: 96 - Loss: 0.347 - Test accuracy: 90.91%
|
||||
Epoch: 97 - Loss: 0.347 - Test accuracy: 90.94%
|
||||
Epoch: 98 - Loss: 0.346 - Test accuracy: 90.94%
|
||||
Epoch: 99 - Loss: 0.346 - Test accuracy: 90.96%
|
||||
Epoch: 100 - Loss: 0.345 - Test accuracy: 90.96%
|
||||
Epoch: 101 - Loss: 0.345 - Test accuracy: 90.96%
|
||||
Epoch: 102 - Loss: 0.344 - Test accuracy: 91.02%
|
||||
Epoch: 103 - Loss: 0.344 - Test accuracy: 91.02%
|
||||
Epoch: 104 - Loss: 0.343 - Test accuracy: 91.04%
|
||||
Epoch: 105 - Loss: 0.343 - Test accuracy: 91.05%
|
||||
Epoch: 106 - Loss: 0.342 - Test accuracy: 91.05%
|
||||
Epoch: 107 - Loss: 0.342 - Test accuracy: 91.06%
|
||||
Epoch: 108 - Loss: 0.341 - Test accuracy: 91.07%
|
||||
Epoch: 109 - Loss: 0.341 - Test accuracy: 91.08%
|
||||
Epoch: 110 - Loss: 0.340 - Test accuracy: 91.09%
|
||||
Epoch: 111 - Loss: 0.340 - Test accuracy: 91.11%
|
||||
Epoch: 112 - Loss: 0.339 - Test accuracy: 91.14%
|
||||
Epoch: 113 - Loss: 0.339 - Test accuracy: 91.15%
|
||||
Epoch: 114 - Loss: 0.339 - Test accuracy: 91.14%
|
||||
Epoch: 115 - Loss: 0.338 - Test accuracy: 91.15%
|
||||
Epoch: 116 - Loss: 0.338 - Test accuracy: 91.17%
|
||||
Epoch: 117 - Loss: 0.337 - Test accuracy: 91.18%
|
||||
Epoch: 118 - Loss: 0.337 - Test accuracy: 91.19%
|
||||
Epoch: 119 - Loss: 0.336 - Test accuracy: 91.20%
|
||||
Epoch: 120 - Loss: 0.336 - Test accuracy: 91.22%
|
||||
Epoch: 121 - Loss: 0.336 - Test accuracy: 91.23%
|
||||
Epoch: 122 - Loss: 0.335 - Test accuracy: 91.28%
|
||||
Epoch: 123 - Loss: 0.335 - Test accuracy: 91.30%
|
||||
Epoch: 124 - Loss: 0.334 - Test accuracy: 91.29%
|
||||
Epoch: 125 - Loss: 0.334 - Test accuracy: 91.32%
|
||||
Epoch: 126 - Loss: 0.334 - Test accuracy: 91.35%
|
||||
Epoch: 127 - Loss: 0.333 - Test accuracy: 91.37%
|
||||
Epoch: 128 - Loss: 0.333 - Test accuracy: 91.36%
|
||||
Epoch: 129 - Loss: 0.333 - Test accuracy: 91.36%
|
||||
Epoch: 130 - Loss: 0.332 - Test accuracy: 91.36%
|
||||
Epoch: 131 - Loss: 0.332 - Test accuracy: 91.35%
|
||||
Epoch: 132 - Loss: 0.332 - Test accuracy: 91.36%
|
||||
Epoch: 133 - Loss: 0.331 - Test accuracy: 91.37%
|
||||
Epoch: 134 - Loss: 0.331 - Test accuracy: 91.38%
|
||||
Epoch: 135 - Loss: 0.330 - Test accuracy: 91.39%
|
||||
Epoch: 136 - Loss: 0.330 - Test accuracy: 91.40%
|
||||
Epoch: 137 - Loss: 0.330 - Test accuracy: 91.43%
|
||||
Epoch: 138 - Loss: 0.329 - Test accuracy: 91.44%
|
||||
Epoch: 139 - Loss: 0.329 - Test accuracy: 91.45%
|
||||
Epoch: 140 - Loss: 0.329 - Test accuracy: 91.46%
|
||||
Epoch: 141 - Loss: 0.328 - Test accuracy: 91.47%
|
||||
Epoch: 142 - Loss: 0.328 - Test accuracy: 91.46%
|
||||
Epoch: 143 - Loss: 0.328 - Test accuracy: 91.48%
|
||||
Epoch: 144 - Loss: 0.328 - Test accuracy: 91.46%
|
||||
Epoch: 145 - Loss: 0.327 - Test accuracy: 91.46%
|
||||
Epoch: 146 - Loss: 0.327 - Test accuracy: 91.46%
|
||||
Epoch: 147 - Loss: 0.327 - Test accuracy: 91.47%
|
||||
Epoch: 148 - Loss: 0.326 - Test accuracy: 91.47%
|
||||
Epoch: 149 - Loss: 0.326 - Test accuracy: 91.48%
|
||||
Epoch: 150 - Loss: 0.326 - Test accuracy: 91.48%
|
||||
Epoch: 151 - Loss: 0.325 - Test accuracy: 91.50%
|
||||
Epoch: 152 - Loss: 0.325 - Test accuracy: 91.50%
|
||||
Epoch: 153 - Loss: 0.325 - Test accuracy: 91.52%
|
||||
Epoch: 154 - Loss: 0.325 - Test accuracy: 91.52%
|
||||
Epoch: 155 - Loss: 0.324 - Test accuracy: 91.52%
|
||||
Epoch: 156 - Loss: 0.324 - Test accuracy: 91.51%
|
||||
Epoch: 157 - Loss: 0.324 - Test accuracy: 91.51%
|
||||
Epoch: 158 - Loss: 0.323 - Test accuracy: 91.52%
|
||||
Epoch: 159 - Loss: 0.323 - Test accuracy: 91.51%
|
||||
Epoch: 160 - Loss: 0.323 - Test accuracy: 91.51%
|
||||
Epoch: 161 - Loss: 0.323 - Test accuracy: 91.50%
|
||||
Epoch: 162 - Loss: 0.322 - Test accuracy: 91.51%
|
||||
Epoch: 163 - Loss: 0.322 - Test accuracy: 91.53%
|
||||
Epoch: 164 - Loss: 0.322 - Test accuracy: 91.54%
|
||||
Epoch: 165 - Loss: 0.322 - Test accuracy: 91.54%
|
||||
Epoch: 166 - Loss: 0.321 - Test accuracy: 91.54%
|
||||
Epoch: 167 - Loss: 0.321 - Test accuracy: 91.56%
|
||||
Epoch: 168 - Loss: 0.321 - Test accuracy: 91.56%
|
||||
Epoch: 169 - Loss: 0.321 - Test accuracy: 91.56%
|
||||
Epoch: 170 - Loss: 0.320 - Test accuracy: 91.57%
|
||||
Epoch: 171 - Loss: 0.320 - Test accuracy: 91.59%
|
||||
Epoch: 172 - Loss: 0.320 - Test accuracy: 91.59%
|
||||
Epoch: 173 - Loss: 0.320 - Test accuracy: 91.60%
|
||||
Epoch: 174 - Loss: 0.319 - Test accuracy: 91.60%
|
||||
Epoch: 175 - Loss: 0.319 - Test accuracy: 91.61%
|
||||
Epoch: 176 - Loss: 0.319 - Test accuracy: 91.61%
|
||||
Epoch: 177 - Loss: 0.319 - Test accuracy: 91.61%
|
||||
Epoch: 178 - Loss: 0.318 - Test accuracy: 91.60%
|
||||
Epoch: 179 - Loss: 0.318 - Test accuracy: 91.60%
|
||||
Epoch: 180 - Loss: 0.318 - Test accuracy: 91.60%
|
||||
Epoch: 181 - Loss: 0.318 - Test accuracy: 91.60%
|
||||
Epoch: 182 - Loss: 0.318 - Test accuracy: 91.60%
|
||||
Epoch: 183 - Loss: 0.317 - Test accuracy: 91.60%
|
||||
Epoch: 184 - Loss: 0.317 - Test accuracy: 91.60%
|
||||
Epoch: 185 - Loss: 0.317 - Test accuracy: 91.62%
|
||||
Epoch: 186 - Loss: 0.317 - Test accuracy: 91.63%
|
||||
Epoch: 187 - Loss: 0.316 - Test accuracy: 91.66%
|
||||
Epoch: 188 - Loss: 0.316 - Test accuracy: 91.66%
|
||||
Epoch: 189 - Loss: 0.316 - Test accuracy: 91.65%
|
||||
Epoch: 190 - Loss: 0.316 - Test accuracy: 91.65%
|
||||
Epoch: 191 - Loss: 0.316 - Test accuracy: 91.65%
|
||||
Epoch: 192 - Loss: 0.315 - Test accuracy: 91.65%
|
||||
Epoch: 193 - Loss: 0.315 - Test accuracy: 91.67%
|
||||
Epoch: 194 - Loss: 0.315 - Test accuracy: 91.66%
|
||||
Epoch: 195 - Loss: 0.315 - Test accuracy: 91.66%
|
||||
Epoch: 196 - Loss: 0.315 - Test accuracy: 91.67%
|
||||
Epoch: 197 - Loss: 0.314 - Test accuracy: 91.68%
|
||||
Epoch: 198 - Loss: 0.314 - Test accuracy: 91.68%
|
||||
Epoch: 199 - Loss: 0.314 - Test accuracy: 91.68%
|
||||
```
|
||||
|
||||
### NN
|
||||
|
||||
```bash
|
||||
go run -race . -model=nn -device=cpu
|
||||
Epoch: 0 Loss: 2.313 Test accuracy: 23.07%
|
||||
Epoch: 1 Loss: 2.247 Test accuracy: 32.13%
|
||||
Epoch: 2 Loss: 2.182 Test accuracy: 44.69%
|
||||
Epoch: 3 Loss: 2.116 Test accuracy: 57.08%
|
||||
Epoch: 4 Loss: 2.047 Test accuracy: 64.63%
|
||||
Epoch: 5 Loss: 1.976 Test accuracy: 68.40%
|
||||
Epoch: 6 Loss: 1.903 Test accuracy: 70.94%
|
||||
Epoch: 7 Loss: 1.831 Test accuracy: 72.40%
|
||||
Epoch: 8 Loss: 1.758 Test accuracy: 74.03%
|
||||
Epoch: 9 Loss: 1.686 Test accuracy: 75.31%
|
||||
Epoch: 10 Loss: 1.614 Test accuracy: 76.54%
|
||||
Epoch: 11 Loss: 1.544 Test accuracy: 77.83%
|
||||
Epoch: 12 Loss: 1.475 Test accuracy: 78.68%
|
||||
Epoch: 13 Loss: 1.408 Test accuracy: 79.54%
|
||||
Epoch: 14 Loss: 1.343 Test accuracy: 80.20%
|
||||
Epoch: 15 Loss: 1.281 Test accuracy: 80.76%
|
||||
Epoch: 16 Loss: 1.221 Test accuracy: 81.47%
|
||||
Epoch: 17 Loss: 1.165 Test accuracy: 81.90%
|
||||
Epoch: 18 Loss: 1.110 Test accuracy: 82.19%
|
||||
Epoch: 19 Loss: 1.059 Test accuracy: 82.67%
|
||||
Epoch: 20 Loss: 1.010 Test accuracy: 83.07%
|
||||
Epoch: 21 Loss: 0.965 Test accuracy: 83.39%
|
||||
Epoch: 22 Loss: 0.922 Test accuracy: 83.74%
|
||||
Epoch: 23 Loss: 0.882 Test accuracy: 84.00%
|
||||
Epoch: 24 Loss: 0.845 Test accuracy: 84.25%
|
||||
Epoch: 25 Loss: 0.810 Test accuracy: 84.41%
|
||||
Epoch: 26 Loss: 0.778 Test accuracy: 84.77%
|
||||
Epoch: 27 Loss: 0.748 Test accuracy: 85.00%
|
||||
Epoch: 28 Loss: 0.721 Test accuracy: 85.28%
|
||||
Epoch: 29 Loss: 0.696 Test accuracy: 85.60%
|
||||
Epoch: 30 Loss: 0.672 Test accuracy: 85.85%
|
||||
Epoch: 31 Loss: 0.651 Test accuracy: 86.05%
|
||||
Epoch: 32 Loss: 0.631 Test accuracy: 86.31%
|
||||
Epoch: 33 Loss: 0.612 Test accuracy: 86.48%
|
||||
Epoch: 34 Loss: 0.595 Test accuracy: 86.74%
|
||||
Epoch: 35 Loss: 0.579 Test accuracy: 86.89%
|
||||
Epoch: 36 Loss: 0.564 Test accuracy: 87.17%
|
||||
Epoch: 37 Loss: 0.551 Test accuracy: 87.23%
|
||||
Epoch: 38 Loss: 0.538 Test accuracy: 87.34%
|
||||
Epoch: 39 Loss: 0.526 Test accuracy: 87.55%
|
||||
Epoch: 40 Loss: 0.515 Test accuracy: 87.74%
|
||||
Epoch: 41 Loss: 0.504 Test accuracy: 88.01%
|
||||
Epoch: 42 Loss: 0.495 Test accuracy: 88.23%
|
||||
Epoch: 43 Loss: 0.485 Test accuracy: 88.37%
|
||||
Epoch: 44 Loss: 0.477 Test accuracy: 88.55%
|
||||
Epoch: 45 Loss: 0.469 Test accuracy: 88.71%
|
||||
Epoch: 46 Loss: 0.461 Test accuracy: 88.89%
|
||||
Epoch: 47 Loss: 0.454 Test accuracy: 88.98%
|
||||
Epoch: 48 Loss: 0.447 Test accuracy: 89.06%
|
||||
Epoch: 49 Loss: 0.440 Test accuracy: 89.17%
|
||||
Epoch: 50 Loss: 0.434 Test accuracy: 89.30%
|
||||
Epoch: 51 Loss: 0.428 Test accuracy: 89.39%
|
||||
Epoch: 52 Loss: 0.422 Test accuracy: 89.51%
|
||||
Epoch: 53 Loss: 0.417 Test accuracy: 89.57%
|
||||
Epoch: 54 Loss: 0.412 Test accuracy: 89.69%
|
||||
Epoch: 55 Loss: 0.407 Test accuracy: 89.85%
|
||||
Epoch: 56 Loss: 0.403 Test accuracy: 89.89%
|
||||
Epoch: 57 Loss: 0.398 Test accuracy: 89.94%
|
||||
Epoch: 58 Loss: 0.394 Test accuracy: 90.02%
|
||||
Epoch: 59 Loss: 0.390 Test accuracy: 90.13%
|
||||
Epoch: 60 Loss: 0.386 Test accuracy: 90.25%
|
||||
Epoch: 61 Loss: 0.382 Test accuracy: 90.30%
|
||||
Epoch: 62 Loss: 0.379 Test accuracy: 90.36%
|
||||
Epoch: 63 Loss: 0.375 Test accuracy: 90.43%
|
||||
Epoch: 64 Loss: 0.372 Test accuracy: 90.47%
|
||||
Epoch: 65 Loss: 0.368 Test accuracy: 90.52%
|
||||
Epoch: 66 Loss: 0.365 Test accuracy: 90.59%
|
||||
Epoch: 67 Loss: 0.362 Test accuracy: 90.63%
|
||||
Epoch: 68 Loss: 0.360 Test accuracy: 90.66%
|
||||
Epoch: 69 Loss: 0.357 Test accuracy: 90.72%
|
||||
Epoch: 70 Loss: 0.354 Test accuracy: 90.76%
|
||||
Epoch: 71 Loss: 0.351 Test accuracy: 90.80%
|
||||
Epoch: 72 Loss: 0.349 Test accuracy: 90.83%
|
||||
Epoch: 73 Loss: 0.346 Test accuracy: 90.89%
|
||||
Epoch: 74 Loss: 0.344 Test accuracy: 90.93%
|
||||
Epoch: 75 Loss: 0.342 Test accuracy: 91.04%
|
||||
Epoch: 76 Loss: 0.340 Test accuracy: 91.11%
|
||||
Epoch: 77 Loss: 0.337 Test accuracy: 91.16%
|
||||
Epoch: 78 Loss: 0.335 Test accuracy: 91.20%
|
||||
Epoch: 79 Loss: 0.333 Test accuracy: 91.25%
|
||||
Epoch: 80 Loss: 0.331 Test accuracy: 91.29%
|
||||
Epoch: 81 Loss: 0.329 Test accuracy: 91.31%
|
||||
Epoch: 82 Loss: 0.327 Test accuracy: 91.34%
|
||||
Epoch: 83 Loss: 0.325 Test accuracy: 91.36%
|
||||
Epoch: 84 Loss: 0.324 Test accuracy: 91.42%
|
||||
Epoch: 85 Loss: 0.322 Test accuracy: 91.46%
|
||||
Epoch: 86 Loss: 0.320 Test accuracy: 91.49%
|
||||
Epoch: 87 Loss: 0.318 Test accuracy: 91.52%
|
||||
Epoch: 88 Loss: 0.317 Test accuracy: 91.53%
|
||||
Epoch: 89 Loss: 0.315 Test accuracy: 91.55%
|
||||
Epoch: 90 Loss: 0.313 Test accuracy: 91.62%
|
||||
Epoch: 91 Loss: 0.312 Test accuracy: 91.68%
|
||||
Epoch: 92 Loss: 0.310 Test accuracy: 91.72%
|
||||
Epoch: 93 Loss: 0.309 Test accuracy: 91.77%
|
||||
Epoch: 94 Loss: 0.307 Test accuracy: 91.82%
|
||||
Epoch: 95 Loss: 0.306 Test accuracy: 91.87%
|
||||
Epoch: 96 Loss: 0.304 Test accuracy: 91.89%
|
||||
Epoch: 97 Loss: 0.303 Test accuracy: 91.90%
|
||||
Epoch: 98 Loss: 0.302 Test accuracy: 91.92%
|
||||
Epoch: 99 Loss: 0.300 Test accuracy: 91.95%
|
||||
Epoch: 100 Loss: 0.299 Test accuracy: 91.99%
|
||||
Epoch: 101 Loss: 0.298 Test accuracy: 92.04%
|
||||
Epoch: 102 Loss: 0.296 Test accuracy: 92.07%
|
||||
Epoch: 103 Loss: 0.295 Test accuracy: 92.11%
|
||||
Epoch: 104 Loss: 0.294 Test accuracy: 92.13%
|
||||
Epoch: 105 Loss: 0.292 Test accuracy: 92.16%
|
||||
Epoch: 106 Loss: 0.291 Test accuracy: 92.18%
|
||||
Epoch: 107 Loss: 0.290 Test accuracy: 92.20%
|
||||
Epoch: 108 Loss: 0.289 Test accuracy: 92.20%
|
||||
Epoch: 109 Loss: 0.287 Test accuracy: 92.24%
|
||||
Epoch: 110 Loss: 0.286 Test accuracy: 92.25%
|
||||
Epoch: 111 Loss: 0.285 Test accuracy: 92.26%
|
||||
Epoch: 112 Loss: 0.284 Test accuracy: 92.26%
|
||||
Epoch: 113 Loss: 0.283 Test accuracy: 92.30%
|
||||
Epoch: 114 Loss: 0.282 Test accuracy: 92.32%
|
||||
Epoch: 115 Loss: 0.281 Test accuracy: 92.34%
|
||||
Epoch: 116 Loss: 0.279 Test accuracy: 92.40%
|
||||
Epoch: 117 Loss: 0.278 Test accuracy: 92.41%
|
||||
Epoch: 118 Loss: 0.277 Test accuracy: 92.43%
|
||||
Epoch: 119 Loss: 0.276 Test accuracy: 92.45%
|
||||
Epoch: 120 Loss: 0.275 Test accuracy: 92.47%
|
||||
Epoch: 121 Loss: 0.274 Test accuracy: 92.52%
|
||||
Epoch: 122 Loss: 0.273 Test accuracy: 92.55%
|
||||
Epoch: 123 Loss: 0.272 Test accuracy: 92.55%
|
||||
Epoch: 124 Loss: 0.271 Test accuracy: 92.57%
|
||||
Epoch: 125 Loss: 0.270 Test accuracy: 92.61%
|
||||
Epoch: 126 Loss: 0.269 Test accuracy: 92.61%
|
||||
Epoch: 127 Loss: 0.268 Test accuracy: 92.61%
|
||||
Epoch: 128 Loss: 0.267 Test accuracy: 92.63%
|
||||
Epoch: 129 Loss: 0.266 Test accuracy: 92.65%
|
||||
Epoch: 130 Loss: 0.265 Test accuracy: 92.66%
|
||||
Epoch: 131 Loss: 0.264 Test accuracy: 92.72%
|
||||
Epoch: 132 Loss: 0.263 Test accuracy: 92.78%
|
||||
Epoch: 133 Loss: 0.262 Test accuracy: 92.77%
|
||||
Epoch: 134 Loss: 0.261 Test accuracy: 92.77%
|
||||
Epoch: 135 Loss: 0.260 Test accuracy: 92.82%
|
||||
Epoch: 136 Loss: 0.259 Test accuracy: 92.81%
|
||||
Epoch: 137 Loss: 0.258 Test accuracy: 92.84%
|
||||
Epoch: 138 Loss: 0.257 Test accuracy: 92.86%
|
||||
Epoch: 139 Loss: 0.256 Test accuracy: 92.88%
|
||||
Epoch: 140 Loss: 0.255 Test accuracy: 92.89%
|
||||
Epoch: 141 Loss: 0.254 Test accuracy: 92.90%
|
||||
Epoch: 142 Loss: 0.253 Test accuracy: 92.90%
|
||||
Epoch: 143 Loss: 0.253 Test accuracy: 92.94%
|
||||
Epoch: 144 Loss: 0.252 Test accuracy: 92.96%
|
||||
Epoch: 145 Loss: 0.251 Test accuracy: 92.99%
|
||||
Epoch: 146 Loss: 0.250 Test accuracy: 93.01%
|
||||
Epoch: 147 Loss: 0.249 Test accuracy: 93.03%
|
||||
Epoch: 148 Loss: 0.248 Test accuracy: 93.08%
|
||||
Epoch: 149 Loss: 0.247 Test accuracy: 93.11%
|
||||
Epoch: 150 Loss: 0.246 Test accuracy: 93.12%
|
||||
Epoch: 151 Loss: 0.245 Test accuracy: 93.14%
|
||||
Epoch: 152 Loss: 0.245 Test accuracy: 93.15%
|
||||
Epoch: 153 Loss: 0.244 Test accuracy: 93.16%
|
||||
Epoch: 154 Loss: 0.243 Test accuracy: 93.16%
|
||||
Epoch: 155 Loss: 0.242 Test accuracy: 93.16%
|
||||
Epoch: 156 Loss: 0.241 Test accuracy: 93.18%
|
||||
Epoch: 157 Loss: 0.240 Test accuracy: 93.18%
|
||||
Epoch: 158 Loss: 0.240 Test accuracy: 93.22%
|
||||
Epoch: 159 Loss: 0.239 Test accuracy: 93.26%
|
||||
Epoch: 160 Loss: 0.238 Test accuracy: 93.28%
|
||||
Epoch: 161 Loss: 0.237 Test accuracy: 93.30%
|
||||
Epoch: 162 Loss: 0.236 Test accuracy: 93.30%
|
||||
Epoch: 163 Loss: 0.235 Test accuracy: 93.33%
|
||||
Epoch: 164 Loss: 0.235 Test accuracy: 93.34%
|
||||
Epoch: 165 Loss: 0.234 Test accuracy: 93.40%
|
||||
Epoch: 166 Loss: 0.233 Test accuracy: 93.42%
|
||||
Epoch: 167 Loss: 0.232 Test accuracy: 93.43%
|
||||
Epoch: 168 Loss: 0.231 Test accuracy: 93.43%
|
||||
Epoch: 169 Loss: 0.231 Test accuracy: 93.45%
|
||||
Epoch: 170 Loss: 0.230 Test accuracy: 93.46%
|
||||
Epoch: 171 Loss: 0.229 Test accuracy: 93.45%
|
||||
Epoch: 172 Loss: 0.228 Test accuracy: 93.46%
|
||||
Epoch: 173 Loss: 0.227 Test accuracy: 93.48%
|
||||
Epoch: 174 Loss: 0.227 Test accuracy: 93.49%
|
||||
Epoch: 175 Loss: 0.226 Test accuracy: 93.54%
|
||||
Epoch: 176 Loss: 0.225 Test accuracy: 93.56%
|
||||
Epoch: 177 Loss: 0.224 Test accuracy: 93.57%
|
||||
Epoch: 178 Loss: 0.224 Test accuracy: 93.57%
|
||||
Epoch: 179 Loss: 0.223 Test accuracy: 93.59%
|
||||
Epoch: 180 Loss: 0.222 Test accuracy: 93.60%
|
||||
Epoch: 181 Loss: 0.221 Test accuracy: 93.60%
|
||||
Epoch: 182 Loss: 0.221 Test accuracy: 93.62%
|
||||
Epoch: 183 Loss: 0.220 Test accuracy: 93.64%
|
||||
Epoch: 184 Loss: 0.219 Test accuracy: 93.63%
|
||||
Epoch: 185 Loss: 0.218 Test accuracy: 93.64%
|
||||
Epoch: 186 Loss: 0.218 Test accuracy: 93.67%
|
||||
Epoch: 187 Loss: 0.217 Test accuracy: 93.68%
|
||||
Epoch: 188 Loss: 0.216 Test accuracy: 93.71%
|
||||
Epoch: 189 Loss: 0.215 Test accuracy: 93.73%
|
||||
Epoch: 190 Loss: 0.215 Test accuracy: 93.75%
|
||||
Epoch: 191 Loss: 0.214 Test accuracy: 93.77%
|
||||
Epoch: 192 Loss: 0.213 Test accuracy: 93.78%
|
||||
Epoch: 193 Loss: 0.213 Test accuracy: 93.79%
|
||||
Epoch: 194 Loss: 0.212 Test accuracy: 93.83%
|
||||
Epoch: 195 Loss: 0.211 Test accuracy: 93.85%
|
||||
Epoch: 196 Loss: 0.211 Test accuracy: 93.89%
|
||||
Epoch: 197 Loss: 0.210 Test accuracy: 93.96%
|
||||
Epoch: 198 Loss: 0.209 Test accuracy: 93.99%
|
||||
Epoch: 199 Loss: 0.209 Test accuracy: 94.01%
|
||||
```
|
||||
|
||||
|
||||
### CNN
|
||||
|
||||
**BatchSize = 256 on CPU**
|
||||
|
||||
```
|
||||
go run . -model=cnn -device=cpu
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.15 Test accuracy: 96.69%
|
||||
Epoch: 1 Loss: 0.20 Test accuracy: 94.54%
|
||||
Epoch: 2 Loss: 0.17 Test accuracy: 95.38%
|
||||
Epoch: 3 Loss: 0.15 Test accuracy: 95.46%
|
||||
Epoch: 4 Loss: 0.12 Test accuracy: 97.09%
|
||||
Epoch: 5 Loss: 0.17 Test accuracy: 97.17%
|
||||
Epoch: 6 Loss: 0.09 Test accuracy: 97.17%
|
||||
Epoch: 7 Loss: 0.06 Test accuracy: 97.17%
|
||||
Epoch: 8 Loss: 0.10 Test accuracy: 97.16%
|
||||
Epoch: 9 Loss: 0.11 Test accuracy: 97.16%
|
||||
Epoch: 10 Loss: 0.14 Test accuracy: 97.16%
|
||||
Epoch: 11 Loss: 0.11 Test accuracy: 97.16%
|
||||
Epoch: 12 Loss: 0.08 Test accuracy: 97.16%
|
||||
Epoch: 13 Loss: 0.10 Test accuracy: 97.16%
|
||||
Epoch: 14 Loss: 0.13 Test accuracy: 97.16%
|
||||
Epoch: 15 Loss: 0.08 Test accuracy: 97.16%
|
||||
Epoch: 16 Loss: 0.10 Test accuracy: 97.16%
|
||||
Epoch: 17 Loss: 0.12 Test accuracy: 97.16%
|
||||
Epoch: 18 Loss: 0.13 Test accuracy: 97.16%
|
||||
Epoch: 19 Loss: 0.08 Test accuracy: 97.16%
|
||||
Epoch: 20 Loss: 0.09 Test accuracy: 97.16%
|
||||
Epoch: 21 Loss: 0.05 Test accuracy: 97.16%
|
||||
Epoch: 22 Loss: 0.10 Test accuracy: 97.16%
|
||||
Epoch: 23 Loss: 0.11 Test accuracy: 97.16%
|
||||
Epoch: 24 Loss: 0.14 Test accuracy: 97.16%
|
||||
Epoch: 25 Loss: 0.15 Test accuracy: 97.16%
|
||||
Epoch: 26 Loss: 0.07 Test accuracy: 97.16%
|
||||
Epoch: 27 Loss: 0.08 Test accuracy: 97.16%
|
||||
Epoch: 28 Loss: 0.13 Test accuracy: 97.16%
|
||||
Epoch: 29 Loss: 0.10 Test accuracy: 97.16%
|
||||
Best test accuracy: 97.17%
|
||||
Taken time: 9.04 mins
|
||||
go run . -model=cnn -device=cpu
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.13 Test accuracy: 96.77%
|
||||
Epoch: 1 Loss: 0.18 Test accuracy: 95.43%
|
||||
Epoch: 2 Loss: 0.18 Test accuracy: 95.53%
|
||||
Epoch: 3 Loss: 0.18 Test accuracy: 96.08%
|
||||
Epoch: 4 Loss: 0.14 Test accuracy: 96.37%
|
||||
Epoch: 5 Loss: 0.14 Test accuracy: 96.40%
|
||||
Epoch: 6 Loss: 0.11 Test accuracy: 96.44%
|
||||
Epoch: 7 Loss: 0.08 Test accuracy: 96.96%
|
||||
Epoch: 8 Loss: 0.16 Test accuracy: 97.09%
|
||||
Epoch: 9 Loss: 0.11 Test accuracy: 97.05%
|
||||
Epoch: 10 Loss: 0.11 Test accuracy: 97.04%
|
||||
Epoch: 11 Loss: 0.11 Test accuracy: 97.10%
|
||||
Epoch: 12 Loss: 0.12 Test accuracy: 97.13%
|
||||
Epoch: 13 Loss: 0.09 Test accuracy: 97.13%
|
||||
Epoch: 14 Loss: 0.11 Test accuracy: 97.13%
|
||||
Epoch: 15 Loss: 0.16 Test accuracy: 97.13%
|
||||
Epoch: 16 Loss: 0.14 Test accuracy: 97.13%
|
||||
Epoch: 17 Loss: 0.11 Test accuracy: 97.13%
|
||||
Epoch: 18 Loss: 0.14 Test accuracy: 97.13%
|
||||
Epoch: 19 Loss: 0.17 Test accuracy: 97.13%
|
||||
Epoch: 20 Loss: 0.16 Test accuracy: 97.13%
|
||||
Epoch: 21 Loss: 0.07 Test accuracy: 97.13%
|
||||
Epoch: 22 Loss: 0.15 Test accuracy: 97.13%
|
||||
Epoch: 23 Loss: 0.14 Test accuracy: 97.13%
|
||||
Epoch: 24 Loss: 0.07 Test accuracy: 97.13%
|
||||
Epoch: 25 Loss: 0.13 Test accuracy: 97.13%
|
||||
Epoch: 26 Loss: 0.11 Test accuracy: 97.13%
|
||||
Epoch: 27 Loss: 0.14 Test accuracy: 97.13%
|
||||
Epoch: 28 Loss: 0.14 Test accuracy: 97.13%
|
||||
Epoch: 29 Loss: 0.08 Test accuracy: 97.13%
|
||||
Best test accuracy: 97.13%
|
||||
Taken time: 9.37 mins
|
||||
go run . -model=cnn -device=cpu
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.13 Test accuracy: 97.03%
|
||||
Epoch: 1 Loss: 0.14 Test accuracy: 97.43%
|
||||
Epoch: 2 Loss: 0.10 Test accuracy: 97.37%
|
||||
Epoch: 3 Loss: 0.13 Test accuracy: 97.35%
|
||||
Epoch: 4 Loss: 0.15 Test accuracy: 97.37%
|
||||
Epoch: 5 Loss: 0.06 Test accuracy: 97.68%
|
||||
Epoch: 6 Loss: 0.12 Test accuracy: 97.19%
|
||||
Epoch: 7 Loss: 0.08 Test accuracy: 97.68%
|
||||
Epoch: 8 Loss: 0.13 Test accuracy: 97.89%
|
||||
Epoch: 9 Loss: 0.10 Test accuracy: 97.32%
|
||||
Epoch: 10 Loss: 0.10 Test accuracy: 98.25%
|
||||
Epoch: 11 Loss: 0.07 Test accuracy: 98.26%
|
||||
Epoch: 12 Loss: 0.07 Test accuracy: 98.39%
|
||||
Epoch: 13 Loss: 0.09 Test accuracy: 98.43%
|
||||
Epoch: 14 Loss: 0.07 Test accuracy: 98.44%
|
||||
Epoch: 15 Loss: 0.09 Test accuracy: 98.49%
|
||||
Epoch: 16 Loss: 0.06 Test accuracy: 98.48%
|
||||
Epoch: 17 Loss: 0.05 Test accuracy: 98.48%
|
||||
Epoch: 18 Loss: 0.06 Test accuracy: 98.48%
|
||||
Epoch: 19 Loss: 0.04 Test accuracy: 98.48%
|
||||
Epoch: 20 Loss: 0.08 Test accuracy: 98.48%
|
||||
Epoch: 21 Loss: 0.11 Test accuracy: 98.48%
|
||||
Epoch: 22 Loss: 0.09 Test accuracy: 98.48%
|
||||
Epoch: 23 Loss: 0.06 Test accuracy: 98.48%
|
||||
Epoch: 24 Loss: 0.06 Test accuracy: 98.48%
|
||||
Epoch: 25 Loss: 0.05 Test accuracy: 98.48%
|
||||
Epoch: 26 Loss: 0.05 Test accuracy: 98.48%
|
||||
Epoch: 27 Loss: 0.05 Test accuracy: 98.48%
|
||||
Epoch: 28 Loss: 0.07 Test accuracy: 98.48%
|
||||
Epoch: 29 Loss: 0.10 Test accuracy: 98.48%
|
||||
Best test accuracy: 98.49%
|
||||
Taken time: 9.39 mins
|
||||
go run . -model=cnn -device=cpu
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.14 Test accuracy: 96.88%
|
||||
Epoch: 1 Loss: 0.12 Test accuracy: 97.29%
|
||||
Epoch: 2 Loss: 0.13 Test accuracy: 97.25%
|
||||
Epoch: 3 Loss: 0.11 Test accuracy: 97.21%
|
||||
Epoch: 4 Loss: 0.12 Test accuracy: 97.22%
|
||||
Epoch: 5 Loss: 0.08 Test accuracy: 97.32%
|
||||
Epoch: 6 Loss: 0.11 Test accuracy: 97.31%
|
||||
Epoch: 7 Loss: 0.13 Test accuracy: 97.32%
|
||||
Epoch: 8 Loss: 0.10 Test accuracy: 97.44%
|
||||
Epoch: 9 Loss: 0.15 Test accuracy: 97.37%
|
||||
Epoch: 10 Loss: 0.09 Test accuracy: 97.46%
|
||||
Epoch: 11 Loss: 0.11 Test accuracy: 97.49%
|
||||
Epoch: 12 Loss: 0.11 Test accuracy: 96.21%
|
||||
Epoch: 13 Loss: 0.13 Test accuracy: 95.94%
|
||||
Epoch: 14 Loss: 0.20 Test accuracy: 95.97%
|
||||
Epoch: 15 Loss: 0.18 Test accuracy: 97.12%
|
||||
Epoch: 16 Loss: 0.04 Test accuracy: 97.50%
|
||||
Epoch: 17 Loss: 0.18 Test accuracy: 97.38%
|
||||
Epoch: 18 Loss: 0.06 Test accuracy: 97.60%
|
||||
Epoch: 19 Loss: 0.13 Test accuracy: 97.45%
|
||||
Epoch: 20 Loss: 0.06 Test accuracy: 97.57%
|
||||
Epoch: 21 Loss: 0.12 Test accuracy: 97.60%
|
||||
Epoch: 22 Loss: 0.10 Test accuracy: 97.60%
|
||||
Epoch: 23 Loss: 0.09 Test accuracy: 97.60%
|
||||
Epoch: 24 Loss: 0.11 Test accuracy: 97.60%
|
||||
Epoch: 25 Loss: 0.13 Test accuracy: 97.60%
|
||||
Epoch: 26 Loss: 0.08 Test accuracy: 97.60%
|
||||
Epoch: 27 Loss: 0.18 Test accuracy: 97.60%
|
||||
Epoch: 28 Loss: 0.09 Test accuracy: 97.60%
|
||||
Epoch: 29 Loss: 0.07 Test accuracy: 97.60%
|
||||
Best test accuracy: 97.60%
|
||||
Taken time: 9.41 mins
|
||||
|
||||
```
|
||||
|
||||
**BatchSize = 32 on CUDA**
|
||||
|
||||
```bash
|
||||
go run . -model=cnn -device=cuda
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.28 Test accuracy: 98.41%
|
||||
Epoch: 1 Loss: 0.01 Test accuracy: 98.55%
|
||||
Epoch: 2 Loss: 0.09 Test accuracy: 98.53%
|
||||
Epoch: 3 Loss: 0.01 Test accuracy: 98.64%
|
||||
Epoch: 4 Loss: 0.01 Test accuracy: 98.74%
|
||||
Epoch: 5 Loss: 0.01 Test accuracy: 98.81%
|
||||
Epoch: 6 Loss: 0.10 Test accuracy: 98.91%
|
||||
Epoch: 7 Loss: 0.02 Test accuracy: 98.86%
|
||||
Epoch: 8 Loss: 0.00 Test accuracy: 98.64%
|
||||
Epoch: 9 Loss: 0.17 Test accuracy: 98.84%
|
||||
Epoch: 10 Loss: 0.01 Test accuracy: 98.83%
|
||||
Epoch: 11 Loss: 0.00 Test accuracy: 98.88%
|
||||
Epoch: 12 Loss: 0.05 Test accuracy: 98.90%
|
||||
Epoch: 13 Loss: 0.01 Test accuracy: 99.01%
|
||||
Epoch: 14 Loss: 0.09 Test accuracy: 97.85%
|
||||
Epoch: 15 Loss: 0.10 Test accuracy: 98.24%
|
||||
Epoch: 16 Loss: 0.00 Test accuracy: 98.53%
|
||||
Epoch: 17 Loss: 0.00 Test accuracy: 98.49%
|
||||
Epoch: 18 Loss: 0.16 Test accuracy: 98.49%
|
||||
Epoch: 19 Loss: 0.13 Test accuracy: 98.49%
|
||||
Epoch: 20 Loss: 0.00 Test accuracy: 98.49%
|
||||
Epoch: 21 Loss: 0.01 Test accuracy: 98.49%
|
||||
Epoch: 22 Loss: 0.17 Test accuracy: 98.49%
|
||||
Epoch: 23 Loss: 0.06 Test accuracy: 98.49%
|
||||
Epoch: 24 Loss: 0.00 Test accuracy: 98.49%
|
||||
Epoch: 25 Loss: 0.12 Test accuracy: 98.49%
|
||||
Epoch: 26 Loss: 0.08 Test accuracy: 98.49%
|
||||
Epoch: 27 Loss: 0.19 Test accuracy: 98.49%
|
||||
Epoch: 28 Loss: 0.02 Test accuracy: 98.49%
|
||||
Epoch: 29 Loss: 0.01 Test accuracy: 98.49%
|
||||
Best test accuracy: 99.01%
|
||||
Taken time: 8.89 mins
|
||||
|
||||
|
||||
go run . -model=cnn -device=cuda
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.05 Test accuracy: 98.40%
|
||||
Epoch: 1 Loss: 0.01 Test accuracy: 98.92%
|
||||
Epoch: 2 Loss: 0.10 Test accuracy: 98.97%
|
||||
Epoch: 3 Loss: 0.03 Test accuracy: 98.79%
|
||||
Epoch: 4 Loss: 0.02 Test accuracy: 98.81%
|
||||
Epoch: 5 Loss: 0.15 Test accuracy: 98.85%
|
||||
Epoch: 6 Loss: 0.01 Test accuracy: 98.82%
|
||||
Epoch: 7 Loss: 0.03 Test accuracy: 98.83%
|
||||
Epoch: 8 Loss: 0.01 Test accuracy: 98.56%
|
||||
Epoch: 9 Loss: 0.00 Test accuracy: 98.85%
|
||||
Epoch: 10 Loss: 0.22 Test accuracy: 98.51%
|
||||
Epoch: 11 Loss: 0.78 Test accuracy: 98.37%
|
||||
Epoch: 12 Loss: 0.01 Test accuracy: 98.47%
|
||||
Epoch: 13 Loss: 0.55 Test accuracy: 98.48%
|
||||
Epoch: 14 Loss: 0.00 Test accuracy: 98.45%
|
||||
Epoch: 15 Loss: 0.13 Test accuracy: 98.47%
|
||||
Epoch: 16 Loss: 0.01 Test accuracy: 98.49%
|
||||
Epoch: 17 Loss: 0.00 Test accuracy: 98.35%
|
||||
Epoch: 18 Loss: 0.08 Test accuracy: 98.41%
|
||||
Epoch: 19 Loss: 0.63 Test accuracy: 98.58%
|
||||
Epoch: 20 Loss: 0.22 Test accuracy: 98.59%
|
||||
Epoch: 21 Loss: 0.00 Test accuracy: 98.63%
|
||||
Epoch: 22 Loss: 0.80 Test accuracy: 98.63%
|
||||
Epoch: 23 Loss: 0.19 Test accuracy: 98.63%
|
||||
Epoch: 24 Loss: 0.00 Test accuracy: 98.63%
|
||||
Epoch: 25 Loss: 0.00 Test accuracy: 98.63%
|
||||
Epoch: 26 Loss: 0.00 Test accuracy: 98.63%
|
||||
Epoch: 27 Loss: 0.00 Test accuracy: 98.63%
|
||||
Epoch: 28 Loss: 0.09 Test accuracy: 98.63%
|
||||
Epoch: 29 Loss: 0.02 Test accuracy: 98.63%
|
||||
Best test accuracy: 98.97%
|
||||
Taken time: 8.85 mins
|
||||
|
||||
|
||||
go run . -model=cnn -device=cuda
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.39 Test accuracy: 97.83%
|
||||
Epoch: 1 Loss: 0.01 Test accuracy: 97.95%
|
||||
Epoch: 2 Loss: 0.00 Test accuracy: 98.74%
|
||||
Epoch: 3 Loss: 0.00 Test accuracy: 98.64%
|
||||
Epoch: 4 Loss: 0.07 Test accuracy: 98.62%
|
||||
Epoch: 5 Loss: 0.01 Test accuracy: 98.75%
|
||||
Epoch: 6 Loss: 0.01 Test accuracy: 98.76%
|
||||
Epoch: 7 Loss: 0.26 Test accuracy: 98.33%
|
||||
Epoch: 8 Loss: 0.04 Test accuracy: 98.44%
|
||||
Epoch: 9 Loss: 0.12 Test accuracy: 98.60%
|
||||
Epoch: 10 Loss: 0.00 Test accuracy: 98.60%
|
||||
Epoch: 11 Loss: 0.51 Test accuracy: 98.60%
|
||||
Epoch: 12 Loss: 0.05 Test accuracy: 98.60%
|
||||
Epoch: 13 Loss: 0.12 Test accuracy: 98.60%
|
||||
Epoch: 14 Loss: 0.00 Test accuracy: 98.60%
|
||||
Epoch: 15 Loss: 0.03 Test accuracy: 98.60%
|
||||
Epoch: 16 Loss: 0.03 Test accuracy: 98.60%
|
||||
Epoch: 17 Loss: 0.25 Test accuracy: 98.60%
|
||||
Epoch: 18 Loss: 0.18 Test accuracy: 98.35%
|
||||
Epoch: 19 Loss: 0.18 Test accuracy: 98.42%
|
||||
Epoch: 20 Loss: 0.01 Test accuracy: 98.40%
|
||||
Epoch: 21 Loss: 0.01 Test accuracy: 98.66%
|
||||
Epoch: 22 Loss: 0.11 Test accuracy: 98.71%
|
||||
Epoch: 23 Loss: 0.17 Test accuracy: 98.72%
|
||||
Epoch: 24 Loss: 0.21 Test accuracy: 98.72%
|
||||
Epoch: 25 Loss: 0.00 Test accuracy: 98.72%
|
||||
Epoch: 26 Loss: 0.00 Test accuracy: 98.72%
|
||||
Epoch: 27 Loss: 0.00 Test accuracy: 98.72%
|
||||
Epoch: 28 Loss: 0.06 Test accuracy: 98.72%
|
||||
Epoch: 29 Loss: 0.11 Test accuracy: 98.72%
|
||||
Best test accuracy: 98.76%
|
||||
Taken time: 8.84 mins
|
||||
|
||||
|
||||
go run . -model=cnn -device=cuda
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.15 Test accuracy: 98.48%
|
||||
Epoch: 1 Loss: 0.23 Test accuracy: 98.95%
|
||||
Epoch: 2 Loss: 0.02 Test accuracy: 98.94%
|
||||
Epoch: 3 Loss: 0.01 Test accuracy: 99.06%
|
||||
Epoch: 4 Loss: 0.16 Test accuracy: 99.03%
|
||||
Epoch: 5 Loss: 0.01 Test accuracy: 99.07%
|
||||
Epoch: 6 Loss: 0.22 Test accuracy: 98.25%
|
||||
Epoch: 7 Loss: 0.06 Test accuracy: 98.23%
|
||||
Epoch: 8 Loss: 0.26 Test accuracy: 98.25%
|
||||
Epoch: 9 Loss: 0.07 Test accuracy: 98.25%
|
||||
Epoch: 10 Loss: 0.02 Test accuracy: 98.25%
|
||||
Epoch: 11 Loss: 0.04 Test accuracy: 98.35%
|
||||
Epoch: 12 Loss: 0.01 Test accuracy: 98.36%
|
||||
Epoch: 13 Loss: 0.01 Test accuracy: 98.36%
|
||||
Epoch: 14 Loss: 0.04 Test accuracy: 98.42%
|
||||
Epoch: 15 Loss: 0.04 Test accuracy: 98.54%
|
||||
Epoch: 16 Loss: 0.11 Test accuracy: 98.53%
|
||||
Epoch: 17 Loss: 0.07 Test accuracy: 98.53%
|
||||
Epoch: 18 Loss: 0.45 Test accuracy: 98.53%
|
||||
Epoch: 19 Loss: 0.07 Test accuracy: 98.53%
|
||||
Epoch: 20 Loss: 0.15 Test accuracy: 98.53%
|
||||
Epoch: 21 Loss: 0.20 Test accuracy: 98.53%
|
||||
Epoch: 22 Loss: 0.02 Test accuracy: 98.53%
|
||||
Epoch: 23 Loss: 0.02 Test accuracy: 98.53%
|
||||
Epoch: 24 Loss: 0.00 Test accuracy: 98.53%
|
||||
Epoch: 25 Loss: 0.01 Test accuracy: 98.53%
|
||||
Epoch: 26 Loss: 0.12 Test accuracy: 98.53%
|
||||
Epoch: 27 Loss: 0.01 Test accuracy: 98.53%
|
||||
Epoch: 28 Loss: 0.04 Test accuracy: 98.53%
|
||||
Epoch: 29 Loss: 0.18 Test accuracy: 98.53%
|
||||
Best test accuracy: 99.07%
|
||||
Taken time: 8.82 mins
|
||||
|
||||
|
||||
testImages: [10000 784]
|
||||
testLabels: [10000]
|
||||
Epoch: 0 Loss: 0.02 Test accuracy: 98.37%
|
||||
Epoch: 1 Loss: 0.01 Test accuracy: 98.26%
|
||||
Epoch: 2 Loss: 0.02 Test accuracy: 98.51%
|
||||
Epoch: 3 Loss: 0.17 Test accuracy: 98.56%
|
||||
Epoch: 4 Loss: 0.02 Test accuracy: 98.60%
|
||||
Epoch: 5 Loss: 0.00 Test accuracy: 98.66%
|
||||
Epoch: 6 Loss: 0.01 Test accuracy: 98.85%
|
||||
Epoch: 7 Loss: 0.02 Test accuracy: 98.86%
|
||||
Epoch: 8 Loss: 0.01 Test accuracy: 98.42%
|
||||
Epoch: 9 Loss: 0.00 Test accuracy: 98.44%
|
||||
Epoch: 10 Loss: 0.02 Test accuracy: 98.50%
|
||||
Epoch: 11 Loss: 0.00 Test accuracy: 98.50%
|
||||
Epoch: 12 Loss: 0.05 Test accuracy: 98.50%
|
||||
Epoch: 13 Loss: 0.13 Test accuracy: 98.50%
|
||||
Epoch: 14 Loss: 0.00 Test accuracy: 98.50%
|
||||
Epoch: 15 Loss: 0.12 Test accuracy: 98.50%
|
||||
Epoch: 16 Loss: 0.00 Test accuracy: 98.50%
|
||||
Epoch: 17 Loss: 0.03 Test accuracy: 98.50%
|
||||
Epoch: 18 Loss: 0.41 Test accuracy: 98.50%
|
||||
Epoch: 19 Loss: 0.17 Test accuracy: 98.50%
|
||||
Epoch: 20 Loss: 0.26 Test accuracy: 98.50%
|
||||
Epoch: 21 Loss: 0.00 Test accuracy: 98.50%
|
||||
Epoch: 22 Loss: 0.29 Test accuracy: 98.50%
|
||||
Epoch: 23 Loss: 0.00 Test accuracy: 98.50%
|
||||
Epoch: 24 Loss: 0.20 Test accuracy: 98.50%
|
||||
Epoch: 25 Loss: 0.01 Test accuracy: 98.50%
|
||||
Epoch: 26 Loss: 0.18 Test accuracy: 98.50%
|
||||
Epoch: 27 Loss: 0.01 Test accuracy: 98.50%
|
||||
Epoch: 28 Loss: 0.12 Test accuracy: 98.50%
|
||||
Epoch: 29 Loss: 0.04 Test accuracy: 98.50%
|
||||
Best test accuracy: 98.86%
|
||||
Taken time: 8.77 mins
|
||||
|
||||
```
|
||||
|
|
|
@ -3,6 +3,8 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
|
@ -16,11 +18,14 @@ const (
|
|||
|
||||
epochsCNN = 30
|
||||
batchCNN = 256
|
||||
batchSize = 256
|
||||
// batchSize = 256
|
||||
batchSize = 32
|
||||
|
||||
LrCNN = 3 * 1e-4
|
||||
)
|
||||
|
||||
var mu sync.Mutex
|
||||
|
||||
type Net struct {
|
||||
conv1 *nn.Conv2D
|
||||
conv2 *nn.Conv2D
|
||||
|
@ -43,45 +48,32 @@ func newNet(vs *nn.Path) *Net {
|
|||
|
||||
func (n *Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
|
||||
outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)
|
||||
defer outView1.MustDrop()
|
||||
|
||||
outC1 := outView1.Apply(n.conv1)
|
||||
|
||||
outMP1 := outC1.MaxPool2DDefault(2, true)
|
||||
defer outMP1.MustDrop()
|
||||
|
||||
outC2 := outMP1.Apply(n.conv2)
|
||||
|
||||
outMP2 := outC2.MaxPool2DDefault(2, true)
|
||||
|
||||
outView2 := outMP2.MustView([]int64{-1, 1024}, true)
|
||||
defer outView2.MustDrop()
|
||||
|
||||
outFC1 := outView2.Apply(n.fc1)
|
||||
|
||||
outRelu := outFC1.MustRelu(true)
|
||||
defer outRelu.MustDrop()
|
||||
outRelu := outFC1.MustRelu(false)
|
||||
outDropout := ts.MustDropout(outRelu, 0.5, train)
|
||||
defer outDropout.MustDrop()
|
||||
|
||||
return outDropout.Apply(n.fc2)
|
||||
}
|
||||
|
||||
func runCNN1() {
|
||||
|
||||
var ds *vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
// ds.TrainImages [60000, 784]
|
||||
// ds.TrainLabels [60000, 784]
|
||||
testImages := ds.TestImages // [10000, 784]
|
||||
testLabels := ds.TestLabels // [10000, 784]
|
||||
trainImages := ds.TrainImages.MustTo(device, false) //[60000, 784]
|
||||
trainLabels := ds.TrainLabels.MustTo(device, false) // [60000, 784]
|
||||
testImages := ds.TestImages.MustTo(device, false) // [10000, 784]
|
||||
testLabels := ds.TestLabels.MustTo(device, false) // [10000, 784]
|
||||
|
||||
fmt.Printf("testImages: %v\n", testImages.MustSize())
|
||||
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
|
||||
|
||||
device := gotch.CudaIfAvailable()
|
||||
vs := nn.NewVarStore(device)
|
||||
|
||||
net := newNet(vs.Root())
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
|
||||
// opt, err := nn.DefaultSGDConfig().Build(vs, LrCNN)
|
||||
|
@ -96,10 +88,9 @@ func runCNN1() {
|
|||
totalSize := ds.TrainImages.MustSize()[0]
|
||||
samples := int(totalSize)
|
||||
// Shuffling
|
||||
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
||||
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
||||
index.MustDrop()
|
||||
index := ts.MustRandperm(int64(totalSize), gotch.Int64, device)
|
||||
imagesTs := trainImages.MustIndexSelect(0, index, false)
|
||||
labelsTs := trainLabels.MustIndexSelect(0, index, false)
|
||||
|
||||
batches := samples / batchSize
|
||||
batchIndex := 0
|
||||
|
@ -114,36 +105,29 @@ func runCNN1() {
|
|||
|
||||
// Indexing
|
||||
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
|
||||
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
|
||||
|
||||
bImages = bImages.MustTo(vs.Device(), true)
|
||||
bLabels = bLabels.MustTo(vs.Device(), true)
|
||||
|
||||
logits := net.ForwardT(bImages, true)
|
||||
bImages.MustDrop()
|
||||
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
logits.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
|
||||
loss = loss.MustSetRequiresGrad(true, true)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
epocLoss = loss.Float64Values()[0]
|
||||
loss.MustDrop()
|
||||
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
fmt.Printf("Start eval...")
|
||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1000)
|
||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
|
||||
if testAccuracy > bestAccuracy {
|
||||
bestAccuracy = testAccuracy
|
||||
}
|
||||
})
|
||||
|
||||
imagesTs.MustDrop()
|
||||
labelsTs.MustDrop()
|
||||
}
|
||||
|
||||
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
|
||||
ts.CleanUp()
|
||||
}
|
||||
|
|
|
@ -9,31 +9,37 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
ImageDim int64 = 784
|
||||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
ImageDim int64 = 784
|
||||
Label int64 = 10
|
||||
|
||||
epochs = 200
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
var ds *vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDir)
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
trainImages := ds.TrainImages.MustTo(device, true)
|
||||
trainLabels := ds.TrainLabels.MustTo(device, true)
|
||||
testImages := ds.TestImages.MustTo(device, true)
|
||||
testLabels := ds.TestLabels.MustTo(device, true)
|
||||
|
||||
device := gotch.CPU
|
||||
dtype := gotch.Float
|
||||
|
||||
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true, false)
|
||||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true, false)
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
// NOTE(TT). if initiating with random float, result is worse.
|
||||
// ws := ts.MustRandn([]int64{ImageDim, Label}, dtype, device)
|
||||
// bs := ts.MustRandn([]int64{Label}, dtype, device)
|
||||
// ws.MustRequiresGrad_(true)
|
||||
// bs.MustRequiresGrad_(true)
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
weight := ts.NewTensor()
|
||||
reduction := int64(1) // Mean of loss
|
||||
ignoreIndex := int64(-100)
|
||||
|
||||
logits := ds.TrainImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(ds.TrainLabels, weight, reduction, ignoreIndex, true)
|
||||
logits := trainImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
loss := logits.MustLogSoftmax(-1, dtype, true).MustNllLoss(trainLabels, weight, reduction, ignoreIndex, true)
|
||||
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
|
@ -42,13 +48,13 @@ func runLinear() {
|
|||
ts.NoGrad(func() {
|
||||
ws.Add_(ws.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
|
||||
bs.Add_(bs.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
|
||||
ts.CleanUp(100)
|
||||
})
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||
testLogits := testImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(testLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
|
||||
|
||||
loss.MustDrop()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,19 +2,31 @@ package main
|
|||
|
||||
import (
|
||||
"flag"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
var model string
|
||||
var (
|
||||
model string
|
||||
deviceOpt string
|
||||
device gotch.Device
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&model, "model", "linear", "specify a model to run")
|
||||
|
||||
flag.StringVar(&deviceOpt, "device", "cpu", "specify device to run on. Eitheir 'cpu' or 'cuda'")
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if deviceOpt == "cuda" {
|
||||
device = gotch.CudaIfAvailable()
|
||||
} else {
|
||||
device = gotch.CPU
|
||||
}
|
||||
|
||||
switch model {
|
||||
case "linear":
|
||||
runLinear()
|
||||
|
|
|
@ -3,6 +3,7 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
|
@ -11,16 +12,16 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
ImageDimNN int64 = 784
|
||||
HiddenNodesNN int64 = 128
|
||||
LabelNN int64 = 10
|
||||
MnistDirNN string = "../../data/mnist"
|
||||
ImageDimNN int64 = 784
|
||||
HiddenNodesNN int64 = 128
|
||||
LabelNN int64 = 10
|
||||
|
||||
epochsNN = 200
|
||||
|
||||
LrNN = 1e-3
|
||||
)
|
||||
|
||||
var MnistDirNN string = fmt.Sprintf("%s/%s", gotch.CachedDir, "mnist")
|
||||
var l nn.Linear
|
||||
|
||||
func netInit(vs *nn.Path) ts.Module {
|
||||
|
@ -38,7 +39,6 @@ func netInit(vs *nn.Path) ts.Module {
|
|||
}
|
||||
|
||||
func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimizer, epoch int) {
|
||||
|
||||
logits := m.Forward(trainX)
|
||||
loss := logits.CrossEntropyForLogits(trainY)
|
||||
|
||||
|
@ -47,26 +47,29 @@ func train(trainX, trainY, testX, testY *ts.Tensor, m ts.Module, opt *nn.Optimiz
|
|||
testLogits := m.Forward(testX)
|
||||
testAccuracy := testLogits.AccuracyForLogits(testY)
|
||||
accuracy := testAccuracy.Float64Values()[0] * 100
|
||||
testAccuracy.MustDrop()
|
||||
lossVal := loss.Float64Values()[0]
|
||||
loss.MustDrop()
|
||||
|
||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", epoch, lossVal, accuracy)
|
||||
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
func runNN() {
|
||||
|
||||
var ds *vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
vs := nn.NewVarStore(device)
|
||||
net := netInit(vs.Root())
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||
train(ds.TrainImages, ds.TrainLabels, ds.TestImages, ds.TestLabels, net, opt, epoch)
|
||||
}
|
||||
trainImages := ds.TrainImages.MustTo(device, true)
|
||||
trainLabels := ds.TrainLabels.MustTo(device, true)
|
||||
testImages := ds.TestImages.MustTo(device, true)
|
||||
testLabels := ds.TestLabels.MustTo(device, true)
|
||||
|
||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||
train(trainImages, trainLabels, testImages, testLabels, net, opt, epoch)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
|
@ -25,8 +26,10 @@ func main() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
err = pickle.LoadInfo(modelFile)
|
||||
m, err := pickle.LoadModelInfo(modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(m)
|
||||
}
|
||||
|
|
|
@ -113,10 +113,15 @@ var ModelUrls map[string]string = map[string]string{
|
|||
// 1. Resolves input string to a fullpath cached filename candidate.
|
||||
// 2. Check it at `CachedDir`, if exists, then return the candidate. If not
|
||||
// 3. Retrieves and Caches data to `CachedDir` and returns path to cached data
|
||||
func CachedPath(filenameOrUrl string) (resolvedPath string, err error) {
|
||||
func CachedPath(filenameOrUrl string, folderOpt ...string) (resolvedPath string, err error) {
|
||||
filename := path.Base(filenameOrUrl)
|
||||
// Resolves to "candidate" filename at `CachedDir`
|
||||
cachedFileCandidate := fmt.Sprintf("%s/%s", CachedDir, filename)
|
||||
fullPath := CachedDir
|
||||
if len(folderOpt) > 0 {
|
||||
fullPath = fmt.Sprintf("%v/%v", CachedDir, folderOpt[0])
|
||||
}
|
||||
|
||||
cachedFileCandidate := fmt.Sprintf("%s/%s", fullPath, filename)
|
||||
|
||||
// 1. Cached candidate file exists
|
||||
if _, err := os.Stat(cachedFileCandidate); err == nil {
|
||||
|
|
843
gen/gen.ml
843
gen/gen.ml
File diff suppressed because it is too large
Load Diff
1357
gen/gen.ml.1.11
Normal file
1357
gen/gen.ml.1.11
Normal file
File diff suppressed because it is too large
Load Diff
198120
gen/pytorch/Declarations-v2.0.0.yaml
Normal file
198120
gen/pytorch/Declarations-v2.0.0.yaml
Normal file
File diff suppressed because it is too large
Load Diff
198402
gen/pytorch/Declarations-v2.0.0.yaml.original
Normal file
198402
gen/pytorch/Declarations-v2.0.0.yaml.original
Normal file
File diff suppressed because it is too large
Load Diff
4
go.mod
4
go.mod
|
@ -1,8 +1,8 @@
|
|||
module github.com/sugarme/gotch
|
||||
|
||||
go 1.14
|
||||
go 1.19
|
||||
|
||||
require (
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0
|
||||
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5
|
||||
golang.org/x/image v0.5.0
|
||||
)
|
||||
|
|
27
go.sum
27
go.sum
|
@ -1,5 +1,28 @@
|
|||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5 h1:QelT11PB4FXiDEXucrfNckHoFxwt8USGY1ajP1ZF5lM=
|
||||
golang.org/x/image v0.0.0-20200927104501-e162460cd6b5/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/image v0.5.0 h1:5JMiNunQeQw++mMOz48/ISeNu3Iweh/JaZU8ZLqHRrI=
|
||||
golang.org/x/image v0.5.0/go.mod h1:FVC7BI/5Ym8R25iw5OLsgshdUBbT1h5jZTpA+mvAdZ4=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
167
half/bfloat16.go
Normal file
167
half/bfloat16.go
Normal file
|
@ -0,0 +1,167 @@
|
|||
package half
|
||||
|
||||
import (
|
||||
"math"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// A 16-bit floating point type implementing the bfloat16 format.
|
||||
// Ref. https://en.wikipedia.org/wiki/Bfloat16_floating-point_format
|
||||
// https://github.com/starkat99/half-rs/tree/main/src/bfloat
|
||||
|
||||
// The bfloat16 - Google 'brain' floating point format is a truncated 16-bit version of the IEEE 754 standard binary32.
|
||||
// bfloat16 has approximately the same dynamic range as float32 (8 bits -> 3.4 × 10^38) by having a lower precision than float16.
|
||||
// While float16 has a precision of 10 bits, bfloat16 has a precision of only 7 bits.
|
||||
//
|
||||
// +------------+------------------------+----------------------------+
|
||||
// | 1-bit sign | 8-bit exponent (range) | 7-bit fraction (precision) |
|
||||
// +------------+------------------------+----------------------------+
|
||||
type BFloat16 uint16
|
||||
|
||||
// Ref.https://github.com/starkat99/half-rs/blob/cabfc74e2a48b44b4556780f9d1550dd50a708be/src/bfloat/convert.rs#L5C1-L24C1
|
||||
func Float32ToBFloat16(value float32) uint16 {
|
||||
// convert to raw bytes
|
||||
x := math.Float32bits(value)
|
||||
|
||||
// Check for NaN
|
||||
if (x & 0x7FFF_FFFF) > 0x7F80_0000 {
|
||||
// keep high part of current mantissa but also set most significant mantissa bit
|
||||
return uint16((x >> 16) | 0x0040)
|
||||
}
|
||||
|
||||
// Round and shift
|
||||
var roundBit uint32 = 0x0000_8000
|
||||
if ((x & roundBit) != 0) && ((x & (3*roundBit - 1)) != 0) {
|
||||
return uint16(x>>16) + 1
|
||||
} else {
|
||||
return uint16(x >> 16)
|
||||
}
|
||||
}
|
||||
|
||||
func Float64ToBFloat16(value float64) uint16 {
|
||||
// Convert o raw bytes, truncating the last 32-bits of mantissa
|
||||
// that precision will always be lost on half-precision
|
||||
val := math.Float64bits(value)
|
||||
x := uint32(val >> 32)
|
||||
|
||||
// Extract IEEE754 components
|
||||
sign := x & 0x8000_0000
|
||||
exp := x & 0x7FF0_0000
|
||||
man := x & 0x000F_FFFF
|
||||
|
||||
// Check for all exponent bit being set, which is Infinity or NaN
|
||||
if exp == 0x7FF0_0000 {
|
||||
// Set mantissa MSB for NaN and also keep shifted mantissa bits.
|
||||
// Also check the last 32 bits.
|
||||
var nanBit uint32 = 0x0040
|
||||
if man == 0 && (uint32(val) == 0) {
|
||||
nanBit = 0
|
||||
}
|
||||
|
||||
return uint16((sign >> 16) | 0x7F80 | nanBit | (man >> 13))
|
||||
}
|
||||
|
||||
// The number is normalized, start assembling half precision version
|
||||
halfSign := sign >> 16
|
||||
|
||||
// Unbias the exponent, then bias for bfloat16 precision
|
||||
unbiasedExp := (int64(exp>>20) - 1023)
|
||||
halfExp := unbiasedExp + 127
|
||||
|
||||
// Check for exponent overflow, return +infinity
|
||||
if halfExp >= 0xFF {
|
||||
return uint16(halfSign | 0x7F80)
|
||||
}
|
||||
|
||||
// Check for underflow
|
||||
if halfExp <= 0 {
|
||||
// Check mantissa for what we can do
|
||||
if 7-halfExp > 21 {
|
||||
// No rounding possibility, so this is a full underflow, return signed zero
|
||||
return uint16(halfSign)
|
||||
}
|
||||
|
||||
// Don't forget about hidden leading mantissa bit when assembling mantissa
|
||||
man = man | 0x0010_0000
|
||||
halfMan := man >> (14 - halfExp)
|
||||
|
||||
// Check for rounding
|
||||
var roundBit uint32 = 1 << (13 - halfExp)
|
||||
if ((man & roundBit) != 0) && ((man & (3*roundBit - 1)) != 0) {
|
||||
halfMan += 1
|
||||
}
|
||||
|
||||
// No exponent for subnormals
|
||||
return uint16(halfSign | halfMan)
|
||||
}
|
||||
|
||||
// Rebias the exponent
|
||||
halfExp1 := uint32(halfExp) << 7
|
||||
halfMan1 := man >> 13
|
||||
|
||||
// Check for rounding
|
||||
var roundBit1 uint32 = 0x0000_1000
|
||||
|
||||
if ((man & roundBit1) != 0) && ((man & (3*roundBit1 - 1)) != 0) {
|
||||
// Round it
|
||||
return uint16((halfSign | halfExp1 | halfMan1) + 1)
|
||||
} else {
|
||||
return uint16(halfSign | halfExp1 | halfMan1)
|
||||
}
|
||||
}
|
||||
|
||||
func BFloat16ToFloat32(i uint16) float32 {
|
||||
// If NaN, keep current mantissa but also set most significant mantissa bit
|
||||
if i&0x7FFF > 0x7F80 {
|
||||
return math.Float32frombits((uint32(i) | 0x0040) << 16)
|
||||
} else {
|
||||
return math.Float32frombits(uint32(i) << 16)
|
||||
}
|
||||
}
|
||||
|
||||
func BFloat16ToFloat64(i uint16) float64 {
|
||||
// Check for signed zero
|
||||
if i&0x7FFF == 0 {
|
||||
return math.Float64frombits(uint64(i) << 48)
|
||||
}
|
||||
|
||||
halfSign := uint64(i & 0x8000)
|
||||
halfExp := uint64(i & 0x7F80)
|
||||
halfMan := uint64(i & 0x007F)
|
||||
|
||||
// Check for an infinity or NaN when all exponent bits set
|
||||
if halfExp == 0x7F80 {
|
||||
// Check for signed infinity if mantissa is zero
|
||||
if halfMan == 0 {
|
||||
return math.Float64frombits((halfSign << 48) | 0x7FF0_0000_0000_0000)
|
||||
} else {
|
||||
// NaN, keep current mantissa but also set most significant mantissa bit
|
||||
return math.Float64frombits((halfSign << 48) | 0x7FF8_0000_0000_0000 | (halfMan << 45))
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate double-precision components with adjusted exponent
|
||||
sign := halfSign << 48
|
||||
|
||||
// Unbias exponent
|
||||
unbiasedExp := (int64(halfExp) >> 7) - 127
|
||||
|
||||
// Check for subnormals, which will be normalized by adjusting exponent
|
||||
if halfExp == 0 {
|
||||
// Calculate how much to adjust the exponent by
|
||||
// leading zeros uint16
|
||||
e := bits.LeadingZeros16(uint16(halfMan)) - 9
|
||||
|
||||
// Rebias and adjust exponent
|
||||
exp := (uint64(1023-127-e) << 52)
|
||||
man := (halfMan << (46 + e)) & 0xF_FFFF_FFFF_FFFF
|
||||
|
||||
return math.Float64frombits(sign | exp | man)
|
||||
}
|
||||
|
||||
// Rebias exponent for a normalized normal
|
||||
exp := uint64(unbiasedExp+1023) << 52
|
||||
man := (halfMan & 0x007F) << 45
|
||||
|
||||
return math.Float64frombits(sign | exp | man)
|
||||
}
|
1
half/bfloat16_test.go
Normal file
1
half/bfloat16_test.go
Normal file
|
@ -0,0 +1 @@
|
|||
package half
|
303
half/float16.go
Normal file
303
half/float16.go
Normal file
|
@ -0,0 +1,303 @@
|
|||
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
|
||||
//
|
||||
// Special thanks to Kathryn Long for her Rust implementation
|
||||
// of float16 at github.com/starkat99/half-rs (MIT license)
|
||||
|
||||
// Package half defines support for half-precision floating-point numbers.
|
||||
package half
|
||||
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// Float16 represents IEEE 754 half-precision floating-point numbers (binary16).
|
||||
type Float16 uint16
|
||||
|
||||
// Precision indicates whether the conversion to Float16 is
|
||||
// exact, subnormal without dropped bits, inexact, underflow, or overflow.
|
||||
type Precision int
|
||||
|
||||
const (
|
||||
|
||||
// PrecisionExact is for non-subnormals that don't drop bits during conversion.
|
||||
// All of these can round-trip. Should always convert to float16.
|
||||
PrecisionExact Precision = iota
|
||||
|
||||
// PrecisionUnknown is for subnormals that don't drop bits during conversion but
|
||||
// not all of these can round-trip so precision is unknown without more effort.
|
||||
// Only 2046 of these can round-trip and the rest cannot round-trip.
|
||||
PrecisionUnknown
|
||||
|
||||
// PrecisionInexact is for dropped significand bits and cannot round-trip.
|
||||
// Some of these are subnormals. Cannot round-trip float32->float16->float32.
|
||||
PrecisionInexact
|
||||
|
||||
// PrecisionUnderflow is for Underflows. Cannot round-trip float32->float16->float32.
|
||||
PrecisionUnderflow
|
||||
|
||||
// PrecisionOverflow is for Overflows. Cannot round-trip float32->float16->float32.
|
||||
PrecisionOverflow
|
||||
)
|
||||
|
||||
// PrecisionFromfloat32 returns Precision without performing
|
||||
// the conversion. Conversions from both Infinity and NaN
|
||||
// values will always report PrecisionExact even if NaN payload
|
||||
// or NaN-Quiet-Bit is lost. This function is kept simple to
|
||||
// allow inlining and run < 0.5 ns/op, to serve as a fast filter.
|
||||
func PrecisionFromfloat32(f32 float32) Precision {
|
||||
u32 := math.Float32bits(f32)
|
||||
|
||||
if u32 == 0 || u32 == 0x80000000 {
|
||||
// +- zero will always be exact conversion
|
||||
return PrecisionExact
|
||||
}
|
||||
|
||||
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
|
||||
const EXPSHIFT uint32 = 23
|
||||
const EXPBIAS uint32 = 127
|
||||
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
|
||||
const DROPMASK uint32 = COEFMASK >> 10
|
||||
|
||||
exp := int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
|
||||
coef := u32 & COEFMASK
|
||||
|
||||
if exp == 128 {
|
||||
// +- infinity or NaN
|
||||
// apps may want to do extra checks for NaN separately
|
||||
return PrecisionExact
|
||||
}
|
||||
|
||||
// https://en.wikipedia.org/wiki/Half-precision_floating-point_format says,
|
||||
// "Decimals between 2^−24 (minimum positive subnormal) and 2^−14 (maximum subnormal): fixed interval 2^−24"
|
||||
if exp < -24 {
|
||||
return PrecisionUnderflow
|
||||
}
|
||||
if exp > 15 {
|
||||
return PrecisionOverflow
|
||||
}
|
||||
if (coef & DROPMASK) != uint32(0) {
|
||||
// these include subnormals and non-subnormals that dropped bits
|
||||
return PrecisionInexact
|
||||
}
|
||||
|
||||
if exp < -14 {
|
||||
// Subnormals. Caller may want to test these further.
|
||||
// There are 2046 subnormals that can successfully round-trip f32->f16->f32
|
||||
// and 20 of those 2046 have 32-bit input coef == 0.
|
||||
// RFC 7049 and 7049bis Draft 12 don't precisely define "preserves value"
|
||||
// so some protocols and libraries will choose to handle subnormals differently
|
||||
// when deciding to encode them to CBOR float32 vs float16.
|
||||
return PrecisionUnknown
|
||||
}
|
||||
|
||||
return PrecisionExact
|
||||
}
|
||||
|
||||
// Frombits returns the float16 number corresponding to the IEEE 754 binary16
|
||||
// representation u16, with the sign bit of u16 and the result in the same bit
|
||||
// position. Frombits(Bits(x)) == x.
|
||||
func Frombits(u16 uint16) Float16 {
|
||||
return Float16(u16)
|
||||
}
|
||||
|
||||
// Fromfloat32 returns a Float16 value converted from f32. Conversion uses
|
||||
// IEEE default rounding (nearest int, with ties to even).
|
||||
func Fromfloat32(f32 float32) Float16 {
|
||||
return Float16(f32bitsToF16bits(math.Float32bits(f32)))
|
||||
}
|
||||
|
||||
// ErrInvalidNaNValue indicates a NaN was not received.
|
||||
const ErrInvalidNaNValue = float16Error("float16: invalid NaN value, expected IEEE 754 NaN")
|
||||
|
||||
type float16Error string
|
||||
|
||||
func (e float16Error) Error() string { return string(e) }
|
||||
|
||||
// FromNaN32ps converts nan to IEEE binary16 NaN while preserving both
|
||||
// signaling and payload. Unlike Fromfloat32(), which can only return
|
||||
// qNaN because it sets quiet bit = 1, this can return both sNaN and qNaN.
|
||||
// If the result is infinity (sNaN with empty payload), then the
|
||||
// lowest bit of payload is set to make the result a NaN.
|
||||
// Returns ErrInvalidNaNValue and 0x7c01 (sNaN) if nan isn't IEEE 754 NaN.
|
||||
// This function was kept simple to be able to inline.
|
||||
func FromNaN32ps(nan float32) (Float16, error) {
|
||||
const SNAN = Float16(uint16(0x7c01)) // signaling NaN
|
||||
|
||||
u32 := math.Float32bits(nan)
|
||||
sign := u32 & 0x80000000
|
||||
exp := u32 & 0x7f800000
|
||||
coef := u32 & 0x007fffff
|
||||
|
||||
if (exp != 0x7f800000) || (coef == 0) {
|
||||
return SNAN, ErrInvalidNaNValue
|
||||
}
|
||||
|
||||
u16 := uint16((sign >> 16) | uint32(0x7c00) | (coef >> 13))
|
||||
|
||||
if (u16 & 0x03ff) == 0 {
|
||||
// result became infinity, make it NaN by setting lowest bit in payload
|
||||
u16 |= 0x0001
|
||||
}
|
||||
|
||||
return Float16(u16), nil
|
||||
}
|
||||
|
||||
// NaN returns a Float16 of IEEE 754 binary16 not-a-number (NaN).
|
||||
// Returned NaN value 0x7e01 has all exponent bits = 1 with the
|
||||
// first and last bits = 1 in the significand. This is consistent
|
||||
// with Go's 64-bit math.NaN(). Canonical CBOR in RFC 7049 uses 0x7e00.
|
||||
func NaN() Float16 {
|
||||
return Float16(0x7e01)
|
||||
}
|
||||
|
||||
// Inf returns a Float16 with an infinity value with the specified sign.
|
||||
// A sign >= returns positive infinity.
|
||||
// A sign < 0 returns negative infinity.
|
||||
func Inf(sign int) Float16 {
|
||||
if sign >= 0 {
|
||||
return Float16(0x7c00)
|
||||
}
|
||||
return Float16(0x8000 | 0x7c00)
|
||||
}
|
||||
|
||||
// Float32 returns a float32 converted from f (Float16).
|
||||
// This is a lossless conversion.
|
||||
func (f Float16) Float32() float32 {
|
||||
u32 := f16bitsToF32bits(uint16(f))
|
||||
return math.Float32frombits(u32)
|
||||
}
|
||||
|
||||
// Bits returns the IEEE 754 binary16 representation of f, with the sign bit
|
||||
// of f and the result in the same bit position. Bits(Frombits(x)) == x.
|
||||
func (f Float16) Bits() uint16 {
|
||||
return uint16(f)
|
||||
}
|
||||
|
||||
// IsNaN reports whether f is an IEEE 754 binary16 “not-a-number” value.
|
||||
func (f Float16) IsNaN() bool {
|
||||
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0)
|
||||
}
|
||||
|
||||
// IsQuietNaN reports whether f is a quiet (non-signaling) IEEE 754 binary16
|
||||
// “not-a-number” value.
|
||||
func (f Float16) IsQuietNaN() bool {
|
||||
return (f&0x7c00 == 0x7c00) && (f&0x03ff != 0) && (f&0x0200 != 0)
|
||||
}
|
||||
|
||||
// IsInf reports whether f is an infinity (inf).
|
||||
// A sign > 0 reports whether f is positive inf.
|
||||
// A sign < 0 reports whether f is negative inf.
|
||||
// A sign == 0 reports whether f is either inf.
|
||||
func (f Float16) IsInf(sign int) bool {
|
||||
return ((f == 0x7c00) && sign >= 0) ||
|
||||
(f == 0xfc00 && sign <= 0)
|
||||
}
|
||||
|
||||
// IsFinite returns true if f is neither infinite nor NaN.
|
||||
func (f Float16) IsFinite() bool {
|
||||
return (uint16(f) & uint16(0x7c00)) != uint16(0x7c00)
|
||||
}
|
||||
|
||||
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
|
||||
func (f Float16) IsNormal() bool {
|
||||
exp := uint16(f) & uint16(0x7c00)
|
||||
return (exp != uint16(0x7c00)) && (exp != 0)
|
||||
}
|
||||
|
||||
// Signbit reports whether f is negative or negative zero.
|
||||
func (f Float16) Signbit() bool {
|
||||
return (uint16(f) & uint16(0x8000)) != 0
|
||||
}
|
||||
|
||||
// String satisfies the fmt.Stringer interface.
|
||||
func (f Float16) String() string {
|
||||
return strconv.FormatFloat(float64(f.Float32()), 'f', -1, 32)
|
||||
}
|
||||
|
||||
// f16bitsToF32bits returns uint32 (float32 bits) converted from specified uint16.
|
||||
func f16bitsToF32bits(in uint16) uint32 {
|
||||
// All 65536 conversions with this were confirmed to be correct
|
||||
// by Montgomery Edwards⁴⁴⁸ (github.com/x448).
|
||||
|
||||
sign := uint32(in&0x8000) << 16 // sign for 32-bit
|
||||
exp := uint32(in&0x7c00) >> 10 // exponenent for 16-bit
|
||||
coef := uint32(in&0x03ff) << 13 // significand for 32-bit
|
||||
|
||||
if exp == 0x1f {
|
||||
if coef == 0 {
|
||||
// infinity
|
||||
return sign | 0x7f800000 | coef
|
||||
}
|
||||
// NaN
|
||||
return sign | 0x7fc00000 | coef
|
||||
}
|
||||
|
||||
if exp == 0 {
|
||||
if coef == 0 {
|
||||
// zero
|
||||
return sign
|
||||
}
|
||||
|
||||
// normalize subnormal numbers
|
||||
exp++
|
||||
for coef&0x7f800000 == 0 {
|
||||
coef <<= 1
|
||||
exp--
|
||||
}
|
||||
coef &= 0x007fffff
|
||||
}
|
||||
|
||||
return sign | ((exp + (0x7f - 0xf)) << 23) | coef
|
||||
}
|
||||
|
||||
// f32bitsToF16bits returns uint16 (Float16 bits) converted from the specified float32.
|
||||
// Conversion rounds to nearest integer with ties to even.
|
||||
func f32bitsToF16bits(u32 uint32) uint16 {
|
||||
// Translated from Rust to Go by Montgomery Edwards⁴⁴⁸ (github.com/x448).
|
||||
// All 4294967296 conversions with this were confirmed to be correct by x448.
|
||||
// Original Rust implementation is by Kathryn Long (github.com/starkat99) with MIT license.
|
||||
|
||||
sign := u32 & 0x80000000
|
||||
exp := u32 & 0x7f800000
|
||||
coef := u32 & 0x007fffff
|
||||
|
||||
if exp == 0x7f800000 {
|
||||
// NaN or Infinity
|
||||
nanBit := uint32(0)
|
||||
if coef != 0 {
|
||||
nanBit = uint32(0x0200)
|
||||
}
|
||||
return uint16((sign >> 16) | uint32(0x7c00) | nanBit | (coef >> 13))
|
||||
}
|
||||
|
||||
halfSign := sign >> 16
|
||||
|
||||
unbiasedExp := int32(exp>>23) - 127
|
||||
halfExp := unbiasedExp + 15
|
||||
|
||||
if halfExp >= 0x1f {
|
||||
return uint16(halfSign | uint32(0x7c00))
|
||||
}
|
||||
|
||||
if halfExp <= 0 {
|
||||
if 14-halfExp > 24 {
|
||||
return uint16(halfSign)
|
||||
}
|
||||
c := coef | uint32(0x00800000)
|
||||
halfCoef := c >> uint32(14-halfExp)
|
||||
roundBit := uint32(1) << uint32(13-halfExp)
|
||||
if (c&roundBit) != 0 && (c&(3*roundBit-1)) != 0 {
|
||||
halfCoef++
|
||||
}
|
||||
return uint16(halfSign | halfCoef)
|
||||
}
|
||||
|
||||
uHalfExp := uint32(halfExp) << 10
|
||||
halfCoef := coef >> 13
|
||||
roundBit := uint32(0x00001000)
|
||||
if (coef&roundBit) != 0 && (coef&(3*roundBit-1)) != 0 {
|
||||
return uint16((halfSign | uHalfExp | halfCoef) + 1)
|
||||
}
|
||||
return uint16(halfSign | uHalfExp | halfCoef)
|
||||
}
|
88
half/float16_bench_test.go
Normal file
88
half/float16_bench_test.go
Normal file
|
@ -0,0 +1,88 @@
|
|||
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
|
||||
|
||||
package half_test
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
float16 "github.com/sugarme/gotch/half"
|
||||
)
|
||||
|
||||
// prevent compiler optimizing out code by assigning to these
|
||||
var resultF16 float16.Float16
|
||||
var resultF32 float32
|
||||
var resultStr string
|
||||
var pcn float16.Precision
|
||||
|
||||
func BenchmarkFloat32pi(b *testing.B) {
|
||||
result := float32(0)
|
||||
pi32 := float32(math.Pi)
|
||||
pi16 := float16.Fromfloat32(pi32)
|
||||
for i := 0; i < b.N; i++ {
|
||||
f16 := float16.Frombits(uint16(pi16))
|
||||
result = f16.Float32()
|
||||
}
|
||||
resultF32 = result
|
||||
}
|
||||
|
||||
func BenchmarkFrombits(b *testing.B) {
|
||||
result := float16.Float16(0)
|
||||
pi32 := float32(math.Pi)
|
||||
pi16 := float16.Fromfloat32(pi32)
|
||||
for i := 0; i < b.N; i++ {
|
||||
result = float16.Frombits(uint16(pi16))
|
||||
}
|
||||
resultF16 = result
|
||||
}
|
||||
|
||||
func BenchmarkFromFloat32pi(b *testing.B) {
|
||||
result := float16.Float16(0)
|
||||
|
||||
pi := float32(math.Pi)
|
||||
for i := 0; i < b.N; i++ {
|
||||
result = float16.Fromfloat32(pi)
|
||||
}
|
||||
resultF16 = result
|
||||
}
|
||||
|
||||
func BenchmarkFromFloat32nan(b *testing.B) {
|
||||
result := float16.Float16(0)
|
||||
|
||||
nan := float32(math.NaN())
|
||||
for i := 0; i < b.N; i++ {
|
||||
result = float16.Fromfloat32(nan)
|
||||
}
|
||||
resultF16 = result
|
||||
}
|
||||
|
||||
func BenchmarkFromFloat32subnorm(b *testing.B) {
|
||||
result := float16.Float16(0)
|
||||
|
||||
subnorm := math.Float32frombits(0x007fffff)
|
||||
for i := 0; i < b.N; i++ {
|
||||
result = float16.Fromfloat32(subnorm)
|
||||
}
|
||||
resultF16 = result
|
||||
}
|
||||
|
||||
func BenchmarkPrecisionFromFloat32(b *testing.B) {
|
||||
var result float16.Precision
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
f32 := float32(0.00001) + float32(0.00001)
|
||||
result = float16.PrecisionFromfloat32(f32)
|
||||
}
|
||||
pcn = result
|
||||
}
|
||||
|
||||
func BenchmarkString(b *testing.B) {
|
||||
var result string
|
||||
|
||||
pi32 := float32(math.Pi)
|
||||
pi16 := float16.Fromfloat32(pi32)
|
||||
for i := 0; i < b.N; i++ {
|
||||
result = pi16.String()
|
||||
}
|
||||
resultStr = result
|
||||
}
|
798
half/float16_test.go
Normal file
798
half/float16_test.go
Normal file
|
@ -0,0 +1,798 @@
|
|||
// Copyright 2019 Montgomery Edwards⁴⁴⁸ and Faye Amacker
|
||||
|
||||
package half_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha512"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
float16 "github.com/sugarme/gotch/half"
|
||||
)
|
||||
|
||||
// wantF32toF16bits is a tiny subset of expected values
|
||||
var wantF32toF16bits = []struct {
|
||||
in float32
|
||||
out uint16
|
||||
}{
|
||||
// generated to provide 100% code coverage plus additional tests for rounding, etc.
|
||||
{in: math.Float32frombits(0x00000000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x00000001), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x00001fff), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x00002000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x00003fff), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x00004000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x007fffff), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x00800000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x33000000), out: 0x0000}, // in f32=0.000000, out f16=0
|
||||
{in: math.Float32frombits(0x33000001), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
|
||||
{in: math.Float32frombits(0x33000002), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
|
||||
{in: math.Float32frombits(0x387fc000), out: 0x03ff}, // in f32=0.000061, out f16=0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
|
||||
{in: math.Float32frombits(0x387fffff), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
|
||||
{in: math.Float32frombits(0x38800000), out: 0x0400}, // in f32=0.000061, out f16=0.000061035156
|
||||
{in: math.Float32frombits(0x38801fff), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
|
||||
{in: math.Float32frombits(0x38802000), out: 0x0401}, // in f32=0.000061, out f16=0.00006109476
|
||||
{in: math.Float32frombits(0x38803fff), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
|
||||
{in: math.Float32frombits(0x38804000), out: 0x0402}, // in f32=0.000061, out f16=0.000061154366
|
||||
{in: math.Float32frombits(0x33bfffff), out: 0x0001}, // in f32=0.000000, out f16=0.000000059604645
|
||||
{in: math.Float32frombits(0x33c00000), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
|
||||
{in: math.Float32frombits(0x33c00001), out: 0x0002}, // in f32=0.000000, out f16=0.00000011920929
|
||||
{in: math.Float32frombits(0x477fffff), out: 0x7c00}, // in f32=65535.996094, out f16=+Inf
|
||||
{in: math.Float32frombits(0x47800000), out: 0x7c00}, // in f32=65536.000000, out f16=+Inf
|
||||
{in: math.Float32frombits(0x7f7fffff), out: 0x7c00}, // in f32=340282346638528859811704183484516925440.000000, out f16=+Inf
|
||||
{in: math.Float32frombits(0x7f800000), out: 0x7c00}, // in f32=+Inf, out f16=+Inf
|
||||
{in: math.Float32frombits(0x7f801fff), out: 0x7e00}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0x7f802000), out: 0x7e01}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0x7f803fff), out: 0x7e01}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0x7f804000), out: 0x7e02}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0x7fffffff), out: 0x7fff}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0x80000000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||
{in: math.Float32frombits(0x80001fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||
{in: math.Float32frombits(0x80002000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||
{in: math.Float32frombits(0x80003fff), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||
{in: math.Float32frombits(0x80004000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||
{in: math.Float32frombits(0x807fffff), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||
{in: math.Float32frombits(0x80800000), out: 0x8000}, // in f32=-0.000000, out f16=-0
|
||||
{in: math.Float32frombits(0xb87fc000), out: 0x83ff}, // in f32=-0.000061, out f16=-0.00006097555 // exp32=-15 (underflows binary16 exp) but round-trips
|
||||
{in: math.Float32frombits(0xb87fffff), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
|
||||
{in: math.Float32frombits(0xb8800000), out: 0x8400}, // in f32=-0.000061, out f16=-0.000061035156
|
||||
{in: math.Float32frombits(0xb8801fff), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
|
||||
{in: math.Float32frombits(0xb8802000), out: 0x8401}, // in f32=-0.000061, out f16=-0.00006109476
|
||||
{in: math.Float32frombits(0xb8803fff), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
|
||||
{in: math.Float32frombits(0xb8804000), out: 0x8402}, // in f32=-0.000061, out f16=-0.000061154366
|
||||
{in: math.Float32frombits(0xc77fffff), out: 0xfc00}, // in f32=-65535.996094, out f16=-Inf
|
||||
{in: math.Float32frombits(0xc7800000), out: 0xfc00}, // in f32=-65536.000000, out f16=-Inf
|
||||
{in: math.Float32frombits(0xff7fffff), out: 0xfc00}, // in f32=-340282346638528859811704183484516925440.000000, out f16=-Inf
|
||||
{in: math.Float32frombits(0xff800000), out: 0xfc00}, // in f32=-Inf, out f16=-Inf
|
||||
{in: math.Float32frombits(0xff801fff), out: 0xfe00}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0xff802000), out: 0xfe01}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0xff803fff), out: 0xfe01}, // in f32=NaN, out f16=NaN
|
||||
{in: math.Float32frombits(0xff804000), out: 0xfe02}, // in f32=NaN, out f16=NaN
|
||||
// additional tests
|
||||
{in: math.Float32frombits(0xc77ff000), out: 0xfc00}, // in f32=-65520.000000, out f16=-Inf
|
||||
{in: math.Float32frombits(0xc77fef00), out: 0xfbff}, // in f32=-65519.000000, out f16=-65504
|
||||
{in: math.Float32frombits(0xc77fee00), out: 0xfbff}, // in f32=-65518.000000, out f16=-65504
|
||||
{in: math.Float32frombits(0xc5802000), out: 0xec01}, // in f32=-4100.000000, out f16=-4100
|
||||
{in: math.Float32frombits(0xc5801800), out: 0xec01}, // in f32=-4099.000000, out f16=-4100
|
||||
{in: math.Float32frombits(0xc5801000), out: 0xec00}, // in f32=-4098.000000, out f16=-4096
|
||||
{in: math.Float32frombits(0xc5800800), out: 0xec00}, // in f32=-4097.000000, out f16=-4096
|
||||
{in: math.Float32frombits(0xc5800000), out: 0xec00}, // in f32=-4096.000000, out f16=-4096
|
||||
{in: math.Float32frombits(0xc57ff000), out: 0xec00}, // in f32=-4095.000000, out f16=-4096
|
||||
{in: math.Float32frombits(0xc57fe000), out: 0xebff}, // in f32=-4094.000000, out f16=-4094
|
||||
{in: math.Float32frombits(0xc57fd000), out: 0xebfe}, // in f32=-4093.000000, out f16=-4092
|
||||
{in: math.Float32frombits(0xc5002000), out: 0xe801}, // in f32=-2050.000000, out f16=-2050
|
||||
{in: math.Float32frombits(0xc5001000), out: 0xe800}, // in f32=-2049.000000, out f16=-2048
|
||||
{in: math.Float32frombits(0xc5000829), out: 0xe800}, // in f32=-2048.510010, out f16=-2048
|
||||
{in: math.Float32frombits(0xc5000800), out: 0xe800}, // in f32=-2048.500000, out f16=-2048
|
||||
{in: math.Float32frombits(0xc50007d7), out: 0xe800}, // in f32=-2048.489990, out f16=-2048
|
||||
{in: math.Float32frombits(0xc5000000), out: 0xe800}, // in f32=-2048.000000, out f16=-2048
|
||||
{in: math.Float32frombits(0xc4fff052), out: 0xe800}, // in f32=-2047.510010, out f16=-2048
|
||||
{in: math.Float32frombits(0xc4fff000), out: 0xe800}, // in f32=-2047.500000, out f16=-2048
|
||||
{in: math.Float32frombits(0xc4ffefae), out: 0xe7ff}, // in f32=-2047.489990, out f16=-2047
|
||||
{in: math.Float32frombits(0xc4ffe000), out: 0xe7ff}, // in f32=-2047.000000, out f16=-2047
|
||||
{in: math.Float32frombits(0xc4ffc000), out: 0xe7fe}, // in f32=-2046.000000, out f16=-2046
|
||||
{in: math.Float32frombits(0xc4ffa000), out: 0xe7fd}, // in f32=-2045.000000, out f16=-2045
|
||||
{in: math.Float32frombits(0xbf800000), out: 0xbc00}, // in f32=-1.000000, out f16=-1
|
||||
{in: math.Float32frombits(0xbf028f5c), out: 0xb814}, // in f32=-0.510000, out f16=-0.5097656
|
||||
{in: math.Float32frombits(0xbf000000), out: 0xb800}, // in f32=-0.500000, out f16=-0.5
|
||||
{in: math.Float32frombits(0xbefae148), out: 0xb7d7}, // in f32=-0.490000, out f16=-0.48999023
|
||||
{in: math.Float32frombits(0x3efae148), out: 0x37d7}, // in f32=0.490000, out f16=0.48999023
|
||||
{in: math.Float32frombits(0x3f000000), out: 0x3800}, // in f32=0.500000, out f16=0.5
|
||||
{in: math.Float32frombits(0x3f028f5c), out: 0x3814}, // in f32=0.510000, out f16=0.5097656
|
||||
{in: math.Float32frombits(0x3f800000), out: 0x3c00}, // in f32=1.000000, out f16=1
|
||||
{in: math.Float32frombits(0x3fbeb852), out: 0x3df6}, // in f32=1.490000, out f16=1.4902344
|
||||
{in: math.Float32frombits(0x3fc00000), out: 0x3e00}, // in f32=1.500000, out f16=1.5
|
||||
{in: math.Float32frombits(0x3fc147ae), out: 0x3e0a}, // in f32=1.510000, out f16=1.5097656
|
||||
{in: math.Float32frombits(0x3fcf1bbd), out: 0x3e79}, // in f32=1.618034, out f16=1.6181641
|
||||
{in: math.Float32frombits(0x401f5c29), out: 0x40fb}, // in f32=2.490000, out f16=2.4902344
|
||||
{in: math.Float32frombits(0x40200000), out: 0x4100}, // in f32=2.500000, out f16=2.5
|
||||
{in: math.Float32frombits(0x4020a3d7), out: 0x4105}, // in f32=2.510000, out f16=2.5097656
|
||||
{in: math.Float32frombits(0x402df854), out: 0x4170}, // in f32=2.718282, out f16=2.71875
|
||||
{in: math.Float32frombits(0x40490fdb), out: 0x4248}, // in f32=3.141593, out f16=3.140625
|
||||
{in: math.Float32frombits(0x40b00000), out: 0x4580}, // in f32=5.500000, out f16=5.5
|
||||
{in: math.Float32frombits(0x44ffa000), out: 0x67fd}, // in f32=2045.000000, out f16=2045
|
||||
{in: math.Float32frombits(0x44ffc000), out: 0x67fe}, // in f32=2046.000000, out f16=2046
|
||||
{in: math.Float32frombits(0x44ffe000), out: 0x67ff}, // in f32=2047.000000, out f16=2047
|
||||
{in: math.Float32frombits(0x44ffefae), out: 0x67ff}, // in f32=2047.489990, out f16=2047
|
||||
{in: math.Float32frombits(0x44fff000), out: 0x6800}, // in f32=2047.500000, out f16=2048
|
||||
{in: math.Float32frombits(0x44fff052), out: 0x6800}, // in f32=2047.510010, out f16=2048
|
||||
{in: math.Float32frombits(0x45000000), out: 0x6800}, // in f32=2048.000000, out f16=2048
|
||||
{in: math.Float32frombits(0x450007d7), out: 0x6800}, // in f32=2048.489990, out f16=2048
|
||||
{in: math.Float32frombits(0x45000800), out: 0x6800}, // in f32=2048.500000, out f16=2048
|
||||
{in: math.Float32frombits(0x45000829), out: 0x6800}, // in f32=2048.510010, out f16=2048
|
||||
{in: math.Float32frombits(0x45001000), out: 0x6800}, // in f32=2049.000000, out f16=2048
|
||||
{in: math.Float32frombits(0x450017d7), out: 0x6801}, // in f32=2049.489990, out f16=2050
|
||||
{in: math.Float32frombits(0x45001800), out: 0x6801}, // in f32=2049.500000, out f16=2050
|
||||
{in: math.Float32frombits(0x45001829), out: 0x6801}, // in f32=2049.510010, out f16=2050
|
||||
{in: math.Float32frombits(0x45002000), out: 0x6801}, // in f32=2050.000000, out f16=2050
|
||||
{in: math.Float32frombits(0x45003000), out: 0x6802}, // in f32=2051.000000, out f16=2052
|
||||
{in: math.Float32frombits(0x457fd000), out: 0x6bfe}, // in f32=4093.000000, out f16=4092
|
||||
{in: math.Float32frombits(0x457fe000), out: 0x6bff}, // in f32=4094.000000, out f16=4094
|
||||
{in: math.Float32frombits(0x457ff000), out: 0x6c00}, // in f32=4095.000000, out f16=4096
|
||||
{in: math.Float32frombits(0x45800000), out: 0x6c00}, // in f32=4096.000000, out f16=4096
|
||||
{in: math.Float32frombits(0x45800800), out: 0x6c00}, // in f32=4097.000000, out f16=4096
|
||||
{in: math.Float32frombits(0x45801000), out: 0x6c00}, // in f32=4098.000000, out f16=4096
|
||||
{in: math.Float32frombits(0x45801800), out: 0x6c01}, // in f32=4099.000000, out f16=4100
|
||||
{in: math.Float32frombits(0x45802000), out: 0x6c01}, // in f32=4100.000000, out f16=4100
|
||||
{in: math.Float32frombits(0x45ad9c00), out: 0x6d6d}, // in f32=5555.500000, out f16=5556
|
||||
{in: math.Float32frombits(0x45ffe800), out: 0x6fff}, // in f32=8189.000000, out f16=8188
|
||||
{in: math.Float32frombits(0x45fff000), out: 0x7000}, // in f32=8190.000000, out f16=8192
|
||||
{in: math.Float32frombits(0x45fff800), out: 0x7000}, // in f32=8191.000000, out f16=8192
|
||||
{in: math.Float32frombits(0x46000000), out: 0x7000}, // in f32=8192.000000, out f16=8192
|
||||
{in: math.Float32frombits(0x46000400), out: 0x7000}, // in f32=8193.000000, out f16=8192
|
||||
{in: math.Float32frombits(0x46000800), out: 0x7000}, // in f32=8194.000000, out f16=8192
|
||||
{in: math.Float32frombits(0x46000c00), out: 0x7000}, // in f32=8195.000000, out f16=8192
|
||||
{in: math.Float32frombits(0x46001000), out: 0x7000}, // in f32=8196.000000, out f16=8192
|
||||
{in: math.Float32frombits(0x46001400), out: 0x7001}, // in f32=8197.000000, out f16=8200
|
||||
{in: math.Float32frombits(0x46001800), out: 0x7001}, // in f32=8198.000000, out f16=8200
|
||||
{in: math.Float32frombits(0x46001c00), out: 0x7001}, // in f32=8199.000000, out f16=8200
|
||||
{in: math.Float32frombits(0x46002000), out: 0x7001}, // in f32=8200.000000, out f16=8200
|
||||
{in: math.Float32frombits(0x46002400), out: 0x7001}, // in f32=8201.000000, out f16=8200
|
||||
{in: math.Float32frombits(0x46002800), out: 0x7001}, // in f32=8202.000000, out f16=8200
|
||||
{in: math.Float32frombits(0x46002c00), out: 0x7001}, // in f32=8203.000000, out f16=8200
|
||||
{in: math.Float32frombits(0x46003000), out: 0x7002}, // in f32=8204.000000, out f16=8208
|
||||
{in: math.Float32frombits(0x467fec00), out: 0x73ff}, // in f32=16379.000000, out f16=16376
|
||||
{in: math.Float32frombits(0x467ff000), out: 0x7400}, // in f32=16380.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x467ff400), out: 0x7400}, // in f32=16381.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x467ff800), out: 0x7400}, // in f32=16382.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x467ffc00), out: 0x7400}, // in f32=16383.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800000), out: 0x7400}, // in f32=16384.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800200), out: 0x7400}, // in f32=16385.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800400), out: 0x7400}, // in f32=16386.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800600), out: 0x7400}, // in f32=16387.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800800), out: 0x7400}, // in f32=16388.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800a00), out: 0x7400}, // in f32=16389.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800c00), out: 0x7400}, // in f32=16390.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46800e00), out: 0x7400}, // in f32=16391.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46801000), out: 0x7400}, // in f32=16392.000000, out f16=16384
|
||||
{in: math.Float32frombits(0x46801200), out: 0x7401}, // in f32=16393.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46801400), out: 0x7401}, // in f32=16394.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46801600), out: 0x7401}, // in f32=16395.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46801800), out: 0x7401}, // in f32=16396.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46801a00), out: 0x7401}, // in f32=16397.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46801c00), out: 0x7401}, // in f32=16398.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46801e00), out: 0x7401}, // in f32=16399.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802000), out: 0x7401}, // in f32=16400.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802200), out: 0x7401}, // in f32=16401.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802400), out: 0x7401}, // in f32=16402.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802600), out: 0x7401}, // in f32=16403.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802800), out: 0x7401}, // in f32=16404.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802a00), out: 0x7401}, // in f32=16405.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802c00), out: 0x7401}, // in f32=16406.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46802e00), out: 0x7401}, // in f32=16407.000000, out f16=16400
|
||||
{in: math.Float32frombits(0x46803000), out: 0x7402}, // in f32=16408.000000, out f16=16416
|
||||
{in: math.Float32frombits(0x46ffee00), out: 0x77ff}, // in f32=32759.000000, out f16=32752
|
||||
{in: math.Float32frombits(0x46fff000), out: 0x7800}, // in f32=32760.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x46fff200), out: 0x7800}, // in f32=32761.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x46fff400), out: 0x7800}, // in f32=32762.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x46fff600), out: 0x7800}, // in f32=32763.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x46fff800), out: 0x7800}, // in f32=32764.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x46fffa00), out: 0x7800}, // in f32=32765.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x46fffc00), out: 0x7800}, // in f32=32766.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x46fffe00), out: 0x7800}, // in f32=32767.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000000), out: 0x7800}, // in f32=32768.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000100), out: 0x7800}, // in f32=32769.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000200), out: 0x7800}, // in f32=32770.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000300), out: 0x7800}, // in f32=32771.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000400), out: 0x7800}, // in f32=32772.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000500), out: 0x7800}, // in f32=32773.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000600), out: 0x7800}, // in f32=32774.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000700), out: 0x7800}, // in f32=32775.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000800), out: 0x7800}, // in f32=32776.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000900), out: 0x7800}, // in f32=32777.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000a00), out: 0x7800}, // in f32=32778.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000b00), out: 0x7800}, // in f32=32779.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000c00), out: 0x7800}, // in f32=32780.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000d00), out: 0x7800}, // in f32=32781.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000e00), out: 0x7800}, // in f32=32782.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47000f00), out: 0x7800}, // in f32=32783.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47001000), out: 0x7800}, // in f32=32784.000000, out f16=32768
|
||||
{in: math.Float32frombits(0x47001100), out: 0x7801}, // in f32=32785.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001200), out: 0x7801}, // in f32=32786.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001300), out: 0x7801}, // in f32=32787.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001400), out: 0x7801}, // in f32=32788.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001500), out: 0x7801}, // in f32=32789.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001600), out: 0x7801}, // in f32=32790.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001700), out: 0x7801}, // in f32=32791.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001800), out: 0x7801}, // in f32=32792.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001900), out: 0x7801}, // in f32=32793.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001a00), out: 0x7801}, // in f32=32794.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001b00), out: 0x7801}, // in f32=32795.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001c00), out: 0x7801}, // in f32=32796.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001d00), out: 0x7801}, // in f32=32797.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001e00), out: 0x7801}, // in f32=32798.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47001f00), out: 0x7801}, // in f32=32799.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002000), out: 0x7801}, // in f32=32800.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002100), out: 0x7801}, // in f32=32801.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002200), out: 0x7801}, // in f32=32802.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002300), out: 0x7801}, // in f32=32803.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002400), out: 0x7801}, // in f32=32804.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002500), out: 0x7801}, // in f32=32805.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002600), out: 0x7801}, // in f32=32806.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002700), out: 0x7801}, // in f32=32807.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002800), out: 0x7801}, // in f32=32808.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002900), out: 0x7801}, // in f32=32809.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002a00), out: 0x7801}, // in f32=32810.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002b00), out: 0x7801}, // in f32=32811.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002c00), out: 0x7801}, // in f32=32812.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002d00), out: 0x7801}, // in f32=32813.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002e00), out: 0x7801}, // in f32=32814.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47002f00), out: 0x7801}, // in f32=32815.000000, out f16=32800
|
||||
{in: math.Float32frombits(0x47003000), out: 0x7802}, // in f32=32816.000000, out f16=32832
|
||||
{in: math.Float32frombits(0x477fe500), out: 0x7bff}, // in f32=65509.000000, out f16=65504
|
||||
{in: math.Float32frombits(0x477fe100), out: 0x7bff}, // in f32=65505.000000, out f16=65504
|
||||
{in: math.Float32frombits(0x477fee00), out: 0x7bff}, // in f32=65518.000000, out f16=65504
|
||||
{in: math.Float32frombits(0x477fef00), out: 0x7bff}, // in f32=65519.000000, out f16=65504
|
||||
{in: math.Float32frombits(0x477feffd), out: 0x7bff}, // in f32=65519.988281, out f16=65504
|
||||
{in: math.Float32frombits(0x477ff000), out: 0x7c00}, // in f32=65520.000000, out f16=+Inf
|
||||
}
|
||||
|
||||
func TestPrecisionFromfloat32(t *testing.T) {
|
||||
for i, v := range wantF32toF16bits {
|
||||
f16 := float16.Fromfloat32(v.in)
|
||||
u16 := uint16(f16)
|
||||
|
||||
if u16 != v.out {
|
||||
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
|
||||
}
|
||||
|
||||
checkPrecision(t, v.in, f16, uint64(i))
|
||||
}
|
||||
|
||||
f32 := float32(5.5) // value that doesn't drop any bits in the significand, is within normal exponent range
|
||||
pre := float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionExact {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionExact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionExact, pre)
|
||||
}
|
||||
|
||||
f32 = math.Float32frombits(0x38000000) // subnormal value with coef = 0 that can round-trip float32->float16->float32
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionUnknown {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
|
||||
}
|
||||
|
||||
f32 = math.Float32frombits(0x387fc000) // subnormal value with coef !=0 that can round-trip float32->float16->float32
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionUnknown {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
|
||||
}
|
||||
|
||||
f32 = math.Float32frombits(0x33c00000) // subnormal value with no dropped bits that cannot round-trip float32->float16->float32
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionUnknown {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnknown (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnknown, pre)
|
||||
}
|
||||
|
||||
f32 = math.Float32frombits(0x38000001) // subnormal value with dropped non-zero bits > 0
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionInexact {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
|
||||
}
|
||||
|
||||
f32 = float32(math.Pi) // value that cannot "preserve value" because it drops bits in the significand
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionInexact {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionInexact (%d), got=%d.", math.Float32bits(f32), float16.PrecisionInexact, pre)
|
||||
}
|
||||
|
||||
f32 = math.Float32frombits(0x1) // value that will underflow
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionUnderflow {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
|
||||
}
|
||||
|
||||
f32 = math.Float32frombits(0x33000000) // value that will underflow
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionUnderflow {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionUnderflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionUnderflow, pre)
|
||||
}
|
||||
|
||||
f32 = math.Float32frombits(0x47800000) // value that will overflow
|
||||
pre = float16.PrecisionFromfloat32(f32)
|
||||
if pre != float16.PrecisionOverflow {
|
||||
t.Errorf("f32bits=0x%08x, wanted=PrecisionOverflow (%d), got=%d.", math.Float32bits(f32), float16.PrecisionOverflow, pre)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFromNaN32ps(t *testing.T) {
|
||||
for i, v := range wantF32toF16bits {
|
||||
f16 := float16.Fromfloat32(v.in)
|
||||
u16 := uint16(f16)
|
||||
|
||||
if u16 != v.out {
|
||||
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
|
||||
}
|
||||
|
||||
checkFromNaN32ps(t, v.in, f16)
|
||||
}
|
||||
|
||||
// since checkFromNaN32ps rejects non-NaN input, try one here
|
||||
nan, err := float16.FromNaN32ps(float32(math.Pi))
|
||||
if err != float16.ErrInvalidNaNValue {
|
||||
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted err float16.ErrInvalidNaNValue, got err = %q", err)
|
||||
}
|
||||
if err.Error() != "float16: invalid NaN value, expected IEEE 754 NaN" {
|
||||
t.Errorf("unexpected string value returned by err.Error() for ErrInvalidNaNValue: %s", err.Error())
|
||||
}
|
||||
if uint16(nan) != 0x7c01 { // signaling NaN
|
||||
t.Errorf("FromNaN32ps: in float32(math.Pi) wanted nan = 0x7c01, got nan = 0x%04x", uint16(nan))
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Test a small subset of possible conversions from float32 to Float16.
|
||||
// TestSomeFromFloat32 runs in under 1 second while TestAllFromFloat32 takes about 45 seconds.
|
||||
func TestSomeFromFloat32(t *testing.T) {
|
||||
|
||||
for i, v := range wantF32toF16bits {
|
||||
f16 := float16.Fromfloat32(v.in)
|
||||
u16 := uint16(f16)
|
||||
|
||||
if u16 != v.out {
|
||||
t.Errorf("i=%d, in f32bits=0x%08x, wanted=0x%04x, got=0x%04x.", i, math.Float32bits(v.in), v.out, u16)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test all possible 4294967296 float32 input values and results for
|
||||
// Fromfloat32(), FromNaN32ps(), and PrecisionFromfloat32().
|
||||
func TestAllFromFloat32(t *testing.T) {
|
||||
|
||||
if testing.Short() {
|
||||
t.Skip("skipping TestAllFromFloat32 in short mode.")
|
||||
}
|
||||
|
||||
fmt.Printf("WARNING: TestAllFromFloat32 should take about 1-2 minutes to run on amd64, other platforms may take longer...\n")
|
||||
|
||||
// Blake2b is "3f310bc5608a087462d361644fe66feeb4c68145f6f18eb6f1439cd7914888b6df9e30ae5350dce0635162cc6a2f23b31b3e4353ca132a3c552bdbd58baa54e6"
|
||||
const wantSHA512 = "08670429a475164d6c4a080969e35231c77ef7069b430b5f38af22e013796b7818bbe8f5942a6ddf26de0e1dfc67d02243f483d85729ebc3762fc2948a5ca1f8"
|
||||
|
||||
const batchSize uint32 = 16384
|
||||
results := make([]uint16, batchSize)
|
||||
buf := new(bytes.Buffer)
|
||||
h := sha512.New()
|
||||
|
||||
for i := uint64(0); i < uint64(0xFFFFFFFF); i += uint64(batchSize) {
|
||||
// fill results
|
||||
for j := uint32(0); j < batchSize; j++ {
|
||||
inF32 := math.Float32frombits(uint32(i) + j)
|
||||
f16 := float16.Fromfloat32(inF32)
|
||||
results[j] = uint16(f16)
|
||||
checkPrecision(t, inF32, f16, i)
|
||||
checkFromNaN32ps(t, inF32, f16)
|
||||
}
|
||||
|
||||
// convert results to []byte
|
||||
err := binary.Write(buf, binary.LittleEndian, results)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// update hash with []byte of results
|
||||
_, err = h.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
}
|
||||
|
||||
// display hash digest in hex
|
||||
digest := h.Sum(nil)
|
||||
gotSHA512hex := hex.EncodeToString(digest)
|
||||
if gotSHA512hex != wantSHA512 {
|
||||
t.Errorf("gotSHA512hex = %s", gotSHA512hex)
|
||||
}
|
||||
}
|
||||
|
||||
// Test all 65536 conversions from float16 to float32.
|
||||
// TestAllToFloat32 runs in under 1 second.
|
||||
func TestAllToFloat32(t *testing.T) {
|
||||
// Blake2b is "078d8e3fac9480de1493f22c8f9bfc1eb2051537c536f00f621557d70eed1af057a487c3e252f6d593769f5288d5ab66d8e9cd1adba359838802944bdb731f4d"
|
||||
const wantSHA512 = "1a4ccec9fd7b6e83310c6b4958a25778cd95f8d4f88b19950e4b8d6932a955f7fbd96b1c9bd9b2a79c3a9d34d653f55e671f8f86e6a5a876660cd38479001aa6"
|
||||
const batchSize uint32 = 16384
|
||||
results := make([]float32, batchSize)
|
||||
buf := new(bytes.Buffer)
|
||||
h := sha512.New()
|
||||
|
||||
for i := uint64(0); i < uint64(0xFFFF); i += uint64(batchSize) {
|
||||
// fill results
|
||||
for j := uint32(0); j < batchSize; j++ {
|
||||
inU16 := uint16(i) + uint16(j)
|
||||
f16 := float16.Float16(inU16)
|
||||
results[j] = f16.Float32()
|
||||
}
|
||||
|
||||
// convert results to []byte
|
||||
err := binary.Write(buf, binary.LittleEndian, results)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// update hash with []byte of results
|
||||
_, err = h.Write(buf.Bytes())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
buf.Reset()
|
||||
}
|
||||
|
||||
// display hash digest in hex
|
||||
digest := h.Sum(nil)
|
||||
gotSHA512hex := hex.EncodeToString(digest)
|
||||
if gotSHA512hex != wantSHA512 {
|
||||
t.Errorf("Float16toFloat32: gotSHA512hex = %s", gotSHA512hex)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestFrombits(t *testing.T) {
|
||||
x := uint16(0x1234)
|
||||
f16 := float16.Frombits(x)
|
||||
if uint16(f16) != f16.Bits() || uint16(f16) != x {
|
||||
t.Errorf("float16.Frombits(0x7fff) returned %04x, wanted %04x", uint16(f16), x)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNaN(t *testing.T) {
|
||||
nan := float16.NaN()
|
||||
if !nan.IsNaN() {
|
||||
t.Errorf("nan.IsNaN() returned false, wanted true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInf(t *testing.T) {
|
||||
posInf := float16.Inf(0)
|
||||
if uint16(posInf) != 0x7c00 {
|
||||
t.Errorf("float16.Inf(0) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
|
||||
}
|
||||
|
||||
posInf = float16.Inf(1)
|
||||
if uint16(posInf) != 0x7c00 {
|
||||
t.Errorf("float16.Inf(1) returned %04x, wanted %04x", uint16(posInf), 0x7c00)
|
||||
}
|
||||
|
||||
negInf := float16.Inf(-1)
|
||||
if uint16(negInf) != 0xfc00 {
|
||||
t.Errorf("float16.Inf(-1) returned %04x, wanted %04x", uint16(negInf), 0xfc00)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBits(t *testing.T) {
|
||||
x := uint16(0x1234)
|
||||
f16 := float16.Frombits(x)
|
||||
if uint16(f16) != f16.Bits() || f16.Bits() != x {
|
||||
t.Errorf("Bits() returned %04x, wanted %04x", uint16(f16), x)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsFinite(t *testing.T) {
|
||||
// IsFinite returns true if f is neither infinite nor NaN.
|
||||
|
||||
finite := float16.Fromfloat32(float32(1.5))
|
||||
if !finite.IsFinite() {
|
||||
t.Errorf("finite.Infinite() returned false, wanted true")
|
||||
}
|
||||
|
||||
posInf := float16.Inf(0)
|
||||
if posInf.IsFinite() {
|
||||
t.Errorf("posInf.Infinite() returned true, wanted false")
|
||||
}
|
||||
|
||||
negInf := float16.Inf(-1)
|
||||
if negInf.IsFinite() {
|
||||
t.Errorf("negInf.Infinite() returned true, wanted false")
|
||||
}
|
||||
|
||||
nan := float16.NaN()
|
||||
if nan.IsFinite() {
|
||||
t.Errorf("nan.Infinite() returned true, wanted false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNaN(t *testing.T) {
|
||||
|
||||
f16 := float16.Float16(0)
|
||||
if f16.IsNaN() {
|
||||
t.Errorf("Float16(0).IsNaN() returned true, wanted false")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0x7e00)
|
||||
if !f16.IsNaN() {
|
||||
t.Errorf("Float16(0x7e00).IsNaN() returned false, wanted true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsQuietNaN(t *testing.T) {
|
||||
|
||||
f16 := float16.Float16(0)
|
||||
if f16.IsQuietNaN() {
|
||||
t.Errorf("Float16(0).IsQuietNaN() returned true, wanted false")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0x7e00)
|
||||
if !f16.IsQuietNaN() {
|
||||
t.Errorf("Float16(0x7e00).IsQuietNaN() returned false, wanted true")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0x7e00 ^ 0x0200)
|
||||
if f16.IsQuietNaN() {
|
||||
t.Errorf("Float16(0x7e00 ^ 0x0200).IsQuietNaN() returned true, wanted false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNormal(t *testing.T) {
|
||||
// IsNormal returns true if f is neither zero, infinite, subnormal, or NaN.
|
||||
|
||||
zero := float16.Frombits(0)
|
||||
if zero.IsNormal() {
|
||||
t.Errorf("zero.IsNormal() returned true, wanted false")
|
||||
}
|
||||
|
||||
posInf := float16.Inf(0)
|
||||
if posInf.IsNormal() {
|
||||
t.Errorf("posInf.IsNormal() returned true, wanted false")
|
||||
}
|
||||
|
||||
negInf := float16.Inf(-1)
|
||||
if negInf.IsNormal() {
|
||||
t.Errorf("negInf.IsNormal() returned true, wanted false")
|
||||
}
|
||||
|
||||
nan := float16.NaN()
|
||||
if nan.IsNormal() {
|
||||
t.Errorf("nan.IsNormal() returned true, wanted false")
|
||||
}
|
||||
|
||||
subnormal := float16.Frombits(0x0001)
|
||||
if subnormal.IsNormal() {
|
||||
t.Errorf("subnormal.IsNormal() returned true, wanted false")
|
||||
}
|
||||
|
||||
normal := float16.Fromfloat32(float32(1.5))
|
||||
if !normal.IsNormal() {
|
||||
t.Errorf("normal.IsNormal() returned false, wanted true")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestSignbit(t *testing.T) {
|
||||
|
||||
f16 := float16.Fromfloat32(float32(0.0))
|
||||
if f16.Signbit() {
|
||||
t.Errorf("float16.Fromfloat32(float32(0)).Signbit() returned true, wanted false")
|
||||
}
|
||||
|
||||
f16 = float16.Fromfloat32(float32(2.0))
|
||||
if f16.Signbit() {
|
||||
t.Errorf("float16.Fromfloat32(float32(2)).Signbit() returned true, wanted false")
|
||||
}
|
||||
|
||||
f16 = float16.Fromfloat32(float32(-2.0))
|
||||
if !f16.Signbit() {
|
||||
t.Errorf("float16.Fromfloat32(float32(-2)).Signbit() returned false, wanted true")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestString(t *testing.T) {
|
||||
f16 := float16.Fromfloat32(1.5)
|
||||
s := f16.String()
|
||||
if s != "1.5" {
|
||||
t.Errorf("Float16(1.5).String() returned %s, wanted 1.5", s)
|
||||
}
|
||||
|
||||
f16 = float16.Fromfloat32(3.141593)
|
||||
s = f16.String()
|
||||
if s != "3.140625" {
|
||||
t.Errorf("Float16(3.141593).String() returned %s, wanted 3.140625", s)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestIsInf(t *testing.T) {
|
||||
|
||||
f16 := float16.Float16(0)
|
||||
if f16.IsInf(0) {
|
||||
t.Errorf("Float16(0).IsInf(0) returned true, wanted false")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0x7c00)
|
||||
if !f16.IsInf(0) {
|
||||
t.Errorf("Float16(0x7c00).IsInf(0) returned false, wanted true")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0x7c00)
|
||||
if !f16.IsInf(1) {
|
||||
t.Errorf("Float16(0x7c00).IsInf(1) returned false, wanted true")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0x7c00)
|
||||
if f16.IsInf(-1) {
|
||||
t.Errorf("Float16(0x7c00).IsInf(-1) returned true, wanted false")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0xfc00)
|
||||
if !f16.IsInf(0) {
|
||||
t.Errorf("Float16(0xfc00).IsInf(0) returned false, wanted true")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0xfc00)
|
||||
if f16.IsInf(1) {
|
||||
t.Errorf("Float16(0xfc00).IsInf(1) returned true, wanted false")
|
||||
}
|
||||
|
||||
f16 = float16.Float16(0xfc00)
|
||||
if !f16.IsInf(-1) {
|
||||
t.Errorf("Float16(0xfc00).IsInf(-1) returned false, wanted true")
|
||||
}
|
||||
}
|
||||
|
||||
func float32parts(f32 float32) (exp int32, coef uint32, dropped uint32) {
|
||||
const COEFMASK uint32 = 0x7fffff // 23 least significant bits
|
||||
const EXPSHIFT uint32 = 23
|
||||
const EXPBIAS uint32 = 127
|
||||
const EXPMASK uint32 = uint32(0xff) << EXPSHIFT
|
||||
const DROPMASK uint32 = COEFMASK >> 10
|
||||
u32 := math.Float32bits(f32)
|
||||
exp = int32(((u32 & EXPMASK) >> EXPSHIFT) - EXPBIAS)
|
||||
coef = u32 & COEFMASK
|
||||
dropped = coef & DROPMASK
|
||||
return exp, coef, dropped
|
||||
}
|
||||
|
||||
func isNaN32(f32 float32) bool {
|
||||
exp, coef, _ := float32parts(f32)
|
||||
return (exp == 128) && (coef != 0)
|
||||
}
|
||||
|
||||
func isQuietNaN32(f32 float32) bool {
|
||||
exp, coef, _ := float32parts(f32)
|
||||
return (exp == 128) && (coef != 0) && ((coef & 0x00400000) != 0)
|
||||
}
|
||||
|
||||
func checkFromNaN32ps(t *testing.T, f32 float32, f16 float16.Float16) {
|
||||
|
||||
if !isNaN32(f32) {
|
||||
return
|
||||
}
|
||||
|
||||
u32 := math.Float32bits(f32)
|
||||
nan16, err := float16.FromNaN32ps(f32)
|
||||
|
||||
if isQuietNaN32(f32) {
|
||||
// result should be the same
|
||||
if err != nil {
|
||||
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
|
||||
}
|
||||
if uint16(nan16) != uint16(f16) {
|
||||
t.Errorf("FromNaN32ps: qnan = 0x%08x (%f) wanted nan16 = %v, got nan16 = %v", u32, f32, f16, nan16)
|
||||
}
|
||||
} else {
|
||||
// result should differ only by the signaling/quiet bit unless payload is empty
|
||||
if err != nil {
|
||||
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted err = nil, got err = %q", u32, f32, err)
|
||||
}
|
||||
|
||||
coef := uint16(f16) & uint16(0x03ff)
|
||||
payload := uint16(f16) & uint16(0x01ff)
|
||||
diff := uint16(nan16 ^ f16)
|
||||
|
||||
if payload == 0 {
|
||||
// the lowest bit needed to be set to prevent turning sNaN into infinity, so 2 bits differ
|
||||
if diff != 0x0201 {
|
||||
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0201, got 0x%04x", u32, f32, diff)
|
||||
}
|
||||
} else {
|
||||
// only the quiet bit was restored, so 1 bit differs
|
||||
if diff != 0x0200 {
|
||||
t.Errorf("FromNaN32ps: snan = 0x%08x (%f) wanted diff == 0x0200, got 0x%04x. f16=0x%04x n16=0x%04x coef=0x%04x", u32, f32, diff, uint16(f16), uint16(nan16), coef)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkPrecision(t *testing.T, f32 float32, f16 float16.Float16, i uint64) {
|
||||
// TODO: rewrite this test when time allows
|
||||
|
||||
u32 := math.Float32bits(f32)
|
||||
u16 := f16.Bits()
|
||||
f32bis := f16.Float32()
|
||||
u32bis := math.Float32bits(f32bis)
|
||||
pre := float16.PrecisionFromfloat32(f32)
|
||||
roundtripped := u32 == u32bis
|
||||
exp32, coef32, dropped32 := float32parts(f32)
|
||||
|
||||
if roundtripped {
|
||||
checkRoundTrippedPrecision(t, u32, u16, u32bis, exp32, coef32, dropped32)
|
||||
return
|
||||
}
|
||||
|
||||
if pre == float16.PrecisionExact {
|
||||
// this should only happen if both input and output are NaN
|
||||
if !(f16.IsNaN() && isNaN32(f32)) {
|
||||
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionExact when roundtrip failed with non-special value", i, u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
|
||||
} else if pre == float16.PrecisionUnknown {
|
||||
if exp32 < -24 {
|
||||
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionUnderflow", i, u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
if dropped32 != 0 {
|
||||
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnknown, wanted PrecisionInexact", i, u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
} else if pre == float16.PrecisionInexact {
|
||||
checkPrecisionInexact(t, u32, u16, u32bis, exp32, coef32, dropped32)
|
||||
} else if pre == float16.PrecisionUnderflow {
|
||||
if exp32 >= -14 {
|
||||
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionUnderflow when exp32 is >= -14", i, u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
} else if pre == float16.PrecisionOverflow {
|
||||
if exp32 <= 15 {
|
||||
t.Errorf("i=%d, PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionOverflow when exp32 is <= 15", i, u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func checkPrecisionInexact(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
|
||||
f32 := math.Float32frombits(u32)
|
||||
f32bis := math.Float32frombits(u32bis)
|
||||
|
||||
if exp32 < -24 {
|
||||
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionUnderflow", u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
if exp32 > 15 {
|
||||
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact, wanted PrecisionOverflow", u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
if coef32 == 0 {
|
||||
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when coef32 is 0", u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
if dropped32 == 0 {
|
||||
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), got PrecisionInexact when dropped32 is 0", u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
}
|
||||
|
||||
func checkRoundTrippedPrecision(t *testing.T, u32 uint32, u16 uint16, u32bis uint32, exp32 int32, coef32 uint32, dropped32 uint32) {
|
||||
f32 := math.Float32frombits(u32)
|
||||
f32bis := math.Float32frombits(u32bis)
|
||||
pre := float16.PrecisionFromfloat32(f32)
|
||||
f16 := float16.Frombits(u16)
|
||||
|
||||
if dropped32 != 0 {
|
||||
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%f), out f16bits=0x%04x, back=0x%08x (%f), dropped32 != 0 with successful roundtrip", u32, f32, u16, u32bis, f32bis)
|
||||
}
|
||||
|
||||
if pre != float16.PrecisionExact {
|
||||
// there are 2046 values that are subnormal and can round-trip float32->float16->float32
|
||||
if pre != float16.PrecisionUnknown {
|
||||
t.Errorf("PrecisionFromfloat32 in f32bits=0x%08x (%032b) (%f), out f16bits=0x%04x (%v), back=0x%08x (%f), got %v, wanted PrecisionExact, exp=%d, coef=%d, drpd=%d", u32, u32, f32, u16, f16, u32bis, f32bis, pre, exp32, coef32, dropped32)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
12
init.go
12
init.go
|
@ -4,11 +4,14 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
var (
|
||||
CachedDir string = "NOT_SETTING"
|
||||
gotchEnvKey string = "GOTCH_CACHE"
|
||||
CachedDir string = "NOT_SETTING"
|
||||
gotchEnvKey string = "GOTCH_CACHE"
|
||||
gotchDebugKey string = "GOTCH_DEBUG"
|
||||
Debug bool = false
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -16,10 +19,13 @@ func init() {
|
|||
CachedDir = fmt.Sprintf("%s/.cache/gotch", homeDir) // default dir: "{$HOME}/.cache/gotch"
|
||||
|
||||
initEnv()
|
||||
// log.Printf("INFO: CacheDir=%q\n", CacheDir)
|
||||
}
|
||||
|
||||
func initEnv() {
|
||||
if v, err := strconv.ParseBool(os.Getenv(gotchDebugKey)); err == nil {
|
||||
Debug = v
|
||||
}
|
||||
|
||||
val := os.Getenv(gotchEnvKey)
|
||||
if val != "" {
|
||||
CachedDir = val
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -4,10 +4,25 @@ package libtch
|
|||
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
/*
|
||||
bool is_null(int* pointer) {
|
||||
if (NULL == pointer) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import "unsafe"
|
||||
|
||||
func IsNull(ctensor Ctensor) bool {
|
||||
// return C.is_null(ctensor)
|
||||
ret := C.is_null((*C.int)(unsafe.Pointer(ctensor)))
|
||||
|
||||
return (bool)(ret)
|
||||
}
|
||||
|
||||
// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`:
|
||||
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||
|
|
|
@ -59,12 +59,6 @@ func NewTensor() Ctensor {
|
|||
return C.at_new_tensor()
|
||||
}
|
||||
|
||||
// int at_device(tensor);
|
||||
func AtDevice(ts Ctensor) int {
|
||||
cint := C.at_device(ts)
|
||||
return *(*int)(unsafe.Pointer(&cint))
|
||||
}
|
||||
|
||||
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
|
||||
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
|
||||
|
||||
|
@ -88,6 +82,30 @@ func AtDataPtr(t Ctensor) unsafe.Pointer {
|
|||
return C.at_data_ptr(t)
|
||||
}
|
||||
|
||||
// int at_defined(tensor);
|
||||
func AtDefined(ts Ctensor) bool {
|
||||
retVal := C.at_defined(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_mkldnn(tensor);
|
||||
func AtIsMkldnn(ts Ctensor) bool {
|
||||
retVal := C.at_is_mkldnn(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_sparse(tensor);
|
||||
func AtIsSparse(ts Ctensor) bool {
|
||||
retVal := C.at_is_sparse(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_device(tensor);
|
||||
func AtDevice(ts Ctensor) int {
|
||||
cint := C.at_device(ts)
|
||||
return *(*int)(unsafe.Pointer(&cint))
|
||||
}
|
||||
|
||||
// size_t at_dim(tensor);
|
||||
func AtDim(t Ctensor) uint64 {
|
||||
result := C.at_dim(t)
|
||||
|
@ -100,12 +118,24 @@ func AtShape(t Ctensor, ptr unsafe.Pointer) {
|
|||
C.at_shape(t, c_ptr)
|
||||
}
|
||||
|
||||
// void at_stride(tensor, int64_t *);
|
||||
func AtStride(t Ctensor, ptr unsafe.Pointer) {
|
||||
c_ptr := (*C.int64_t)(ptr)
|
||||
C.at_stride(t, c_ptr)
|
||||
}
|
||||
|
||||
// int at_scalar_type(tensor);
|
||||
func AtScalarType(t Ctensor) int32 {
|
||||
result := C.at_scalar_type(t)
|
||||
return *(*int32)(unsafe.Pointer(&result))
|
||||
}
|
||||
|
||||
// int at_is_contiguous(tensor);
|
||||
func AtIsContiguous(ts Ctensor) bool {
|
||||
retVal := C.at_is_contiguous(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
func GetAndResetLastErr() *C.char {
|
||||
return C.get_and_reset_last_err()
|
||||
}
|
||||
|
@ -134,6 +164,24 @@ func AtcSetBenchmarkCudnn(b int) {
|
|||
C.atc_set_benchmark_cudnn(cb)
|
||||
}
|
||||
|
||||
// void atc_synchronize(int64_t device_index);
|
||||
func AtcSynchronize(deviceIndex int64) {
|
||||
cDeviceIndex := *(*C.int64_t)(unsafe.Pointer(&deviceIndex))
|
||||
C.atc_synchronize(cDeviceIndex)
|
||||
}
|
||||
|
||||
// int atc_get_device();
|
||||
func AtcGetDevice() int {
|
||||
cDeviceIndex := C.atc_get_device()
|
||||
return int(cDeviceIndex)
|
||||
}
|
||||
|
||||
// int atc_set_device(int device_index);
|
||||
func AtcSetDevice(deviceIndex int) int {
|
||||
cDeviceIndex := C.int(deviceIndex)
|
||||
return int(cDeviceIndex)
|
||||
}
|
||||
|
||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
|
||||
ctensor := (C.tensor)(ts)
|
||||
|
@ -158,18 +206,6 @@ func AtRequiresGrad(ts Ctensor) bool {
|
|||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_defined(tensor);
|
||||
func AtDefined(ts Ctensor) bool {
|
||||
retVal := C.at_defined(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_is_sparse(tensor);
|
||||
func AtIsSparse(ts Ctensor) bool {
|
||||
retVal := C.at_is_sparse(ts)
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// void at_backward(tensor, int, int);
|
||||
func AtBackward(ts Ctensor, keepGraph int, createGraph int) {
|
||||
ckeepGraph := *(*C.int)(unsafe.Pointer(&keepGraph))
|
||||
|
@ -364,7 +400,7 @@ func AtFree(ts Ctensor) {
|
|||
C.at_free(ts)
|
||||
}
|
||||
|
||||
//int at_grad_set_enabled(int b);
|
||||
// int at_grad_set_enabled(int b);
|
||||
func AtGradSetEnabled(b int) int {
|
||||
cbool := *(*C.int)(unsafe.Pointer(&b))
|
||||
cretVal := C.at_grad_set_enabled(cbool)
|
||||
|
@ -870,3 +906,13 @@ func AtoConstantPadNd(ptr *Ctensor, self Ctensor, padData []int64, padLen int, v
|
|||
cpadLen := *(*C.int)(unsafe.Pointer(&padLen))
|
||||
C.ato_constant_pad_nd(ptr, self, cpadDataPtr, cpadLen, value)
|
||||
}
|
||||
|
||||
// // NOTE. TT. added to test new API generated
|
||||
// func AtgRandn1(sizeData []int64, sizeLen int, optionsKind int32, optionsDevice int32) Ctensor {
|
||||
// csizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
||||
// csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
||||
// coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||
// coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||
//
|
||||
// return C.atg_randn1(csizeDataPtr, csizeLen, coptionsKind, coptionsDevice)
|
||||
// }
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
#include<torch/csrc/autograd/engine.h>
|
||||
#include<torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include "torch_api.h"
|
||||
#include "ATen/core/interned_strings.h"
|
||||
#include <ATen/autocast_mode.h>
|
||||
#include <stdexcept>
|
||||
#include <torch/csrc/autograd/engine.h>
|
||||
#include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
|
||||
#include <torch/csrc/jit/passes/normalize_ops.h>
|
||||
#include<torch/torch.h>
|
||||
#include<ATen/autocast_mode.h>
|
||||
#include<torch/script.h>
|
||||
#include<stdexcept>
|
||||
#include<vector>
|
||||
#include "torch_api.h"
|
||||
|
||||
#include <torch/csrc/jit/runtime/graph_executor.h>
|
||||
#include <torch/script.h>
|
||||
#include <torch/torch.h>
|
||||
#include <vector>
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
@ -36,15 +36,20 @@ vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len) {
|
|||
return result;
|
||||
}
|
||||
|
||||
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs, int len) {
|
||||
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
|
||||
int len) {
|
||||
vector<c10::optional<torch::Tensor>> result;
|
||||
for (int i = 0; i < len; ++i) {
|
||||
result.push_back(vs[i] != nullptr ? c10::optional<torch::Tensor>(*(vs[i])) : c10::nullopt);
|
||||
result.push_back(vs[i] != nullptr ? c10::optional<torch::Tensor>(*(vs[i]))
|
||||
: c10::nullopt);
|
||||
}
|
||||
return c10::List<c10::optional<torch::Tensor>>(result);
|
||||
}
|
||||
|
||||
at::Device device_of_int(int d) {
|
||||
if (d == -3)
|
||||
return at::Device(at::kVulkan);
|
||||
// if (d == -2) return at::Device(at::kMPS);
|
||||
if (d < 0)
|
||||
return at::Device(at::kCPU);
|
||||
return at::Device(at::kCUDA, /*index=*/d);
|
||||
|
@ -81,19 +86,20 @@ tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
|
|||
|
||||
void at_copy_data(tensor tensor, void *vs, size_t numel,
|
||||
size_t elt_size_in_bytes) {
|
||||
PROTECT(if (elt_size_in_bytes != tensor->element_size()) throw std::
|
||||
invalid_argument("incoherent element sizes in bytes");
|
||||
if (numel > tensor->numel()) throw std::invalid_argument(
|
||||
"target numel is larger than tensor numel");
|
||||
if (tensor->device().type() != at::kCPU) {
|
||||
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
|
||||
void *tensor_data = tmp_tensor.data_ptr();
|
||||
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
|
||||
} else {
|
||||
auto tmp_tensor = tensor->contiguous();
|
||||
void *tensor_data = tmp_tensor.data_ptr();
|
||||
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
|
||||
})
|
||||
PROTECT(
|
||||
if (elt_size_in_bytes != tensor->element_size()) throw std::
|
||||
invalid_argument("incoherent element sizes in bytes");
|
||||
if (numel > tensor->numel()) throw std::invalid_argument(
|
||||
"target numel is larger than tensor numel");
|
||||
if (tensor->device().type() != at::kCPU) {
|
||||
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
|
||||
void *tensor_data = tmp_tensor.data_ptr();
|
||||
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
|
||||
} else {
|
||||
auto tmp_tensor = tensor->contiguous();
|
||||
void *tensor_data = tmp_tensor.data_ptr();
|
||||
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
|
||||
})
|
||||
}
|
||||
|
||||
tensor at_shallow_clone(tensor t) {
|
||||
|
@ -139,15 +145,20 @@ int at_scalar_type(tensor t) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
int at_is_contiguous(tensor t) {
|
||||
PROTECT(return t->is_contiguous();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
// void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor
|
||||
// inf_scale) { PROTECT( at::_amp_non_finite_check_and_unscale_(*t, *found_inf,
|
||||
// *inf_scale);
|
||||
// )
|
||||
// }
|
||||
void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor inf_scale) {
|
||||
PROTECT(
|
||||
at::_amp_foreach_non_finite_check_and_unscale_(*t, *found_inf, *inf_scale);
|
||||
)
|
||||
void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf,
|
||||
tensor inf_scale) {
|
||||
PROTECT(at::_amp_foreach_non_finite_check_and_unscale_(*t, *found_inf,
|
||||
*inf_scale);)
|
||||
}
|
||||
|
||||
void at_autocast_clear_cache() { at::autocast::clear_cache(); }
|
||||
|
@ -176,7 +187,7 @@ bool at_autocast_set_enabled(bool b) {
|
|||
int at_device(tensor t) {
|
||||
PROTECT(auto device = t->device(); if (device.type() == at::kCPU) return -1;
|
||||
if (device.type() == at::kCUDA) return device.index();)
|
||||
return -2;
|
||||
return -99; // error
|
||||
}
|
||||
|
||||
void at_backward(tensor t, int keep_graph, int create_graph) {
|
||||
|
@ -407,13 +418,13 @@ void at_run_backward(tensor *tensors, int ntensors, tensor *inputs, int ninputs,
|
|||
for (int i = 0; i < ntensors; ++i)
|
||||
grads.push_back(torch::ones_like(*tensors[i]));
|
||||
|
||||
auto vl = torch::autograd::Engine::get_default_engine().execute(roots, grads, keep_graph, create_graph, false, inputs_);
|
||||
auto vl = torch::autograd::Engine::get_default_engine().execute(
|
||||
roots, grads, keep_graph, create_graph, false, inputs_);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
outputs[i] = static_cast<tensor>(new torch::autograd::Variable(vl[i]));
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
optimizer ato_adam(double learning_rate, double beta1, double beta2,
|
||||
double weight_decay) {
|
||||
PROTECT(auto options = torch::optim::AdamOptions(learning_rate)
|
||||
|
@ -510,11 +521,11 @@ void ato_set_learning_rate_group(optimizer t, size_t group,
|
|||
set_lr_group<torch::optim::SGDOptions>(t, group, learning_rate);)
|
||||
}
|
||||
|
||||
void ato_constant_pad_nd(tensor *out__, tensor self, int64_t *pad_data, int pad_len, scalar value) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::constant_pad_nd(*self, torch::IntArrayRef(pad_data, pad_len), *value);
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
void ato_constant_pad_nd(tensor *out__, tensor self, int64_t *pad_data,
|
||||
int pad_len, scalar value) {
|
||||
PROTECT(auto outputs__ = torch::constant_pad_nd(
|
||||
*self, torch::IntArrayRef(pad_data, pad_len), *value);
|
||||
out__[0] = new torch::Tensor(outputs__);)
|
||||
}
|
||||
|
||||
// ============ set/get learning rates ==============================
|
||||
|
@ -536,15 +547,17 @@ template <class T> void set_lrs(optimizer t, double *learning_rates) {
|
|||
}
|
||||
|
||||
void ato_set_learning_rates(optimizer t, double *lrs, int lrs_num) {
|
||||
PROTECT(int ngroup = t->param_groups().size(); if (lrs == nullptr) {
|
||||
throw std::invalid_argument("Input learning rates should not be null");
|
||||
} if (ngroup != lrs_num) {
|
||||
throw std::invalid_argument("Size of input learning rates is unequal to "
|
||||
"number of parameter groups.");
|
||||
} set_lrs<torch::optim::AdamOptions>(t, lrs);
|
||||
set_lrs<torch::optim::AdamWOptions>(t, lrs);
|
||||
set_lrs<torch::optim::RMSpropOptions>(t, lrs);
|
||||
set_lrs<torch::optim::SGDOptions>(t, lrs);)
|
||||
PROTECT(
|
||||
int ngroup = t->param_groups().size(); if (lrs == nullptr) {
|
||||
throw std::invalid_argument("Input learning rates should not be null");
|
||||
} if (ngroup != lrs_num) {
|
||||
throw std::invalid_argument(
|
||||
"Size of input learning rates is unequal to "
|
||||
"number of parameter groups.");
|
||||
} set_lrs<torch::optim::AdamOptions>(t, lrs);
|
||||
set_lrs<torch::optim::AdamWOptions>(t, lrs);
|
||||
set_lrs<torch::optim::RMSpropOptions>(t, lrs);
|
||||
set_lrs<torch::optim::SGDOptions>(t, lrs);)
|
||||
}
|
||||
|
||||
template <class T> void get_lrs(optimizer t, vector<double> &lrs) {
|
||||
|
@ -753,6 +766,26 @@ void atc_set_benchmark_cudnn(int b) {
|
|||
at::globalContext().setBenchmarkCuDNN(b);
|
||||
}
|
||||
|
||||
void atc_synchronize(int64_t device_index) {
|
||||
PROTECT(return torch::cuda::synchronize(device_index);)
|
||||
}
|
||||
|
||||
// returns current CUDA device index.
|
||||
int atc_get_device() {
|
||||
PROTECT(at::Device d(at::kCUDA);
|
||||
auto *g = c10::impl::getDeviceGuardImpl(d.type()); d = g->getDevice();
|
||||
return d.index();)
|
||||
return -99; // error
|
||||
}
|
||||
|
||||
// set new cuda device with input device index.
|
||||
void atc_set_device(int device_index) {
|
||||
PROTECT(at::Device new_device(at::kCUDA);
|
||||
new_device = device_of_int(device_index);
|
||||
auto *g = c10::impl::getDeviceGuardImpl(new_device.type());
|
||||
g->setDevice(new_device);)
|
||||
}
|
||||
|
||||
module atm_load(char *filename) {
|
||||
PROTECT(return new torch::jit::script::Module(torch::jit::load(filename));)
|
||||
return nullptr;
|
||||
|
@ -1007,13 +1040,14 @@ void ati_to_generic_list(ivalue i, ivalue *outputs, int noutputs) {
|
|||
}
|
||||
|
||||
void ati_to_generic_dict(ivalue i, ivalue *outputs, int noutputs) {
|
||||
PROTECT(auto dict = i->toGenericDict(); if (dict.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected dict size");
|
||||
} int k = 0;
|
||||
for (auto it = dict.begin(); it != dict.end(); ++it) {
|
||||
outputs[k++] = new torch::jit::IValue(it->key());
|
||||
outputs[k++] = new torch::jit::IValue(it->value());
|
||||
})
|
||||
PROTECT(
|
||||
auto dict = i->toGenericDict(); if (dict.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected dict size");
|
||||
} int k = 0;
|
||||
for (auto it = dict.begin(); it != dict.end(); ++it) {
|
||||
outputs[k++] = new torch::jit::IValue(it->key());
|
||||
outputs[k++] = new torch::jit::IValue(it->value());
|
||||
})
|
||||
}
|
||||
|
||||
void ati_to_int_list(ivalue i, int64_t *outputs, int noutputs) {
|
||||
|
|
|
@ -3,6 +3,10 @@
|
|||
#include <stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
#include <stdexcept>
|
||||
#include <torch/torch.h>
|
||||
using namespace std;
|
||||
thread_local char *torch_last_err = nullptr;
|
||||
|
||||
extern "C" {
|
||||
|
@ -46,6 +50,7 @@ size_t at_dim(tensor);
|
|||
void at_shape(tensor, int64_t *);
|
||||
void at_stride(tensor, int64_t *);
|
||||
int at_scalar_type(tensor);
|
||||
int at_is_contiguous(tensor);
|
||||
|
||||
void at__amp_non_finite_check_and_unscale(tensor, tensor, tensor);
|
||||
|
||||
|
@ -105,13 +110,8 @@ void at_set_num_threads(int n_threads);
|
|||
|
||||
void at_free(tensor);
|
||||
|
||||
void at_run_backward(tensor *tensors,
|
||||
int ntensors,
|
||||
tensor *inputs,
|
||||
int ninputs,
|
||||
tensor *outputs,
|
||||
int keep_graph,
|
||||
int create_graph);
|
||||
void at_run_backward(tensor *tensors, int ntensors, tensor *inputs, int ninputs,
|
||||
tensor *outputs, int keep_graph, int create_graph);
|
||||
|
||||
optimizer ato_adam(double learning_rate, double beta1, double beta2,
|
||||
double weight_decay);
|
||||
|
@ -141,8 +141,10 @@ int64_t ato_param_group_num(optimizer);
|
|||
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
|
||||
void ato_add_param_group(optimizer, tensor *params, int param_num);
|
||||
|
||||
// TT. added option pad value. Original generated API `atg_constant_pad_nd` no option of adding pad value.
|
||||
void ato_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len, scalar value);
|
||||
// TT. added option pad value. Original generated API `atg_constant_pad_nd` no
|
||||
// option of adding pad value.
|
||||
void ato_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len,
|
||||
scalar value);
|
||||
|
||||
scalar ats_int(int64_t);
|
||||
scalar ats_float(double);
|
||||
|
@ -155,6 +157,12 @@ int atc_cuda_device_count();
|
|||
int atc_cuda_is_available();
|
||||
int atc_cudnn_is_available();
|
||||
void atc_set_benchmark_cudnn(int b);
|
||||
void atc_synchronize(int64_t device_index);
|
||||
|
||||
// TT. added for testing qt
|
||||
// ref. https://github.com/pytorch/pytorch/issues/14959
|
||||
int atc_get_device();
|
||||
void atc_set_device(int device_index);
|
||||
|
||||
module atm_load(char *);
|
||||
module atm_load_on_device(char *, int device);
|
||||
|
@ -212,6 +220,11 @@ void ati_free(ivalue);
|
|||
|
||||
#ifdef __cplusplus
|
||||
}; // extern "C"
|
||||
#endif
|
||||
|
||||
std::vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len);
|
||||
at::Device device_of_int(int d);
|
||||
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
|
||||
int len);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
146
mem-util.go
Normal file
146
mem-util.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
package gotch
|
||||
|
||||
// helper to debug memory blow-up
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
)
|
||||
|
||||
func PrintMemStats(messageOpt ...string) {
|
||||
message := "Memory Stats"
|
||||
if len(messageOpt) > 0 {
|
||||
message = fmt.Sprintf("%s: %s", message, messageOpt[0])
|
||||
}
|
||||
|
||||
var rtm runtime.MemStats
|
||||
runtime.ReadMemStats(&rtm)
|
||||
|
||||
tp := newTablePrinter()
|
||||
tp.title = message
|
||||
|
||||
tp.AddRecord("|", "Allocated heap objects", padRight(fmt.Sprintf("%v", rtm.Mallocs), 10), "|")
|
||||
tp.AddRecord("|", "Released heap objects", padRight(fmt.Sprintf("%v", rtm.Frees), 10), "|")
|
||||
tp.AddRecord("|", "Living heap objects", padRight(fmt.Sprintf("%v", rtm.HeapObjects), 10), "|")
|
||||
tp.AddRecord("|", "Memory in use by heap objects (bytes)", padRight(fmt.Sprintf("%v", rtm.HeapAlloc), 10), "|")
|
||||
tp.AddRecord("|", "Reserved memory (by Go runtime for heap, stack,...) (bytes)", padRight(fmt.Sprintf("%v", rtm.Sys), 10), "|")
|
||||
tp.AddRecord("|", "Total pause time by GC (nanoseconds)", padRight(fmt.Sprintf("%v", rtm.PauseTotalNs), 10), "|")
|
||||
tp.AddRecord("|", "Number of GC called", padRight(fmt.Sprintf("%v", rtm.NumGC), 10), "|")
|
||||
// tp.AddRecord("Last GC called", fmt.Sprintf("%v", time.UnixMilli(int64(rtm.LastGC/1_000_000))))
|
||||
|
||||
tp.Print()
|
||||
|
||||
}
|
||||
|
||||
type tablePrinter struct {
|
||||
w *tabwriter.Writer
|
||||
maxLength int
|
||||
title string
|
||||
}
|
||||
|
||||
type printItem struct {
|
||||
val string
|
||||
alignRight bool
|
||||
}
|
||||
|
||||
func item(val string, alignRightOpt ...bool) printItem {
|
||||
alignRight := false
|
||||
if len(alignRightOpt) > 0 {
|
||||
alignRight = alignRightOpt[0]
|
||||
}
|
||||
return printItem{
|
||||
val: val,
|
||||
alignRight: alignRight,
|
||||
}
|
||||
}
|
||||
|
||||
func newTablePrinter() *tablePrinter {
|
||||
w := tabwriter.NewWriter(
|
||||
os.Stdout, //output
|
||||
0, // min width
|
||||
1, // tabwidth
|
||||
2, // padding
|
||||
' ', // padding character
|
||||
0, // align left
|
||||
)
|
||||
|
||||
return &tablePrinter{
|
||||
w: w,
|
||||
maxLength: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) AddRecord(items ...string) {
|
||||
tp.printRecord(items...)
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) AlignRight() {
|
||||
tp.w.Init(
|
||||
os.Stdout, //output
|
||||
0, // min width
|
||||
1, // tabwidth
|
||||
2, // padding
|
||||
' ', // padding character
|
||||
tabwriter.AlignRight,
|
||||
) // flags
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) AlignLeft() {
|
||||
tp.w.Init(
|
||||
os.Stdout, //output
|
||||
0, // min width
|
||||
1, // tabwidth
|
||||
2, // padding
|
||||
' ', // padding character
|
||||
0, // align left
|
||||
) // flags
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) printRecord(rec ...string) {
|
||||
var val string
|
||||
for i, item := range rec {
|
||||
switch i {
|
||||
case 0:
|
||||
val = item
|
||||
case len(rec) - 1:
|
||||
val += fmt.Sprintf("\t%s\n", item)
|
||||
default:
|
||||
val += fmt.Sprintf("\t%s", item)
|
||||
}
|
||||
}
|
||||
|
||||
nbytes, err := tp.w.Write([]byte(val))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if nbytes > tp.maxLength {
|
||||
tp.maxLength = nbytes
|
||||
}
|
||||
}
|
||||
|
||||
func (tp *tablePrinter) Print() {
|
||||
printBorder(tp.maxLength)
|
||||
printLine(tp.maxLength, tp.title)
|
||||
printBorder(tp.maxLength)
|
||||
tp.w.Flush()
|
||||
printBorder(tp.maxLength)
|
||||
}
|
||||
|
||||
func padRight(val interface{}, rightEnd int) string {
|
||||
value := fmt.Sprintf("%v", val)
|
||||
pad := fmt.Sprintf("%s", strings.Repeat(" ", rightEnd-len(value)))
|
||||
return fmt.Sprintf("%s%s", pad, value)
|
||||
}
|
||||
|
||||
func printLine(lineLength int, value string) {
|
||||
fmt.Printf("| %s %s\n", value, padRight("|", lineLength-len(value)-1))
|
||||
}
|
||||
|
||||
func printBorder(length int) {
|
||||
line := fmt.Sprintf("%s", strings.Repeat("-", length))
|
||||
fmt.Printf("+%s+\n", line)
|
||||
}
|
145
nn/init.go
145
nn/init.go
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
type Init interface {
|
||||
// creates a new tensor with specified initiation
|
||||
InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor)
|
||||
InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor)
|
||||
|
||||
// re-initializes (in-place) an existing tensor with the specified initiation
|
||||
Set(tensor *ts.Tensor)
|
||||
|
@ -25,18 +25,24 @@ type constInit struct {
|
|||
value float64
|
||||
}
|
||||
|
||||
var _ Init = new(constInit)
|
||||
|
||||
func NewConstInit(v float64) constInit {
|
||||
return constInit{v}
|
||||
}
|
||||
|
||||
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
func (c constInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
var err error
|
||||
kind := gotch.Float
|
||||
switch {
|
||||
case c.value == 0.0:
|
||||
retVal = ts.MustZeros(dims, kind, device)
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
case c.value == 1.0:
|
||||
retVal = ts.MustOnes(dims, kind, device)
|
||||
retVal = ts.MustOnes(dims, dtype, device)
|
||||
default:
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
|
@ -68,18 +74,29 @@ type randnInit struct {
|
|||
stdev float64
|
||||
}
|
||||
|
||||
var _ Init = new(randnInit)
|
||||
|
||||
func NewRandnInit(mean, stdev float64) randnInit {
|
||||
return randnInit{mean, stdev}
|
||||
}
|
||||
|
||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||
if r.mean == 0 {
|
||||
return ts.MustRandn(dims, gotch.Float, device)
|
||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
initTs := ts.MustRandn(dims, gotch.Float, device)
|
||||
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
||||
ts.NoGrad(func() {
|
||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||
if r.mean == 0 {
|
||||
retVal = ts.MustRandn(dims, dtype, device)
|
||||
}
|
||||
|
||||
initTs := ts.MustRandn(dims, dtype, device)
|
||||
retVal = initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
||||
})
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (r randnInit) Set(tensor *ts.Tensor) {
|
||||
|
@ -88,9 +105,11 @@ func (r randnInit) Set(tensor *ts.Tensor) {
|
|||
log.Fatalf("randInit - Set method call error: %v\n", err)
|
||||
}
|
||||
|
||||
initTs := r.InitTensor(dims, tensor.MustDevice())
|
||||
tensor.Copy_(initTs)
|
||||
initTs.MustDrop()
|
||||
ts.NoGrad(func() {
|
||||
initTs := r.InitTensor(dims, tensor.MustDevice())
|
||||
tensor.Copy_(initTs)
|
||||
initTs.MustDrop()
|
||||
})
|
||||
}
|
||||
|
||||
// uniformInit :
|
||||
|
@ -101,18 +120,26 @@ type uniformInit struct {
|
|||
up float64
|
||||
}
|
||||
|
||||
var _ Init = new(uniformInit)
|
||||
|
||||
func NewUniformInit(lo, up float64) uniformInit {
|
||||
return uniformInit{lo, up}
|
||||
}
|
||||
|
||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
var err error
|
||||
kind := gotch.Float
|
||||
retVal = ts.MustZeros(dims, kind, device)
|
||||
retVal.Uniform_(u.lo, u.up)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
var err error
|
||||
ts.NoGrad(func() {
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(u.lo, u.up)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
}
|
||||
})
|
||||
return retVal
|
||||
}
|
||||
|
||||
|
@ -174,6 +201,8 @@ type kaimingUniformInit struct {
|
|||
NonLinearity string
|
||||
}
|
||||
|
||||
var _ Init = new(kaimingUniformInit)
|
||||
|
||||
func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
|
||||
o := DefaultKaimingOptions()
|
||||
for _, opt := range opts {
|
||||
|
@ -187,26 +216,37 @@ func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
|
|||
}
|
||||
}
|
||||
|
||||
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
fanIn, _, err := CalculateFans(dims)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
|
||||
if err != nil {
|
||||
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
|
||||
panic(err)
|
||||
}
|
||||
/*
|
||||
fanIn, _, err := CalculateFans(dims)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
|
||||
gain, err := calculateGain(k.NonLinearity, k.NegativeSlope) // default non-linearity="leaky_relu", negative_slope=0.01
|
||||
if err != nil {
|
||||
err = fmt.Errorf("kaimingUniformInit.InitTensor() failed: %v\n", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Calculate uniform bounds from standard deviation
|
||||
bound := math.Sqrt(3.0) * std
|
||||
std := gain / math.Sqrt(float64(fanIn)) // default using fanIn
|
||||
|
||||
kind := gotch.Float
|
||||
retVal = ts.MustZeros(dims, kind, device)
|
||||
retVal.Uniform_(-bound, bound)
|
||||
// Calculate uniform bounds from standard deviation
|
||||
bound := math.Sqrt(3.0) * std
|
||||
|
||||
// NOTE. This is a well-known memory leak!!!
|
||||
// Avoid to use it for now!!!
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(-bound, bound)
|
||||
*/
|
||||
|
||||
// NOTE. For now, just make a random norm
|
||||
retVal = ts.MustRandn(dims, dtype, device)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
@ -347,3 +387,36 @@ func contains(items []string, item string) bool {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// XavierUniform fills the input tensor with values according to the method
|
||||
// described in the paper `Understanding the difficulty of training deep feedforward neural networks`
|
||||
// using a uniform distribution
|
||||
//
|
||||
// Also known as Glorot initialization.
|
||||
//
|
||||
// Paper: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
|
||||
// Pytorch implementation: https://github.com/pytorch/pytorch/blob/df50f91571891ec3f87977a2bdd4a2b609d70afc/torch/nn/init.py#L310
|
||||
func XavierUniform_(x *ts.Tensor, gainOpt ...float64) {
|
||||
gain := 1.0
|
||||
if len(gainOpt) > 0 {
|
||||
gain = gainOpt[0]
|
||||
}
|
||||
|
||||
size := x.MustSize()
|
||||
dtype := x.DType()
|
||||
device := x.MustDevice()
|
||||
fanIn, fanOut, err := CalculateFans(size)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
std := gain * math.Sqrt(2.0/float64(fanIn+fanOut))
|
||||
|
||||
// calculate uniform bounds from standard deviation
|
||||
a := math.Sqrt(3.0) * std
|
||||
uniformInit := NewUniformInit(-a, a)
|
||||
src := uniformInit.InitTensor(size, device, dtype)
|
||||
x.Copy_(src)
|
||||
|
||||
src.MustDrop()
|
||||
}
|
||||
|
|
44
nn/init_test.go
Normal file
44
nn/init_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Test whether InitTensor() can cause memory blow-up due to accumulate gradient.
|
||||
func TestInitTensor_Memcheck(t *testing.T) {
|
||||
gotch.PrintMemStats("Start")
|
||||
device := gotch.CPU
|
||||
vs := NewVarStore(device)
|
||||
params := 500
|
||||
|
||||
path := vs.Root()
|
||||
dims := []int64{1024, 1024}
|
||||
for i := 0; i < params; i++ {
|
||||
ts.NoGrad(func() {
|
||||
name := fmt.Sprintf("param_%v", i)
|
||||
x := ts.MustRandn(dims, gotch.DefaultDType, device)
|
||||
path.MustAdd(name, x, false)
|
||||
x.MustDrop()
|
||||
})
|
||||
}
|
||||
|
||||
// vs.Summary()
|
||||
|
||||
fmt.Printf("vs created...\n")
|
||||
// printMemStats("After varstore created")
|
||||
|
||||
vs.Destroy()
|
||||
ts.CleanUp()
|
||||
|
||||
fmt.Printf("vs deleted...\n")
|
||||
|
||||
// printMemStats("After varstore deleted")
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
gotch.PrintMemStats("Final")
|
||||
}
|
|
@ -79,7 +79,7 @@ func (m *TrainableCModule) Save(file string) error {
|
|||
// ForwardT implements ModuleT for TrainableCModule.
|
||||
// NOTE: train parameter will not be used.
|
||||
func (m *TrainableCModule) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
|
||||
retVal, err := m.Inner.ForwardTs([]ts.Tensor{*x})
|
||||
retVal, err := m.Inner.ForwardTs([]*ts.Tensor{x})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
43
nn/linear.go
43
nn/linear.go
|
@ -6,7 +6,6 @@ import (
|
|||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
|
@ -22,6 +21,8 @@ type LinearConfig struct {
|
|||
func DefaultLinearConfig() *LinearConfig {
|
||||
negSlope := math.Sqrt(5)
|
||||
return &LinearConfig{
|
||||
// NOTE. KaimingUniform cause mem leak due to ts.Uniform()!!!
|
||||
// Avoid using it now.
|
||||
WsInit: NewKaimingUniformInit(WithKaimingNegativeSlope(negSlope)),
|
||||
BsInit: nil,
|
||||
Bias: true,
|
||||
|
@ -41,11 +42,7 @@ type Linear struct {
|
|||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||
var bs *ts.Tensor
|
||||
// bs has size of output dimension
|
||||
switch c.Bias {
|
||||
case false:
|
||||
bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device())
|
||||
case true:
|
||||
if c.Bias {
|
||||
switch {
|
||||
case c.BsInit == nil:
|
||||
shape := []int64{inDim, outDim}
|
||||
|
@ -65,8 +62,10 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||
}
|
||||
}
|
||||
|
||||
ws := vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false)
|
||||
|
||||
return &Linear{
|
||||
Ws: vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
||||
Ws: ws,
|
||||
Bs: bs,
|
||||
}
|
||||
}
|
||||
|
@ -87,29 +86,31 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||
//
|
||||
// Example:
|
||||
//
|
||||
// inDim := 3
|
||||
// outDim := 2
|
||||
// batchSize := 4
|
||||
// weights: 2x3
|
||||
// [ 1 1 1
|
||||
// 1 1 1 ]
|
||||
// inDim := 3
|
||||
// outDim := 2
|
||||
// batchSize := 4
|
||||
// weights: 2x3
|
||||
// [ 1 1 1
|
||||
// 1 1 1 ]
|
||||
//
|
||||
// input node: 3x4
|
||||
// [ 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
// input node: 3x4
|
||||
// [ 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
||||
|
||||
mul := xs.MustMatmul(l.Ws, false)
|
||||
return mul.MustAdd(l.Bs, true)
|
||||
if l.Bs != nil {
|
||||
return mul.MustAdd(l.Bs, true)
|
||||
} else {
|
||||
return mul
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardT implements ModuleT interface for Linear layer.
|
||||
//
|
||||
// NOTE: train param will not be used.
|
||||
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
||||
|
||||
mul := xs.MustMatmul(l.Ws, false)
|
||||
return mul.MustAdd(l.Bs, true)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
|
@ -68,7 +67,7 @@ func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tenso
|
|||
reduction := options.Reduction
|
||||
ignoreIndex := options.IgnoreIndex
|
||||
|
||||
logSm := logits.MustLogSoftmax(-1, gotch.Float, false)
|
||||
logSm := logits.MustLogSoftmax(-1, dtype, false)
|
||||
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
|
||||
ws.MustDrop()
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"log"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
|
@ -387,12 +386,12 @@ func WithErrorIfNonFinite(v bool) ClipOpt {
|
|||
}
|
||||
}
|
||||
|
||||
/// Clips gradient L2 norm over all trainable parameters.
|
||||
// / Clips gradient L2 norm over all trainable parameters.
|
||||
//
|
||||
// The norm is computed over all gradients together, as if they were
|
||||
// concatenated into a single vector.
|
||||
//
|
||||
/// Args:
|
||||
// / Args:
|
||||
// - max: max norm of the gradient
|
||||
// - o.NormType. Type of the used p-norm, can be "inf" for infinity norm. Default= 2.0
|
||||
// - o.ErrorIfNonFinite bool. If true, throw error if total norm of the gradients from paramters is "nan", "inf" or "-inf". Default=false
|
||||
|
@ -413,15 +412,19 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
|||
}
|
||||
|
||||
var (
|
||||
norms []ts.Tensor
|
||||
norms []*ts.Tensor
|
||||
totalNorm *ts.Tensor
|
||||
)
|
||||
|
||||
device := opt.varstore.device
|
||||
|
||||
// FIXME. What about mixed-precision?
|
||||
dtype := parameters[0].DType()
|
||||
|
||||
if o.NormType == math.Inf(1) {
|
||||
for _, v := range opt.varstore.vars {
|
||||
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
|
||||
norms = append(norms, *n)
|
||||
norms = append(norms, n)
|
||||
}
|
||||
// total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||
totalNorm = ts.MustStack(norms, 0).MustMax(true)
|
||||
|
@ -431,14 +434,14 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
|||
|
||||
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
|
||||
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||
norms = append(norms, *x)
|
||||
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
|
||||
norms = append(norms, x)
|
||||
}
|
||||
}
|
||||
|
||||
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
|
||||
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
|
||||
for _, x := range norms {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
@ -556,7 +559,7 @@ func (opt *Optimizer) ParamGroupNum() int {
|
|||
return int(ngroup)
|
||||
}
|
||||
|
||||
func (opt *Optimizer) AddParamGroup(tensors []ts.Tensor) {
|
||||
func (opt *Optimizer) AddParamGroup(tensors []*ts.Tensor) {
|
||||
err := opt.opt.AddParamGroup(tensors)
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err)
|
||||
|
|
25
nn/rnn.go
25
nn/rnn.go
|
@ -74,7 +74,7 @@ func DefaultRNNConfig() *RNNConfig {
|
|||
//
|
||||
// https://en.wikipedia.org/wiki/Long_short-term_memory
|
||||
type LSTM struct {
|
||||
flatWeights []ts.Tensor
|
||||
flatWeights []*ts.Tensor
|
||||
hiddenDim int64
|
||||
config *RNNConfig
|
||||
device gotch.Device
|
||||
|
@ -89,7 +89,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
|||
}
|
||||
|
||||
gateDim := 4 * hiddenDim
|
||||
flatWeights := make([]ts.Tensor, 0)
|
||||
flatWeights := make([]*ts.Tensor, 0)
|
||||
|
||||
for i := 0; i < int(cfg.NumLayers); i++ {
|
||||
if i != 0 {
|
||||
|
@ -102,7 +102,7 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
|||
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
|
||||
|
||||
case 2: // bi-directional
|
||||
// forward
|
||||
|
@ -110,14 +110,14 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
|||
wHh := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
|
||||
|
||||
// reverse
|
||||
wIhR := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
|
||||
wHhR := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
|
||||
bIhR := vs.MustZeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
|
||||
bHhR := vs.MustZeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
|
||||
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
|
||||
flatWeights = append(flatWeights, wIhR, wHhR, bIhR, bHhR)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,9 @@ func (l *LSTM) ZeroState(batchDim int64) State {
|
|||
|
||||
layerDim := l.config.NumLayers * numDirections
|
||||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
||||
|
||||
dtype := l.flatWeights[0].DType()
|
||||
zeros := ts.MustZeros(shape, dtype, l.device)
|
||||
|
||||
retVal := &LSTMState{
|
||||
Tensor1: zeros.MustShallowClone(),
|
||||
|
@ -188,7 +190,7 @@ func (l *LSTM) Seq(input *ts.Tensor) (*ts.Tensor, State) {
|
|||
|
||||
func (l *LSTM) SeqInit(input *ts.Tensor, inState State) (*ts.Tensor, State) {
|
||||
|
||||
output, h, c := input.MustLstm([]ts.Tensor{*inState.(*LSTMState).Tensor1, *inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||
output, h, c := input.MustLstm([]*ts.Tensor{inState.(*LSTMState).Tensor1, inState.(*LSTMState).Tensor2}, l.flatWeights, l.config.HasBiases, l.config.NumLayers, l.config.Dropout, l.config.Train, l.config.Bidirectional, l.config.BatchFirst)
|
||||
|
||||
return output, &LSTMState{
|
||||
Tensor1: h,
|
||||
|
@ -209,7 +211,7 @@ func (gs *GRUState) Value() *ts.Tensor {
|
|||
//
|
||||
// https://en.wikipedia.org/wiki/Gated_recurrent_unit
|
||||
type GRU struct {
|
||||
flatWeights []ts.Tensor
|
||||
flatWeights []*ts.Tensor
|
||||
hiddenDim int64
|
||||
config *RNNConfig
|
||||
device gotch.Device
|
||||
|
@ -223,7 +225,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
|
|||
}
|
||||
|
||||
gateDim := 3 * hiddenDim
|
||||
flatWeights := make([]ts.Tensor, 0)
|
||||
flatWeights := make([]*ts.Tensor, 0)
|
||||
|
||||
for i := 0; i < int(cfg.NumLayers); i++ {
|
||||
for n := 0; n < int(numDirections); n++ {
|
||||
|
@ -239,7 +241,7 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
|
|||
bIh := vs.MustZeros("b_ih", []int64{gateDim})
|
||||
bHh := vs.MustZeros("b_hh", []int64{gateDim})
|
||||
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
flatWeights = append(flatWeights, wIh, wHh, bIh, bHh)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -269,7 +271,8 @@ func (g *GRU) ZeroState(batchDim int64) State {
|
|||
layerDim := g.config.NumLayers * numDirections
|
||||
shape := []int64{layerDim, batchDim, g.hiddenDim}
|
||||
|
||||
tensor := ts.MustZeros(shape, gotch.Float, g.device)
|
||||
dtype := g.flatWeights[0].DType()
|
||||
tensor := ts.MustZeros(shape, dtype, g.device)
|
||||
|
||||
return &GRUState{Tensor: tensor}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func (s *Sequential) AddFn(fn ts.Module) {
|
|||
}
|
||||
|
||||
// ForwardAll applies the forward pass and returns the output for each layer.
|
||||
func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []ts.Tensor) {
|
||||
func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []*ts.Tensor) {
|
||||
|
||||
var n uint8 = uint8(len(s.layers))
|
||||
if len(opts) > 0 {
|
||||
|
@ -54,11 +54,11 @@ func (s *Sequential) ForwardAll(xs *ts.Tensor, opts ...uint8) (retVal []ts.Tenso
|
|||
}
|
||||
|
||||
if s.IsEmpty() {
|
||||
return []ts.Tensor{*xs.MustShallowClone()}
|
||||
return []*ts.Tensor{xs.MustShallowClone()}
|
||||
}
|
||||
|
||||
for i := 0; i < int(n); i++ {
|
||||
retVal = append(retVal, *s.layers[i].Forward(xs))
|
||||
retVal = append(retVal, s.layers[i].Forward(xs))
|
||||
}
|
||||
|
||||
return retVal
|
||||
|
@ -85,15 +85,15 @@ func (s *Sequential) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
|||
}
|
||||
|
||||
// forward sequentially
|
||||
outs := make([]ts.Tensor, len(s.layers))
|
||||
outs := make([]*ts.Tensor, len(s.layers))
|
||||
for i := 0; i < len(s.layers); i++ {
|
||||
if i == 0 {
|
||||
outs[0] = *s.layers[i].Forward(xs)
|
||||
outs[0] = s.layers[i].Forward(xs)
|
||||
defer outs[0].MustDrop()
|
||||
} else if i == len(s.layers)-1 {
|
||||
return s.layers[i].Forward(&outs[i-1])
|
||||
return s.layers[i].Forward(outs[i-1])
|
||||
} else {
|
||||
outs[i] = *s.layers[i].Forward(&outs[i-1])
|
||||
outs[i] = s.layers[i].Forward(outs[i-1])
|
||||
defer outs[i].MustDrop()
|
||||
}
|
||||
}
|
||||
|
@ -106,7 +106,7 @@ type SequentialT struct {
|
|||
layers []ts.ModuleT
|
||||
}
|
||||
|
||||
/// SeqT creates a new empty sequential layer.
|
||||
// / SeqT creates a new empty sequential layer.
|
||||
func SeqT() *SequentialT {
|
||||
return &SequentialT{
|
||||
layers: make([]ts.ModuleT, 0),
|
||||
|
@ -139,15 +139,15 @@ func (s *SequentialT) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
|
|||
}
|
||||
|
||||
// forward sequentially
|
||||
outs := make([]ts.Tensor, len(s.layers))
|
||||
outs := make([]*ts.Tensor, len(s.layers))
|
||||
for i := 0; i < len(s.layers); i++ {
|
||||
if i == 0 {
|
||||
outs[0] = *s.layers[i].ForwardT(xs, train)
|
||||
outs[0] = s.layers[i].ForwardT(xs, train)
|
||||
defer outs[0].MustDrop()
|
||||
} else if i == len(s.layers)-1 {
|
||||
return s.layers[i].ForwardT(&outs[i-1], train)
|
||||
return s.layers[i].ForwardT(outs[i-1], train)
|
||||
} else {
|
||||
outs[i] = *s.layers[i].ForwardT(&outs[i-1], train)
|
||||
outs[i] = s.layers[i].ForwardT(outs[i-1], train)
|
||||
defer outs[i].MustDrop()
|
||||
}
|
||||
}
|
||||
|
@ -179,7 +179,7 @@ func (s *SequentialT) AddFnT(fn ts.ModuleT) {
|
|||
}
|
||||
|
||||
// ForwardAll applies the forward pass and returns the output for each layer.
|
||||
func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (retVal []ts.Tensor) {
|
||||
func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (retVal []*ts.Tensor) {
|
||||
|
||||
var n uint8 = uint8(len(s.layers))
|
||||
if len(opts) > 0 {
|
||||
|
@ -187,13 +187,13 @@ func (s *SequentialT) ForwardAllT(xs *ts.Tensor, train bool, opts ...uint8) (ret
|
|||
}
|
||||
|
||||
if s.IsEmpty() {
|
||||
return []ts.Tensor{*xs.MustShallowClone()}
|
||||
return []*ts.Tensor{xs.MustShallowClone()}
|
||||
}
|
||||
|
||||
currTs := xs
|
||||
for i := 0; i < int(n); i++ {
|
||||
res := s.layers[i].ForwardT(currTs, train)
|
||||
retVal = append(retVal, *res)
|
||||
retVal = append(retVal, res)
|
||||
currTs = res
|
||||
}
|
||||
|
||||
|
|
146
nn/varstore.go
146
nn/varstore.go
|
@ -78,15 +78,15 @@ func (vs *VarStore) IsEmpty() bool {
|
|||
}
|
||||
|
||||
// TrainableVariabless returns reference to all trainable variables kept in VarStore.
|
||||
func (vs *VarStore) TrainableVariables() []ts.Tensor {
|
||||
func (vs *VarStore) TrainableVariables() []*ts.Tensor {
|
||||
vs.Lock()
|
||||
defer vs.Unlock()
|
||||
|
||||
var trainables []ts.Tensor
|
||||
var trainables []*ts.Tensor
|
||||
for _, v := range vs.vars {
|
||||
x := v.Tensor
|
||||
if x.MustRequiresGrad() {
|
||||
trainables = append(trainables, *x)
|
||||
trainables = append(trainables, x)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -192,6 +192,8 @@ func (vs *VarStore) Load(filepath string) error {
|
|||
x.Tensor.MustDrop()
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -227,6 +229,9 @@ func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error {
|
|||
v.Tensor.Copy_(currTs)
|
||||
})
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -286,6 +291,8 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) {
|
|||
x.Tensor.MustDrop()
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
|
||||
return missingVariables, nil
|
||||
}
|
||||
|
||||
|
@ -336,6 +343,8 @@ func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string,
|
|||
})
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
|
||||
return missingVariables, nil
|
||||
}
|
||||
|
||||
|
@ -406,6 +415,8 @@ func (vs *VarStore) Copy(src *VarStore) error {
|
|||
srcDevTs.MustDrop()
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -417,12 +428,21 @@ func (vs *VarStore) Summary() {
|
|||
layers = append(layers, name)
|
||||
}
|
||||
sort.Strings(layers)
|
||||
var dtype gotch.DType
|
||||
isFirst := true
|
||||
for _, l := range layers {
|
||||
var x *ts.Tensor
|
||||
var isBuffer bool
|
||||
for name, v := range vars {
|
||||
if name == l {
|
||||
x = v.Tensor
|
||||
|
||||
// Get DType of first tensor for representation only
|
||||
if isFirst {
|
||||
dtype = x.DType()
|
||||
}
|
||||
isFirst = false
|
||||
|
||||
isBuffer = v.Type == "buffer"
|
||||
break
|
||||
}
|
||||
|
@ -435,25 +455,33 @@ func (vs *VarStore) Summary() {
|
|||
}
|
||||
|
||||
fmt.Printf("Num of layers: %v\n", len(vars))
|
||||
fmt.Printf("DType: %v\n", dtype)
|
||||
}
|
||||
|
||||
// Destroy deletes all tensors in varstore and set it to nil.
|
||||
func (vs *VarStore) Destroy() {
|
||||
vs.Lock()
|
||||
for n, v := range vs.vars {
|
||||
v.Tensor.MustDrop()
|
||||
|
||||
delete(vs.vars, n)
|
||||
}
|
||||
|
||||
vs.Unlock()
|
||||
|
||||
vs = nil
|
||||
}
|
||||
|
||||
// ToDType casts all variables in VarStore to specified DType.
|
||||
//
|
||||
// NOTE. only float-like types (Half, Float, Double) can ensure convertible.
|
||||
// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.
|
||||
func (vs *VarStore) ToDType(dtype gotch.DType) {
|
||||
vs.Root().ToDType(dtype)
|
||||
}
|
||||
|
||||
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToHalf() {
|
||||
vs.Root().ToHalf()
|
||||
}
|
||||
|
||||
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
// NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToFloat() {
|
||||
vs.Root().ToFloat()
|
||||
}
|
||||
|
@ -465,6 +493,25 @@ func (vs *VarStore) ToDouble() {
|
|||
vs.Root().ToDouble()
|
||||
}
|
||||
|
||||
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToHalf() {
|
||||
vs.Root().ToHalf()
|
||||
}
|
||||
|
||||
// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToBFloat16() {
|
||||
vs.Root().ToBFloat16()
|
||||
}
|
||||
|
||||
func (vs *VarStore) ToDevice(device gotch.Device) {
|
||||
p := vs.Root()
|
||||
p.ToDevice(device)
|
||||
}
|
||||
|
||||
// Path methods:
|
||||
// =============
|
||||
|
||||
|
@ -520,6 +567,7 @@ func (p *Path) add(name string, newTs *ts.Tensor, trainable bool, varType string
|
|||
tensor *ts.Tensor
|
||||
err error
|
||||
)
|
||||
|
||||
if trainable {
|
||||
tensor, err = newTs.SetRequiresGrad(true, false)
|
||||
if err != nil {
|
||||
|
@ -664,7 +712,7 @@ func (p *Path) SetGroup(g uint) {
|
|||
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
|
||||
//
|
||||
// NOTE. this method should be used for floating-point conversion, i.e.,
|
||||
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double".
|
||||
// "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double".
|
||||
func (p *Path) ToDType(dtype gotch.DType) {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
|
@ -686,28 +734,63 @@ func (p *Path) toFloat(dtype gotch.DType) {
|
|||
for name, v := range p.varstore.vars {
|
||||
if strings.Contains(name, path) {
|
||||
dtype := v.Tensor.DType()
|
||||
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double {
|
||||
if gotch.IsFloatDType(dtype) {
|
||||
newVar := v
|
||||
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
||||
p.varstore.vars[name] = newVar
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
}
|
||||
|
||||
// ToHalf casts all variables in current path and subpaths to `Half` precision.
|
||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||
func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) {
|
||||
dtype := gotch.Float
|
||||
if len(floatDTypeOpt) > 0 {
|
||||
dt := floatDTypeOpt[0]
|
||||
if !gotch.IsFloatDType(dt) {
|
||||
// Ingore the option
|
||||
if gotch.Debug {
|
||||
log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt)
|
||||
}
|
||||
} else {
|
||||
dtype = dt
|
||||
}
|
||||
}
|
||||
|
||||
p.toFloat(dtype)
|
||||
}
|
||||
|
||||
// ToDouble casts all variables in current path and subpaths to `Double` precision dtype.
|
||||
func (p *Path) ToDouble() {
|
||||
p.toFloat(gotch.Double)
|
||||
}
|
||||
|
||||
// ToHalf casts all variables in current path and subpaths to `Half` precision dtype.
|
||||
func (p *Path) ToHalf() {
|
||||
p.toFloat(gotch.Half)
|
||||
}
|
||||
|
||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||
func (p *Path) ToFloat() {
|
||||
p.toFloat(gotch.Float)
|
||||
// ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype.
|
||||
func (p *Path) ToBFloat16() {
|
||||
p.toFloat(gotch.BFloat16)
|
||||
}
|
||||
|
||||
// ToDouble casts all variables in current path and subpaths to `Double` precision.
|
||||
func (p *Path) ToDouble() {
|
||||
p.toFloat(gotch.Double)
|
||||
func (p *Path) ToDevice(device gotch.Device) {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
path := strings.Join(p.path, SEP)
|
||||
for name, v := range p.varstore.vars {
|
||||
if strings.Contains(name, path) {
|
||||
newVar := v
|
||||
newVar.Tensor = v.Tensor.MustTo(device, true)
|
||||
p.varstore.vars[name] = newVar
|
||||
}
|
||||
}
|
||||
|
||||
ts.CleanUp()
|
||||
}
|
||||
|
||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||
|
@ -718,7 +801,8 @@ func (p *Path) ToDouble() {
|
|||
// The variable uses a float tensor initialized with zeros.
|
||||
func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
device := p.Device()
|
||||
z, err := ts.Zeros(dims, gotch.Float, device)
|
||||
dtype := gotch.DefaultDType
|
||||
z, err := ts.Zeros(dims, dtype, device)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err)
|
||||
return nil, err
|
||||
|
@ -755,7 +839,8 @@ func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.T
|
|||
// The variable uses a float tensor initialized with ones.
|
||||
func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
device := p.Device()
|
||||
z, err := ts.Ones(dims, gotch.Float, device)
|
||||
dtype := gotch.DefaultDType
|
||||
z, err := ts.Ones(dims, dtype, device)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Path.OneNoTrain() failed: %w", err)
|
||||
return nil, err
|
||||
|
@ -792,12 +877,19 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
|
|||
// The variable uses a float tensor initialized as per the
|
||||
// related argument.
|
||||
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
v := ini.InitTensor(dims, p.varstore.device)
|
||||
dtype := gotch.DefaultDType
|
||||
// v := ini.InitTensor(dims, p.varstore.device, dtype)
|
||||
var v *ts.Tensor
|
||||
|
||||
v = ini.InitTensor(dims, p.varstore.device, dtype)
|
||||
|
||||
out, err := p.Add(name, v, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v.MustDrop()
|
||||
|
||||
return out, err
|
||||
}
|
||||
|
||||
|
@ -1098,7 +1190,8 @@ func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor {
|
|||
|
||||
// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable.
|
||||
func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
o := ts.MustOnes(dims, gotch.Float, e.path.Device())
|
||||
dtype := gotch.DefaultDType
|
||||
o := ts.MustOnes(dims, dtype, e.path.Device())
|
||||
out, err := e.path.getOrAddWithLock(e.name, o, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1161,7 +1254,8 @@ func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.
|
|||
|
||||
// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable.
|
||||
func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
|
||||
dtype := gotch.DefaultDType
|
||||
z := ts.MustZeros(dims, dtype, e.path.Device())
|
||||
out, err := e.path.getOrAddWithLock(e.name, z, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
package nn_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
|
@ -133,3 +135,43 @@ func TestSaveLoad(t *testing.T) {
|
|||
t.Errorf("Failed deleting varstore saved file: %v\n", filenameAbs)
|
||||
}
|
||||
}
|
||||
|
||||
// Test whether create params in varstore can cause memory blow-up due to accumulate gradient.
|
||||
func TestVarstore_Memcheck(t *testing.T) {
|
||||
gotch.PrintMemStats("Start")
|
||||
device := gotch.CPU
|
||||
vs := nn.NewVarStore(device)
|
||||
params := 1000
|
||||
|
||||
path := vs.Root()
|
||||
// dims := []int64{1024, 1024}
|
||||
config := nn.DefaultLinearConfig()
|
||||
inDim := int64(1024)
|
||||
outDim := int64(1024)
|
||||
var layers []nn.Linear
|
||||
for i := 0; i < params; i++ {
|
||||
ts.NoGrad(func() {
|
||||
name := fmt.Sprintf("param_%v", i)
|
||||
l := nn.NewLinear(path.Sub(name), inDim, outDim, config)
|
||||
layers = append(layers, *l)
|
||||
// x := ts.MustRandn(dims, gotch.DefaultDType, device)
|
||||
// path.MustAdd(name, x, false)
|
||||
// x.MustDrop()
|
||||
})
|
||||
}
|
||||
|
||||
// vs.Summary()
|
||||
|
||||
fmt.Printf("vs created...\n")
|
||||
// printMemStats("After varstore created")
|
||||
|
||||
vs.Destroy()
|
||||
ts.CleanUp()
|
||||
|
||||
fmt.Printf("vs deleted...\n")
|
||||
|
||||
// printMemStats("After varstore deleted")
|
||||
|
||||
time.Sleep(time.Second * 10)
|
||||
gotch.PrintMemStats("Final")
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pickle_test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
|
@ -18,11 +19,13 @@ func ExampleLoadInfo() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
err = pickle.LoadInfo(modelFile)
|
||||
m, err := pickle.LoadModelInfo(modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(m)
|
||||
|
||||
// Output:
|
||||
// classifier.0.bias - [4096]
|
||||
// classifier.0.weight - [4096 25088]
|
||||
|
@ -57,4 +60,25 @@ func ExampleLoadInfo() {
|
|||
// features.7.bias - [128]
|
||||
// features.7.weight - [128 128 3 3]
|
||||
// Num of variables: 32
|
||||
// Tensor DType: Float
|
||||
}
|
||||
|
||||
func ExampleModelFloat16() {
|
||||
modelName := "HuggingFaceH4/tiny-random-LlamaForCausalLM"
|
||||
url := "https://huggingface.co/HuggingFaceH4/tiny-random-LlamaForCausalLM/resolve/main/pytorch_model.bin"
|
||||
|
||||
modelFile, err := gotch.CachedPath(url, modelName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
m, err := pickle.LoadModelInfo(modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Model DType: %v\n", m.DType())
|
||||
|
||||
// Output:
|
||||
// Model DType: Half
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"reflect"
|
||||
"sort"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
@ -60,83 +61,41 @@ func Decode(filename string) (map[string]*ts.Tensor, error) {
|
|||
dictResult := *result.(*Dict)
|
||||
for _, item := range dictResult {
|
||||
name := item.Key
|
||||
itemTyp := reflect.TypeOf(item.Value).String()
|
||||
switch itemTyp {
|
||||
case "*pickle.Dict": // Nested *pickle.Dict case
|
||||
subResult := *item.Value.(*Dict)
|
||||
for _, subItem := range subResult {
|
||||
subName := subItem.Key
|
||||
x, ok := subItem.Value.(*StorageTensor)
|
||||
if !ok {
|
||||
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v. Skip decoding parameter %q ...\n", reflect.TypeOf(subItem.Value).String(), subName)
|
||||
continue
|
||||
}
|
||||
data := x.Source.GetData()
|
||||
size := x.Size
|
||||
dtype := x.Source.DType()
|
||||
device := x.Source.Device()
|
||||
stride := x.Stride
|
||||
storageOffset := x.StorageOffset
|
||||
if reflect.ValueOf(data).Len() == 0 {
|
||||
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x1 := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if x.RequiresGrad {
|
||||
x1.MustRequiresGrad_(x.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x1
|
||||
}
|
||||
|
||||
default:
|
||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||
|
||||
// if !isStorageTensor {
|
||||
// err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||
// return nil, err
|
||||
// }
|
||||
if !isStorageTensor {
|
||||
log.Printf("INFO: Decode() failed: expected 'StorageTensor' type, got %v, with value of %v. Skip decoding parameter %q ...\n", reflect.TypeOf(item.Value).String(), item.Value, name)
|
||||
continue
|
||||
}
|
||||
|
||||
data := sx.Source.GetData()
|
||||
size := sx.Size
|
||||
dtype := sx.Source.DType()
|
||||
device := sx.Source.Device()
|
||||
stride := sx.Stride
|
||||
storageOffset := sx.StorageOffset
|
||||
|
||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
// log.Printf("data: %v\n", data)
|
||||
|
||||
// Dealing with Pytorch `..._tracked` variables.
|
||||
if reflect.ValueOf(data).Len() == 0 {
|
||||
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x := ts.MustOfSlice(data).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if sx.RequiresGrad {
|
||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x
|
||||
sx, isStorageTensor := item.Value.(*StorageTensor)
|
||||
if !isStorageTensor {
|
||||
err := fmt.Errorf("Decode() failed: expected 'StorageTensor' type, got %v\n", reflect.TypeOf(item.Value).String())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data := sx.Source.GetData()
|
||||
size := sx.Size
|
||||
dtype := sx.Source.DType()
|
||||
|
||||
device := sx.Source.Device()
|
||||
stride := sx.Stride
|
||||
storageOffset := sx.StorageOffset
|
||||
|
||||
// log.Printf("%q - %q - shape: %v - stride: %v - storageOffset: %v\n", name, sx.Source.Device().Name, sx.Size, sx.Stride, storageOffset)
|
||||
// log.Printf("data: %v\n", data)
|
||||
|
||||
// Dealing with Pytorch `..._tracked` variables.
|
||||
if reflect.ValueOf(data).Len() == 0 {
|
||||
log.Printf("INFO: skip weight %q with zero data length.\n", name.(string))
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO. should we just skip them?
|
||||
if reflect.ValueOf(data).Len() == 1 && len(size) == 0 {
|
||||
size = []int64{1}
|
||||
stride = []int64{1}
|
||||
}
|
||||
|
||||
x := ts.MustOfSlice(data, ts.WithDType(dtype)).MustAsStrided(size, stride, []int64{storageOffset}, true).MustTotype(dtype, true).MustTo(device, true)
|
||||
if sx.RequiresGrad {
|
||||
x.MustRequiresGrad_(sx.RequiresGrad)
|
||||
}
|
||||
|
||||
namedTensors[name.(string)] = x
|
||||
}
|
||||
case "*pickle.OrderedDict":
|
||||
dictResult := result.(*OrderedDict)
|
||||
|
@ -581,35 +540,74 @@ func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
|
|||
return missingVariables, nil
|
||||
}
|
||||
|
||||
// LoadInfo loads pretrained weights and prints out name and shape of weights.
|
||||
func LoadInfo(modelFile string) error {
|
||||
weights, err := Decode(modelFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadInfo() failed: %w", err)
|
||||
return err
|
||||
}
|
||||
type ModelInfor struct {
|
||||
weights map[string][]int64
|
||||
dtype gotch.DType
|
||||
}
|
||||
|
||||
layers := make([]string, 0, len(weights))
|
||||
for tsName := range weights {
|
||||
func NewModelInfor(weights map[string][]int64, dtype gotch.DType) *ModelInfor {
|
||||
return &ModelInfor{
|
||||
weights: weights,
|
||||
dtype: dtype,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelInfor) String() string {
|
||||
var summary string
|
||||
layers := make([]string, 0, len(m.weights))
|
||||
for tsName := range m.weights {
|
||||
layers = append(layers, tsName)
|
||||
}
|
||||
sort.Strings(layers)
|
||||
for _, l := range layers {
|
||||
var x *ts.Tensor
|
||||
for tsName, tsVal := range weights {
|
||||
var x []int64
|
||||
for tsName, shape := range m.weights {
|
||||
if tsName == l {
|
||||
x = tsVal
|
||||
x = shape
|
||||
break
|
||||
}
|
||||
}
|
||||
fmt.Printf("%s - %+v\n", l, x.MustSize())
|
||||
|
||||
summary += fmt.Sprintf("%s - %+v\n", l, x)
|
||||
}
|
||||
|
||||
fmt.Printf("Num of variables: %v\n", len(weights))
|
||||
summary += fmt.Sprintf("Num of variables: %v\n", len(m.weights))
|
||||
summary += fmt.Sprintf("Tensor DType: %v\n", m.dtype)
|
||||
|
||||
for _, x := range weights {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
||||
return nil
|
||||
return summary
|
||||
}
|
||||
|
||||
func (m *ModelInfor) DType() gotch.DType {
|
||||
return m.dtype
|
||||
}
|
||||
|
||||
func (m *ModelInfor) Parameters() int {
|
||||
return len(m.weights)
|
||||
}
|
||||
|
||||
// LoadInfo loads pretrained weights and prints out name and shape of weights.
|
||||
func LoadModelInfo(modelFile string) (*ModelInfor, error) {
|
||||
weights, err := Decode(modelFile)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadInfo() failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w := make(map[string][]int64)
|
||||
var dtype gotch.DType
|
||||
isFirst := true
|
||||
for n, x := range weights {
|
||||
w[n] = x.MustSize()
|
||||
|
||||
if isFirst {
|
||||
dtype = x.DType()
|
||||
isFirst = false
|
||||
}
|
||||
}
|
||||
|
||||
m := NewModelInfor(w, dtype)
|
||||
|
||||
ts.CleanUp()
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/half"
|
||||
)
|
||||
|
||||
// This file implements Pytorch storage data types.
|
||||
|
@ -67,7 +68,8 @@ func (s *HalfStorageClass) New(size int, location string) Storage {
|
|||
|
||||
type HalfStorage struct {
|
||||
BaseStorage
|
||||
Data []float32
|
||||
// Data []float32
|
||||
Data []half.Float16
|
||||
}
|
||||
|
||||
var _ Storage = &HalfStorage{}
|
||||
|
@ -77,7 +79,7 @@ func (s *HalfStorage) SetFromFile(r io.Reader) error {
|
|||
}
|
||||
|
||||
func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]float32, size)
|
||||
data := make([]half.Float16, size)
|
||||
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
|
@ -85,7 +87,7 @@ func (s *HalfStorage) SetFromFileWithSize(r io.Reader, size int) error {
|
|||
return err
|
||||
}
|
||||
u16 := binary.LittleEndian.Uint16(bytes)
|
||||
data[i] = math.Float32frombits(FloatBits16to32(u16))
|
||||
data[i] = half.Float16(u16)
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
|
@ -96,7 +98,7 @@ func (s *HalfStorage) GetData() interface{} {
|
|||
}
|
||||
|
||||
func (s *HalfStorage) DType() gotch.DType {
|
||||
return gotch.Float
|
||||
return gotch.Half
|
||||
}
|
||||
|
||||
func (s *HalfStorage) Device() gotch.Device {
|
||||
|
@ -108,6 +110,62 @@ func (s *HalfStorage) Device() gotch.Device {
|
|||
}
|
||||
}
|
||||
|
||||
// BFloat16Storage:
|
||||
// ================
|
||||
type BFloat16StorageClass struct{}
|
||||
|
||||
var _ StorageClass = &BFloat16StorageClass{}
|
||||
|
||||
func (s *BFloat16StorageClass) New(size int, location string) Storage {
|
||||
return &BFloat16Storage{
|
||||
BaseStorage: BaseStorage{Size: size, Location: location},
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type BFloat16Storage struct {
|
||||
BaseStorage
|
||||
Data []half.BFloat16
|
||||
}
|
||||
|
||||
var _ Storage = &BFloat16Storage{}
|
||||
|
||||
func (s *BFloat16Storage) SetFromFile(r io.Reader) error {
|
||||
return setFromFile(s, r)
|
||||
}
|
||||
|
||||
func (s *BFloat16Storage) SetFromFileWithSize(r io.Reader, size int) error {
|
||||
data := make([]half.BFloat16, size)
|
||||
br := NewLimitedBufferReader(r, size, 2, 512)
|
||||
for i := 0; i < size; i++ {
|
||||
bytes, err := br.ReadNext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u16 := binary.LittleEndian.Uint16(bytes)
|
||||
data[i] = half.BFloat16(u16)
|
||||
}
|
||||
s.Data = data
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BFloat16Storage) GetData() interface{} {
|
||||
return s.Data
|
||||
}
|
||||
|
||||
func (s *BFloat16Storage) DType() gotch.DType {
|
||||
return gotch.BFloat16
|
||||
}
|
||||
|
||||
func (s *BFloat16Storage) Device() gotch.Device {
|
||||
switch s.Location {
|
||||
case "cuda":
|
||||
return gotch.CudaIfAvailable()
|
||||
default:
|
||||
return gotch.CPU
|
||||
}
|
||||
}
|
||||
|
||||
// FloatStorage:
|
||||
// =============
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/bin/bash
|
||||
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-1.11.0}"
|
||||
CUDA_VERSION="${CUDA_VER:-11.3}"
|
||||
LIBTORCH_VERSION="${LIBTORCH_VER:-2.0.1}"
|
||||
CUDA_VERSION="${CUDA_VER:-11.7}"
|
||||
|
||||
if [ "${CUDA_VERSION}" == "cpu" ]; then
|
||||
CU_VERSION="cpu"
|
||||
|
|
|
@ -232,6 +232,7 @@ func (tdi *TextDataIter) Progress() float32 {
|
|||
progress := float32(startIndex) / float32(availableIndices)
|
||||
return progress
|
||||
}
|
||||
|
||||
// Labels returns the number of different `character` (rune) used by the dataset.
|
||||
func (td *TextData) Labels() (retVal int64) {
|
||||
return int64(len(td.CharForLabel))
|
||||
|
@ -281,12 +282,12 @@ func (tdi *TextDataIter) Next() (*Tensor, bool) {
|
|||
indexes := indexesTs.Int64Values()
|
||||
indexesTs.MustDrop()
|
||||
|
||||
var batch []Tensor
|
||||
var batch []*Tensor
|
||||
|
||||
for _, idx := range indexes {
|
||||
narrowIdx := NewNarrow(idx, idx+tdi.SeqLen)
|
||||
idxTs := tdi.Data.Idx(narrowIdx)
|
||||
batch = append(batch, *idxTs)
|
||||
batch = append(batch, idxTs)
|
||||
}
|
||||
|
||||
retVal := MustStack(batch, 0)
|
||||
|
|
|
@ -20,8 +20,8 @@ func TestTextData_NewTextData(t *testing.T) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
txt := `héllo`
|
||||
// txt := "h\xC3\xA9llo"
|
||||
// txt := `héllo`
|
||||
txt := "h\xC3\xA9llo"
|
||||
err = ioutil.WriteFile(filePath, []byte(txt), 0644)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -33,7 +33,7 @@ func TestTextData_NewTextData(t *testing.T) {
|
|||
}
|
||||
|
||||
wantData := []float64{0, 1, 2, 3, 3, 4}
|
||||
gotData := textData.CloneData().Float64Values()
|
||||
gotData := textData.Data.Float64Values()
|
||||
|
||||
if !reflect.DeepEqual(wantData, gotData) {
|
||||
t.Errorf("Want data: %v\n", wantData)
|
||||
|
@ -111,5 +111,4 @@ func TestTextDataIter(t *testing.T) {
|
|||
vals := sum.Int64Values()
|
||||
t.Logf("sum: %v\n", vals)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import "C"
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"unsafe"
|
||||
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
|
@ -41,7 +42,8 @@ func TorchErr() error {
|
|||
cptr := (*C.char)(lib.GetAndResetLastErr())
|
||||
errStr := ptrToString(cptr)
|
||||
if errStr != "" {
|
||||
return fmt.Errorf("Libtorch API Error: %v\n", errStr)
|
||||
trace := string(debug.Stack())
|
||||
return fmt.Errorf("Libtorch API Error: %v\n%v\n", errStr, trace)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -17,7 +17,7 @@ func LoadHwc(path string) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor), nil
|
||||
}
|
||||
|
||||
// SaveHwc save an image from tensor. It expects a tensor of shape [height,
|
||||
|
@ -38,5 +38,5 @@ func ResizeHwc(ts *Tensor, outWidth, outHeight int64) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor), nil
|
||||
}
|
||||
|
|
58
ts/index.go
58
ts/index.go
|
@ -85,26 +85,26 @@ type InsertNewAxis struct{}
|
|||
// NewSelect creates an tensor indexer with given index.
|
||||
// `index` must be in range of tensor dimension. E.g. tensor shape [2,8]
|
||||
// will have size = 2, hence `index` should be in range from [0,2)
|
||||
func NewSelect(index int64) Select {
|
||||
return Select{index}
|
||||
func NewSelect(index int64) *Select {
|
||||
return &Select{index}
|
||||
}
|
||||
|
||||
func NewNarrow(start, end int64) Narrow {
|
||||
return Narrow{Start: start, End: end}
|
||||
func NewNarrow(start, end int64) *Narrow {
|
||||
return &Narrow{Start: start, End: end}
|
||||
}
|
||||
|
||||
func NewIndexSelect(ts *Tensor) IndexSelect {
|
||||
return IndexSelect{Index: ts}
|
||||
func NewIndexSelect(ts *Tensor) *IndexSelect {
|
||||
return &IndexSelect{Index: ts}
|
||||
}
|
||||
|
||||
func NewInsertNewAxis() InsertNewAxis {
|
||||
return InsertNewAxis{}
|
||||
func NewInsertNewAxis() *InsertNewAxis {
|
||||
return &InsertNewAxis{}
|
||||
}
|
||||
|
||||
func NewSliceIndex(sl []int64) IndexSelect {
|
||||
func NewSliceIndex(sl []int64) *IndexSelect {
|
||||
ts := MustOfSlice(sl)
|
||||
|
||||
return IndexSelect{Index: ts}
|
||||
return &IndexSelect{Index: ts}
|
||||
}
|
||||
|
||||
// type SelectFn func(int64)
|
||||
|
@ -120,7 +120,7 @@ func NewSliceIndex(sl []int64) IndexSelect {
|
|||
// )
|
||||
|
||||
type IndexOp interface {
|
||||
Idx(index interface{}) Tensor
|
||||
Idx(index interface{}) *Tensor
|
||||
}
|
||||
|
||||
// implement IndexOp for Tensor:
|
||||
|
@ -129,7 +129,7 @@ type IndexOp interface {
|
|||
// Idx implements `IndexOp` interface for Tensor
|
||||
//
|
||||
// NOTE:
|
||||
// - `index`: expects type `TensorIndexer` or `[]TensorIndexer`
|
||||
// - `index`: expects type `TensorIndexer` or `[]*TensorIndexer`
|
||||
func (ts *Tensor) Idx(index interface{}) (retVal *Tensor) {
|
||||
|
||||
// indexTyp := reflect.TypeOf(index)
|
||||
|
@ -137,8 +137,9 @@ func (ts *Tensor) Idx(index interface{}) (retVal *Tensor) {
|
|||
|
||||
var indexes []TensorIndexer
|
||||
|
||||
switch indexVal.Kind().String() {
|
||||
case "struct": // T: A
|
||||
typ := indexVal.Kind().String()
|
||||
switch typ {
|
||||
case "ptr": // T: A
|
||||
indexes = append(indexes, index.(TensorIndexer))
|
||||
case "slice": // T: []TensorIndexer
|
||||
switch len(index.([]TensorIndexer)) {
|
||||
|
@ -201,7 +202,7 @@ func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal *Tensor, err error)
|
|||
// Make sure number of non-newaxis is not exceed number of dimensions
|
||||
var numNewAxis int = 0
|
||||
for _, ti := range indexSpec {
|
||||
if reflect.TypeOf(ti).Name() == "InsertNewAxis" {
|
||||
if reflect.TypeOf(ti).String() == "*ts.InsertNewAxis" {
|
||||
numNewAxis += 1
|
||||
}
|
||||
}
|
||||
|
@ -218,10 +219,10 @@ func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal *Tensor, err error)
|
|||
|
||||
// Make sure tensor conforms the format
|
||||
for _, spec := range indexSpec {
|
||||
// If `spec` is `IndexSelect` type and
|
||||
if reflect.TypeOf(spec).Name() == "IndexSelect" {
|
||||
if reflect.ValueOf(spec).Kind() == reflect.Struct {
|
||||
inputTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(*Tensor)
|
||||
// If `spec` is `*IndexSelect` type and
|
||||
if reflect.TypeOf(spec).String() == "*ts.IndexSelect" {
|
||||
if reflect.ValueOf(spec).Kind() == reflect.Ptr {
|
||||
inputTensor := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Index").Interface().(*Tensor)
|
||||
|
||||
// 1. Either its input tensor has dimension > 1, throw error.
|
||||
inputTensorShape, err := inputTensor.Size()
|
||||
|
@ -257,33 +258,32 @@ func (ts *Tensor) indexer(indexSpec []TensorIndexer) (retVal *Tensor, err error)
|
|||
|
||||
// `spec` is a function type implements `TensorIndexer`
|
||||
for _, spec := range indexSpec {
|
||||
|
||||
switch reflect.TypeOf(spec).Name() {
|
||||
case "InsertNewAxis":
|
||||
switch reflect.TypeOf(spec).String() {
|
||||
case "*ts.InsertNewAxis":
|
||||
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx + 1
|
||||
case "Select": // 1 field: `Index`
|
||||
index := reflect.ValueOf(spec).FieldByName("Index").Interface().(int64)
|
||||
case "*ts.Select": // 1 field: `Index`
|
||||
index := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Index").Interface().(int64)
|
||||
nextTensor, err = currTensor.Select(currIdx, index, true) // TODO: double-check is `*index` or `index`
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx // not advanced because select() squeezes dimension
|
||||
case "Narrow": // 2 fields: `(Start, End int64)`
|
||||
case "*ts.Narrow": // 2 fields: `(Start, End int64)`
|
||||
// TODO: implement for `Unbounded`, `Included`, `Excluded` ranges
|
||||
// NOTE: for now, just implement (Included(start), Excluded(end))` case
|
||||
start := reflect.ValueOf(spec).FieldByName("Start").Interface().(int64)
|
||||
end := reflect.ValueOf(spec).FieldByName("End").Interface().(int64)
|
||||
start := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Start").Interface().(int64)
|
||||
end := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("End").Interface().(int64)
|
||||
nextTensor, err = currTensor.Narrow(currIdx, start, end-start, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
nextIdx = currIdx + 1
|
||||
case "IndexSelect": // 1 field `(Index *Tensor)`
|
||||
indexTensor := reflect.ValueOf(spec).FieldByName("Index").Interface().(*Tensor)
|
||||
case "*ts.IndexSelect": // 1 field `(Index *Tensor)`
|
||||
indexTensor := reflect.Indirect(reflect.ValueOf(spec)).FieldByName("Index").Interface().(*Tensor)
|
||||
device, err := currTensor.Device()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
|
|
10
ts/init.go
Normal file
10
ts/init.go
Normal file
|
@ -0,0 +1,10 @@
|
|||
package ts
|
||||
|
||||
import (
|
||||
// "runtime/debug"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// debug.SetMemoryLimit()
|
||||
// debug.SetGCPercent(100) // ratio freshly allocated data to live data remaining after previous collection.
|
||||
}
|
|
@ -26,7 +26,7 @@ func (it *Iterable) Next() (item interface{}, ok bool) {
|
|||
}
|
||||
|
||||
var err error
|
||||
switch it.ItemKind.Kind().String() {
|
||||
switch it.ItemKind.GoKind().String() {
|
||||
case "int64":
|
||||
item, err = it.Content.Int64Value([]int64{it.Index})
|
||||
if err != nil {
|
||||
|
|
106
ts/jit.go
106
ts/jit.go
|
@ -22,25 +22,23 @@ type CIValue struct {
|
|||
civalue lib.Civalue
|
||||
}
|
||||
|
||||
type IValueKind struct {
|
||||
reflect.Type
|
||||
}
|
||||
type IValueKind int
|
||||
|
||||
var (
|
||||
NoneVal IValueKind = IValueKind{reflect.TypeOf(nil)}
|
||||
TensorVal IValueKind = IValueKind{reflect.TypeOf(Tensor{})}
|
||||
DoubleVal IValueKind = IValueKind{reflect.TypeOf(float64(1))}
|
||||
IntVal IValueKind = IValueKind{reflect.TypeOf(int64(1))}
|
||||
BoolVal IValueKind = IValueKind{reflect.TypeOf(true)}
|
||||
TupleVal IValueKind = IValueKind{reflect.TypeOf([]IValue{})}
|
||||
IntListVal IValueKind = IValueKind{reflect.TypeOf([]int64{})}
|
||||
DoubleListVal IValueKind = IValueKind{reflect.TypeOf([]float64{})}
|
||||
BoolListVal IValueKind = IValueKind{reflect.TypeOf([]bool{})}
|
||||
StringVal IValueKind = IValueKind{reflect.TypeOf("")}
|
||||
TensorListVal IValueKind = IValueKind{reflect.TypeOf([]Tensor{})}
|
||||
GenericListVal IValueKind = IValueKind{reflect.TypeOf([]IValue{})}
|
||||
GenericDictVal IValueKind = IValueKind{reflect.TypeOf(map[IValue]IValue{})} // 2 elements. ? map[IValue]IValue
|
||||
GenericVal IValueKind = IValueKind{reflect.TypeOf(IValue{})}
|
||||
const (
|
||||
NoneVal IValueKind = iota
|
||||
TensorVal // *Tensor
|
||||
DoubleVal // float64
|
||||
IntVal // int64
|
||||
BoolVal // bool
|
||||
TupleVal // []*IValue
|
||||
IntListVal // []int64
|
||||
DoubleListVal // []float64
|
||||
BoolListVal // []bool
|
||||
StringVal // string
|
||||
TensorListVal // []*Tensor
|
||||
GenericListVal // []*IValue
|
||||
GenericDictVal // map[IValue]IValue - 2 elements
|
||||
GenericVal // *IValue
|
||||
)
|
||||
|
||||
type IValue struct {
|
||||
|
@ -51,7 +49,6 @@ type IValue struct {
|
|||
|
||||
// NewIValue creates a new IValue from given value of various types.
|
||||
func NewIValue(v interface{}) *IValue {
|
||||
|
||||
retVal := &IValue{value: v}
|
||||
if v == nil {
|
||||
retVal.kind = NoneVal
|
||||
|
@ -62,7 +59,7 @@ func NewIValue(v interface{}) *IValue {
|
|||
inputTypeStr := reflect.TypeOf(v).Kind().String()
|
||||
|
||||
switch inputTypeStr {
|
||||
case "Tensor":
|
||||
case "*Tensor":
|
||||
retVal.kind = TensorVal
|
||||
retVal.name = "Tensor"
|
||||
case "float64":
|
||||
|
@ -87,17 +84,7 @@ func NewIValue(v interface{}) *IValue {
|
|||
retVal.kind = StringVal
|
||||
retVal.name = "String"
|
||||
case "slice":
|
||||
fmt.Printf("slice elem type: %q\n", reflect.TypeOf(v).Elem().Kind().String())
|
||||
switch reflect.TypeOf(v).Elem().Kind().String() {
|
||||
case "IValue":
|
||||
switch len(v.([]IValue)) {
|
||||
case 2:
|
||||
retVal.kind = TupleVal
|
||||
retVal.name = "Tuple"
|
||||
default:
|
||||
retVal.kind = GenericListVal
|
||||
retVal.name = "GenericList"
|
||||
}
|
||||
case "int64":
|
||||
retVal.kind = IntListVal
|
||||
retVal.name = "IntList"
|
||||
|
@ -119,31 +106,38 @@ func NewIValue(v interface{}) *IValue {
|
|||
case "bool":
|
||||
retVal.kind = BoolListVal
|
||||
retVal.name = "BoolList"
|
||||
case "struct": // NOTE: only supported `Tensor` type
|
||||
case "ptr": // NOTE: only supported `*Tensor` type
|
||||
val := reflect.Indirect(reflect.ValueOf(v))
|
||||
switch {
|
||||
// 1. Tuple (Tensor, Tensor)
|
||||
case val.Type() == reflect.TypeOf([]Tensor{}) && val.Len() == 2:
|
||||
// 1. Tuple (*Tensor, *Tensor)
|
||||
case val.Type().String() == "[]*ts.Tensor" && val.Len() == 2:
|
||||
retVal.kind = TensorListVal
|
||||
retVal.name = "Tuple"
|
||||
retVal.value = v.([]Tensor)
|
||||
retVal.value = v.([]*Tensor)
|
||||
|
||||
// 2. List (Tensor, Tensor, ...)
|
||||
case val.Type() == reflect.TypeOf([]Tensor{}) && val.Len() > 2:
|
||||
// 2. List (*Tensor, *Tensor, ...)
|
||||
case val.Type().String() == "[]*ts.Tensor" && val.Len() > 2:
|
||||
retVal.kind = TensorListVal
|
||||
retVal.name = "TensorList"
|
||||
retVal.value = v.([]Tensor)
|
||||
retVal.value = v.([]*Tensor)
|
||||
case val.Type().String() == "[]*ts.IValue" && val.Len() == 2:
|
||||
retVal.kind = TupleVal
|
||||
retVal.name = "Tuple"
|
||||
retVal.value = v.([]*IValue)
|
||||
case val.Type().String() == "[]*ts.IValue" && val.Len() > 2, val.Type().String() == "[]*ts.IValue" && val.Len() == 1:
|
||||
retVal.kind = GenericListVal
|
||||
retVal.name = "GenericList"
|
||||
default:
|
||||
log.Fatalf("NewIValue method call - 'slice -> struct' case - Unsupported type (%v)\n", reflect.TypeOf(v).Kind().String())
|
||||
log.Fatalf("NewIValue method call - 'slice -> struct' case - Unsupported type (%v)\n", val.Type().String())
|
||||
}
|
||||
}
|
||||
case "map":
|
||||
// TODO: exclude map of type other than IValue type
|
||||
retVal.kind = GenericDictVal
|
||||
retVal.name = "GenericDict"
|
||||
case "struct":
|
||||
case "ptr":
|
||||
val := reflect.Indirect(reflect.ValueOf(v))
|
||||
fieldName := val.Type().Field(0).Name
|
||||
fieldName := val.Type().Field(2).Name
|
||||
switch fieldName {
|
||||
case "ctensor":
|
||||
retVal.kind = TensorVal
|
||||
|
@ -172,7 +166,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
|
|||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "Tensor":
|
||||
cval := lib.AtiTensor(iv.value.(Tensor).ctensor)
|
||||
cval := lib.AtiTensor(iv.value.(*Tensor).ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -203,7 +197,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
|
|||
case "Tuple":
|
||||
val := reflect.Indirect(reflect.ValueOf(iv.value))
|
||||
switch {
|
||||
// 1. Tuple is (Tensor, Tensor)
|
||||
// 1. Tuple is (*Tensor, *Tensor)
|
||||
case val.Type() == reflect.TypeOf([]Tensor{}):
|
||||
var v []Tensor = iv.value.([]Tensor)
|
||||
var cvals []lib.Civalue
|
||||
|
@ -223,9 +217,9 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
|
|||
}
|
||||
return &CIValue{civalue: tuple}, nil
|
||||
|
||||
// 2. Tuple is (IValue, IValue)
|
||||
// 2. Tuple is (*IValue, *IValue)
|
||||
default:
|
||||
var v []IValue = iv.value.([]IValue)
|
||||
var v []*IValue = iv.value.([]*IValue)
|
||||
var cvals []lib.Civalue
|
||||
for _, i := range v {
|
||||
cval, err := i.ToCIValue()
|
||||
|
@ -328,7 +322,7 @@ func (iv *IValue) ToCIValue() (*CIValue, error) {
|
|||
return &CIValue{civalue: cval}, nil
|
||||
|
||||
case "TensorList":
|
||||
var vals []Tensor = iv.value.([]Tensor)
|
||||
var vals []*Tensor = iv.value.([]*Tensor)
|
||||
var cvals []lib.Ctensor
|
||||
for _, i := range vals {
|
||||
cvals = append(cvals, i.ctensor)
|
||||
|
@ -450,7 +444,7 @@ func IValueFromC(cval *CIValue) (*IValue, error) {
|
|||
return nil, err
|
||||
}
|
||||
return &IValue{
|
||||
value: Tensor{tensor},
|
||||
value: newTensor(tensor),
|
||||
kind: TensorVal,
|
||||
name: "Tensor",
|
||||
}, nil
|
||||
|
@ -518,14 +512,14 @@ func IValueFromC(cval *CIValue) (*IValue, error) {
|
|||
elemName := v.Name()
|
||||
switch elemName {
|
||||
case "Tensor":
|
||||
var vals []Tensor
|
||||
var vals []*Tensor
|
||||
for _, civalue := range civalues {
|
||||
v, err := IValueFromC(&civalue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vals = append(vals, v.Value().(Tensor))
|
||||
vals = append(vals, v.Value().(*Tensor))
|
||||
}
|
||||
if len == 2 {
|
||||
return &IValue{
|
||||
|
@ -725,12 +719,12 @@ func IValueFromC(cval *CIValue) (*IValue, error) {
|
|||
}
|
||||
|
||||
// 3. Get values
|
||||
var tensors []Tensor
|
||||
tensors = append(tensors, Tensor{ctensor: *ptr1})
|
||||
var tensors []*Tensor
|
||||
tensors = append(tensors, newTensor(*ptr1))
|
||||
currPtr := ptr1
|
||||
for i := 1; i < int(len); i++ {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(ptr1)))
|
||||
tensors = append(tensors, Tensor{ctensor: *nextPtr})
|
||||
tensors = append(tensors, newTensor(*nextPtr))
|
||||
currPtr = nextPtr
|
||||
}
|
||||
|
||||
|
@ -1017,7 +1011,7 @@ func ModuleLoadDataOnDevice(stream io.Reader, device gotch.Device) (*CModule, er
|
|||
}
|
||||
|
||||
// ForwardTs performs the forward pass for a model on some specified tensor inputs.
|
||||
func (cm *CModule) ForwardTs(tensors []Tensor) (*Tensor, error) {
|
||||
func (cm *CModule) ForwardTs(tensors []*Tensor) (*Tensor, error) {
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
|
@ -1061,11 +1055,11 @@ func (cm *CModule) ForwardTs(tensors []Tensor) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor), nil
|
||||
}
|
||||
|
||||
// ForwardIs performs the forward pass for a model on some specified ivalue input.
|
||||
func (cm *CModule) ForwardIs(ivalues []IValue) (*IValue, error) {
|
||||
func (cm *CModule) ForwardIs(ivalues []*IValue) (*IValue, error) {
|
||||
|
||||
var civalues []lib.Civalue
|
||||
for _, i := range ivalues {
|
||||
|
@ -1145,7 +1139,7 @@ func (cm *CModule) NamedParameters() ([]NamedTensor, error) {
|
|||
for _, v := range data.NamedCtensors {
|
||||
namedTensor := NamedTensor{
|
||||
Name: v.Name,
|
||||
Tensor: &Tensor{v.Ctensor},
|
||||
Tensor: newTensor(v.Ctensor),
|
||||
}
|
||||
|
||||
namedTensors = append(namedTensors, namedTensor)
|
||||
|
@ -1194,7 +1188,7 @@ func (cm *CModule) SetEval() {
|
|||
// Forwad implements Module interface for CModule.
|
||||
func (cm *CModule) Forward(tensor *Tensor) (*Tensor, error) {
|
||||
|
||||
var tensors []Tensor = []Tensor{*tensor}
|
||||
var tensors []*Tensor = []*Tensor{tensor}
|
||||
return cm.ForwardTs(tensors)
|
||||
}
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ func TestModuleForwardTs(t *testing.T) {
|
|||
ts1 := ts.TensorFrom([]int64{42})
|
||||
ts2 := ts.TensorFrom([]int64{1337})
|
||||
|
||||
res, err := foo.ForwardTs([]ts.Tensor{*ts1, *ts2})
|
||||
res, err := foo.ForwardTs([]*ts.Tensor{ts1, ts2})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
@ -83,17 +83,17 @@ func TestModuleForwardIValue(t *testing.T) {
|
|||
ts1 := ts.TensorFrom([]int64{42})
|
||||
ts2 := ts.TensorFrom([]int64{1337})
|
||||
|
||||
iv1 := ts.NewIValue(*ts1)
|
||||
iv2 := ts.NewIValue(*ts2)
|
||||
iv1 := ts.NewIValue(ts1)
|
||||
iv2 := ts.NewIValue(ts2)
|
||||
|
||||
got, err := foo.ForwardIs([]ts.IValue{*iv1, *iv2})
|
||||
got, err := foo.ForwardIs([]*ts.IValue{iv1, iv2})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
expectedTs1 := ts.TensorFrom([]int64{1421})
|
||||
expectedTs2 := ts.TensorFrom([]int64{-1295})
|
||||
want := ts.NewIValue([]ts.Tensor{*expectedTs1, *expectedTs2})
|
||||
want := ts.NewIValue([]*ts.Tensor{expectedTs1, expectedTs2})
|
||||
|
||||
if !reflect.DeepEqual(want.Name(), got.Name()) {
|
||||
t.Errorf("Expected Ivalue Name: %v\n", want.Name())
|
||||
|
|
15
ts/layout.go
Normal file
15
ts/layout.go
Normal file
|
@ -0,0 +1,15 @@
|
|||
package ts
|
||||
|
||||
// include/c10/core/Layout.h
|
||||
type Layout int8
|
||||
|
||||
const (
|
||||
Strided Layout = iota // 0
|
||||
Sparse // 1
|
||||
SparseCsr // 2
|
||||
Mkldnn // 3
|
||||
SparseCsc // 4
|
||||
SparseBsr // 5
|
||||
SparseBsc // 6
|
||||
NumOptions // 7
|
||||
)
|
File diff suppressed because it is too large
Load Diff
30
ts/npy.go
30
ts/npy.go
|
@ -101,22 +101,22 @@ func (h *NpyHeader) ToString() (string, error) {
|
|||
shape := strings.Join(shapeStr, ",")
|
||||
|
||||
var descr string
|
||||
switch h.descr.Kind().String() {
|
||||
// case "float16": // NOTE. No float16 in Go primary types. TODO. implement
|
||||
// descr = "f2"
|
||||
case "float32":
|
||||
switch h.descr {
|
||||
case gotch.Half:
|
||||
descr = "f2"
|
||||
case gotch.Float:
|
||||
descr = "f4"
|
||||
case "float64":
|
||||
case gotch.Double:
|
||||
descr = "f8"
|
||||
case "int":
|
||||
case gotch.Int:
|
||||
descr = "i4"
|
||||
case "int64":
|
||||
case gotch.Int64:
|
||||
descr = "i8"
|
||||
case "int16":
|
||||
case gotch.Int16:
|
||||
descr = "i2"
|
||||
case "int8":
|
||||
case gotch.Int8:
|
||||
descr = "i1"
|
||||
case "uint8":
|
||||
case gotch.Uint8:
|
||||
descr = "u1"
|
||||
default:
|
||||
err := fmt.Errorf("Unsupported kind: %v\n", h.descr)
|
||||
|
@ -306,6 +306,11 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// NOTE(TT.). case tensor 1 element with shape = []
|
||||
if len(data) > 0 && len(header.shape) == 0 {
|
||||
header.shape = []int64{1}
|
||||
}
|
||||
|
||||
return OfDataSize(data, header.shape, header.descr)
|
||||
}
|
||||
|
||||
|
@ -348,6 +353,11 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// NOTE(TT.). case tensor 1 element with shape = []
|
||||
if len(data) > 0 && len(header.shape) == 0 {
|
||||
header.shape = []int64{1}
|
||||
}
|
||||
|
||||
tensor, err := OfDataSize(data, header.shape, header.descr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -69,7 +69,7 @@ func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (*COptimizer, error
|
|||
}
|
||||
|
||||
// AddParameters adds parameters as a slice of tensors to optimizer
|
||||
func (co *COptimizer) AddParameters(tensors []Tensor) error {
|
||||
func (co *COptimizer) AddParameters(tensors []*Tensor) error {
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
|
@ -127,7 +127,7 @@ func (co *COptimizer) ParamGroupNum() (int64, error) {
|
|||
return ngroup, nil
|
||||
}
|
||||
|
||||
func (co *COptimizer) AddParamGroup(tensors []Tensor) error {
|
||||
func (co *COptimizer) AddParamGroup(tensors []*Tensor) error {
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
|
|
10
ts/other.go
10
ts/other.go
|
@ -2,17 +2,14 @@ package ts
|
|||
|
||||
// Other tensor methods
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
|
||||
func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
|
||||
weight := NewTensor()
|
||||
reduction := int64(1) // Mean of loss
|
||||
ignoreIndex := int64(-100)
|
||||
|
||||
logSm := ts.MustLogSoftmax(-1, gotch.Float, false)
|
||||
dtype := ts.DType()
|
||||
logSm := ts.MustLogSoftmax(-1, dtype, false)
|
||||
return logSm.MustNllLoss(targets, weight, reduction, ignoreIndex, true)
|
||||
}
|
||||
|
||||
|
@ -21,7 +18,8 @@ func (ts *Tensor) CrossEntropyForLogits(targets *Tensor) (retVal *Tensor) {
|
|||
func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
|
||||
argmax := ts.MustArgmax([]int64{-1}, false, false)
|
||||
eq1 := argmax.MustEqTensor(targets, true)
|
||||
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float, true)
|
||||
dtype := ts.DType()
|
||||
return eq1.MustTotype(dtype, true).MustMean(dtype, true)
|
||||
}
|
||||
|
||||
func (ts *Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal *Tensor) {
|
||||
|
|
|
@ -43,3 +43,20 @@ func ExampleTensorSplitWithSizes(t *testing.T) {
|
|||
// 8 9
|
||||
// [ CPUFloatType{4,2} ]
|
||||
}
|
||||
|
||||
// Test Unbind op specific for BFloat16/Half
|
||||
func TestTensorUnbind(t *testing.T) {
|
||||
// device := gotch.CudaIfAvailable()
|
||||
device := gotch.CPU
|
||||
|
||||
dtype := gotch.BFloat16
|
||||
// dtype := gotch.Half // <- NOTE. Libtorch API Error: "arange_cpu" not implemented for 'Half'
|
||||
|
||||
x := ts.MustArange(ts.IntScalar(60), dtype, device).MustView([]int64{3, 4, 5}, true)
|
||||
|
||||
out := x.MustUnbind(0, true)
|
||||
|
||||
if len(out) != 3 {
|
||||
t.Errorf("Want 3, got %v\n", len(out))
|
||||
}
|
||||
}
|
||||
|
|
125
ts/patch.go
125
ts/patch.go
|
@ -11,9 +11,9 @@ import (
|
|||
)
|
||||
|
||||
// NOTE. This is a temporarily patched to make it run.
|
||||
// TODO. make change at generator for []Tensor input
|
||||
// TODO. make change at generator for []*Tensor input
|
||||
|
||||
func (ts *Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor, err error) {
|
||||
func (ts *Tensor) Lstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor, err error) {
|
||||
|
||||
// NOTE: `atg_lstm` will create 3 consecutive Ctensors in memory of C land. The first
|
||||
// Ctensor will have address given by `ctensorPtr1` here.
|
||||
|
@ -55,11 +55,11 @@ func (ts *Tensor) Lstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, num
|
|||
return output, h, c, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, &Tensor{ctensor: *ctensorPtr3}, nil
|
||||
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), newTensor(*ctensorPtr3), nil
|
||||
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor) {
|
||||
func (ts *Tensor) MustLstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor) {
|
||||
output, h, c, err := ts.Lstm(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
||||
|
||||
if err != nil {
|
||||
|
@ -69,7 +69,7 @@ func (ts *Tensor) MustLstm(hxData []Tensor, paramsData []Tensor, hasBiases bool,
|
|||
return output, h, c
|
||||
}
|
||||
|
||||
func (ts *Tensor) Gru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor, err error) {
|
||||
func (ts *Tensor) Gru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor, err error) {
|
||||
|
||||
// NOTE: `atg_gru` will create 2 consecutive Ctensors in memory of C land.
|
||||
// The first Ctensor will have address given by `ctensorPtr1` here.
|
||||
|
@ -105,11 +105,11 @@ func (ts *Tensor) Gru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers
|
|||
return output, h, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, nil
|
||||
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
|
||||
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustGru(hx *Tensor, paramsData []Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor) {
|
||||
func (ts *Tensor) MustGru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor) {
|
||||
output, h, err := ts.Gru(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -139,7 +139,7 @@ func (ts *Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2
|
|||
return ts1, ts2, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor: *ctensorPtr1}, &Tensor{ctensor: *ctensorPtr2}, nil
|
||||
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor) {
|
||||
|
@ -169,7 +169,7 @@ func (ts *Tensor) NLLLoss(target *Tensor, del bool) (retVal *Tensor, err error)
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
retVal = newTensor(*ptr)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
@ -200,7 +200,7 @@ func (ts *Tensor) MustNLLLoss(target *Tensor, del bool) (retVal *Tensor) {
|
|||
// tensor *atg_where(tensor condition);
|
||||
|
||||
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||
func AlignTensors(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
|
@ -213,21 +213,21 @@ func AlignTensors(tensors []Tensor) (retVal []Tensor, err error) {
|
|||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustAlignTensors(tensors []Tensor, del bool) (retVal []Tensor) {
|
||||
func MustAlignTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
for _, t := range tensors {
|
||||
defer t.MustDrop()
|
||||
|
@ -242,7 +242,7 @@ func MustAlignTensors(tensors []Tensor, del bool) (retVal []Tensor) {
|
|||
}
|
||||
|
||||
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||
func BroadcastTensors(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
|
@ -255,21 +255,21 @@ func BroadcastTensors(tensors []Tensor) (retVal []Tensor, err error) {
|
|||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) {
|
||||
func MustBroadcastTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
for _, t := range tensors {
|
||||
defer t.MustDrop()
|
||||
|
@ -285,29 +285,28 @@ func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) {
|
|||
}
|
||||
|
||||
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
|
||||
func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) {
|
||||
func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) {
|
||||
ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor) {
|
||||
func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
@ -321,7 +320,7 @@ func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor)
|
|||
}
|
||||
|
||||
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
||||
func Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
|
@ -334,21 +333,21 @@ func Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
|
|||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustMeshgrid(tensors []Tensor) (retVal []Tensor) {
|
||||
func MustMeshgrid(tensors []*Tensor) (retVal []*Tensor) {
|
||||
retVal, err := Meshgrid(tensors)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -358,29 +357,28 @@ func MustMeshgrid(tensors []Tensor) (retVal []Tensor) {
|
|||
}
|
||||
|
||||
// tensor *atg_nonzero_numpy(tensor self);
|
||||
func (ts *Tensor) NonzeroNumpy() (retVal []Tensor, err error) {
|
||||
|
||||
func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) {
|
||||
ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
|
||||
func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
@ -396,10 +394,11 @@ func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
|
|||
// Split splits tensor into chunks
|
||||
//
|
||||
// Parameters:
|
||||
// - splitSize – size of a single chunk
|
||||
// - dim – dimension along which to split the tensor.
|
||||
// - splitSize – size of a single chunk
|
||||
// - dim – dimension along which to split the tensor.
|
||||
//
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
||||
func (ts *Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
|
||||
func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) {
|
||||
|
||||
ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -411,22 +410,21 @@ func (ts *Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
|
|||
// calculated from there. The vector of tensors will end if the calculated
|
||||
// pointer value is `null`.
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
|
||||
func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
@ -442,10 +440,11 @@ func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
|
|||
// SplitWithSizes splits tensor into chunks
|
||||
//
|
||||
// Parameters:
|
||||
// - splitSizes – slice of sizes for each chunk
|
||||
// - dim – dimension along which to split the tensor.
|
||||
// - splitSizes – slice of sizes for each chunk
|
||||
// - dim – dimension along which to split the tensor.
|
||||
//
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
||||
func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor, err error) {
|
||||
func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tensor, err error) {
|
||||
|
||||
ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -457,22 +456,21 @@ func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor
|
|||
// calculated from there. The vector of tensors will end if the calculated
|
||||
// pointer value is `null`.
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor) {
|
||||
func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
@ -486,29 +484,29 @@ func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (r
|
|||
}
|
||||
|
||||
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||
func (ts *Tensor) Unbind(dim int64) (retVal []Tensor, err error) {
|
||||
|
||||
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
|
||||
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) {
|
||||
func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
@ -522,29 +520,28 @@ func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) {
|
|||
}
|
||||
|
||||
// tensor *atg_where(tensor condition);
|
||||
func Where(condition Tensor) (retVal []Tensor, err error) {
|
||||
|
||||
func Where(condition Tensor) (retVal []*Tensor, err error) {
|
||||
ctensorsPtr := lib.AtgWhere(condition.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
retVal = append(retVal, newTensor(*currentPtr))
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Add(unsafe.Pointer(currentPtr), unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
retVal = append(retVal, newTensor(*nextPtr))
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
|
||||
func MustWhere(condition Tensor, del bool) (retVal []*Tensor) {
|
||||
if del {
|
||||
defer condition.MustDrop()
|
||||
}
|
||||
|
|
46
ts/print.go
46
ts/print.go
|
@ -11,41 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func (ts *Tensor) ValueGo() interface{} {
|
||||
dtype := ts.DType()
|
||||
numel := ts.Numel()
|
||||
var dst interface{}
|
||||
switch dtype {
|
||||
case gotch.Uint8:
|
||||
dst = make([]uint8, numel)
|
||||
case gotch.Int8:
|
||||
dst = make([]int8, numel)
|
||||
case gotch.Int16:
|
||||
dst = make([]int16, numel)
|
||||
case gotch.Int:
|
||||
dst = make([]int32, numel)
|
||||
case gotch.Int64:
|
||||
dst = make([]int64, numel)
|
||||
case gotch.Float:
|
||||
dst = make([]float32, numel)
|
||||
case gotch.Double:
|
||||
dst = make([]float64, numel)
|
||||
case gotch.Bool:
|
||||
dst = make([]bool, numel)
|
||||
default:
|
||||
err := fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
||||
log.Fatal(err)
|
||||
}
|
||||
err := ts.CopyData(dst, ts.Numel())
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// convert []int32 -> int
|
||||
if reflect.TypeOf(dst).String() == "[]int32" {
|
||||
dst = sliceInt32ToInt(dst.([]int32))
|
||||
}
|
||||
|
||||
return dst
|
||||
return ts.Vals()
|
||||
}
|
||||
|
||||
func shapeToSize(shape []int64) int {
|
||||
|
@ -250,8 +216,8 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
|
|||
// var typ T
|
||||
typ := ts.DType()
|
||||
|
||||
switch typ.String() {
|
||||
case "float32", "float64":
|
||||
switch typ {
|
||||
case gotch.Half, gotch.BFloat16, gotch.Float, gotch.Double:
|
||||
switch f.verb {
|
||||
case 'f', 'e', 'E', 'G', 'b':
|
||||
// accepted. Do nothing
|
||||
|
@ -259,7 +225,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
|
|||
f.verb = 'g'
|
||||
}
|
||||
|
||||
case "uint8", "int8", "int16", "int32", "int64":
|
||||
case gotch.Uint8, gotch.Int8, gotch.Int16, gotch.Int, gotch.Int64:
|
||||
switch f.verb {
|
||||
case 'b':
|
||||
f.base = 2
|
||||
|
@ -273,7 +239,7 @@ func (f *fmtState) fmtVerb(ts *Tensor) {
|
|||
f.base = 10
|
||||
f.verb = 'd'
|
||||
}
|
||||
case "bool":
|
||||
case gotch.Bool:
|
||||
f.verb = 't'
|
||||
default:
|
||||
f.verb = 'v'
|
||||
|
@ -318,7 +284,7 @@ func (ts *Tensor) Format(s fmt.State, verb rune) {
|
|||
shape := toSliceInt(ts.MustSize())
|
||||
strides := shapeToStrides(shape)
|
||||
device := ts.MustDevice()
|
||||
dtype := ts.DType().String()
|
||||
dtype := ts.DType()
|
||||
defined := ts.MustDefined()
|
||||
if verb == 'i' {
|
||||
fmt.Fprintf(
|
||||
|
|
84
ts/scalar.go
84
ts/scalar.go
|
@ -1,26 +1,87 @@
|
|||
package ts
|
||||
|
||||
import (
|
||||
// "unsafe"
|
||||
"fmt"
|
||||
"log"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
type Scalar struct {
|
||||
name string
|
||||
cscalar lib.Cscalar
|
||||
}
|
||||
|
||||
// free releases C allocated memory.
|
||||
func freeCScalar(x *Scalar) error {
|
||||
nbytes := x.nbytes()
|
||||
atomic.AddInt64(&AllocatedMem, -nbytes)
|
||||
lock.Lock()
|
||||
delete(ExistingScalars, x.name)
|
||||
lock.Unlock()
|
||||
|
||||
lib.AtsFree(x.cscalar)
|
||||
if err := TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if gotch.Debug {
|
||||
log.Printf("INFO: Released scalar %q - C memory: %d bytes.\n", x.name, nbytes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newScalarName(nameOpt ...string) string {
|
||||
var name string
|
||||
if len(nameOpt) > 0 {
|
||||
name = nameOpt[0]
|
||||
} else {
|
||||
name = fmt.Sprintf("tensor_%09d", TensorCount)
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
func newScalar(cscalar lib.Cscalar, nameOpt ...string) *Scalar {
|
||||
x := &Scalar{
|
||||
cscalar: cscalar,
|
||||
name: newName(nameOpt...),
|
||||
}
|
||||
|
||||
atomic.AddInt64(&ScalarCount, 1)
|
||||
nbytes := x.nbytes()
|
||||
atomic.AddInt64(&AllocatedMem, nbytes)
|
||||
lock.Lock()
|
||||
ExistingScalars[x.name] = struct{}{}
|
||||
lock.Unlock()
|
||||
|
||||
if gotch.Debug {
|
||||
log.Printf("INFO: scalar %q added - Allocated memory (%d bytes).\n", x.name, nbytes)
|
||||
}
|
||||
|
||||
runtime.SetFinalizer(x, freeCScalar)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
func (sc *Scalar) nbytes() int64 {
|
||||
return 4 // either Int64 or Float64 scalar -> 4 bytes
|
||||
}
|
||||
|
||||
// IntScalar creates a integer scalar
|
||||
func IntScalar(v int64) *Scalar {
|
||||
cscalar := lib.AtsInt(v)
|
||||
return &Scalar{cscalar}
|
||||
return newScalar(cscalar)
|
||||
}
|
||||
|
||||
// FloatScalar creates a float scalar
|
||||
func FloatScalar(v float64) *Scalar {
|
||||
cscalar := lib.AtsFloat(v)
|
||||
return &Scalar{cscalar}
|
||||
return newScalar(cscalar)
|
||||
}
|
||||
|
||||
// ToInt returns a integer value
|
||||
|
@ -61,14 +122,17 @@ func (sc *Scalar) ToString() (retVal string, err error) {
|
|||
// TODO: Really? after running s.Drop() and s.ToInt()
|
||||
// it returns Zero.
|
||||
func (sc *Scalar) Drop() (err error) {
|
||||
lib.AtsFree(sc.cscalar)
|
||||
return TorchErr()
|
||||
// TODO. FIXME either remove or rewind for specific scenario
|
||||
return nil
|
||||
// lib.AtsFree(sc.cscalar)
|
||||
// return TorchErr()
|
||||
}
|
||||
|
||||
func (sc *Scalar) MustDrop() {
|
||||
lib.AtsFree(sc.cscalar)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// TODO. FIXME either remove or rewind for specific scenario
|
||||
return
|
||||
// lib.AtsFree(sc.cscalar)
|
||||
// if err := TorchErr(); err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
}
|
||||
|
|
24876
ts/tensor-generated.go
24876
ts/tensor-generated.go
File diff suppressed because it is too large
Load Diff
638
ts/tensor.go
638
ts/tensor.go
|
@ -11,34 +11,194 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
type Tensor struct {
|
||||
ctensor lib.Ctensor
|
||||
}
|
||||
var (
|
||||
TensorCount int64 // incremental counting created tensors
|
||||
ScalarCount int64 // incremental counting created scalars
|
||||
AllocatedMem int64 // bytes - keeping track of memory created and still occupied by gotch/tensor (excluding mem allocated by libtorch at C side)
|
||||
|
||||
func (ts *Tensor) Ctensor() unsafe.Pointer {
|
||||
return unsafe.Pointer(ts.ctensor)
|
||||
}
|
||||
ExistingTensors map[string]struct{} = make(map[string]struct{}) // keep track of existing tensors by name
|
||||
ExistingScalars map[string]struct{} = make(map[string]struct{}) // keep track of existing scalar by name
|
||||
lock sync.Mutex
|
||||
)
|
||||
|
||||
// None is an undefined tensor.
|
||||
// NOTE. None is an undefined tensor.
|
||||
// It can be used in optional tensor parameter where 'None' value used.
|
||||
// `ts.MustDefined()` function is used for checking 'null'
|
||||
var None = NewTensor()
|
||||
|
||||
type bigStruct struct {
|
||||
lots [1e5]byte // 100k - always on host memory.
|
||||
}
|
||||
|
||||
// Tensor is a Go wrapper of a "C tensor pointer" - 8 Bytes (64-bits OS)
|
||||
// or 4 Bytes (32-bits OS).
|
||||
// `ctensor` is just a "C pointer" to `torch::Tensor` (torch::Tensor *lib.Ctensor)
|
||||
//
|
||||
// NOTE.Tensor should be big enough to be in heap memory.
|
||||
// (yes, we choose to place tensor consistently in heap memory so that
|
||||
// it can be targeted by Go garbage collector).
|
||||
//
|
||||
// For heap allocation see. https://stackoverflow.com/questions/10866195
|
||||
type Tensor struct {
|
||||
d *bigStruct
|
||||
name string
|
||||
ctensor lib.Ctensor
|
||||
calledFrom string
|
||||
}
|
||||
|
||||
func newTensor(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
|
||||
if len(nameOpt) == 0 {
|
||||
nameOpt = []string{}
|
||||
}
|
||||
name := newName(nameOpt...)
|
||||
|
||||
x := new(Tensor)
|
||||
x.ctensor = ctensor
|
||||
x.d = new(bigStruct)
|
||||
|
||||
atomic.AddInt64(&TensorCount, 1)
|
||||
nbytes := x.nbytes()
|
||||
atomic.AddInt64(&AllocatedMem, nbytes)
|
||||
lock.Lock()
|
||||
if _, ok := ExistingTensors[name]; ok {
|
||||
name = fmt.Sprintf("%s_%09d", name, TensorCount)
|
||||
}
|
||||
ExistingTensors[name] = struct{}{}
|
||||
lock.Unlock()
|
||||
|
||||
x.name = name
|
||||
|
||||
if gotch.Debug {
|
||||
log.Printf("INFO: Added tensor %q - Allocated memory: %d bytes.\n", x.name, nbytes)
|
||||
}
|
||||
|
||||
x.calledFrom = "newTensor()"
|
||||
|
||||
runtime.SetFinalizer(x, freeCTensor)
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// New creates new tensor from C tensor.
|
||||
func New(ctensor lib.Ctensor, nameOpt ...string) *Tensor {
|
||||
return newTensor(ctensor, nameOpt...)
|
||||
}
|
||||
|
||||
func CheckCMemLeak() string {
|
||||
tensors := []string{}
|
||||
lock.Lock()
|
||||
for n := range ExistingTensors {
|
||||
tensors = append(tensors, n)
|
||||
}
|
||||
memUsed := AllocatedMem
|
||||
lock.Unlock()
|
||||
|
||||
var msg string
|
||||
msg += fmt.Sprintf("============================= C MEMORY CHECK RESULT ==================================\n")
|
||||
msg += fmt.Sprintf("C memory allocated not been released: %v bytes\n", memUsed)
|
||||
msg += fmt.Sprintf("Tensors not been released: %q\n", tensors)
|
||||
msg += fmt.Sprintf("======================================================================================\n")
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// CleanUp calls double runtime.GC() with sleep time in between.
|
||||
func CleanUp(sleepTimeOpt ...int) {
|
||||
sleepTime := time.Duration(1000) // 1 second
|
||||
if len(sleepTimeOpt) > 0 {
|
||||
sleepTime = time.Duration(sleepTimeOpt[0])
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
time.Sleep(time.Millisecond * sleepTime)
|
||||
runtime.GC()
|
||||
}
|
||||
|
||||
// Ctensor return C pointer value.
|
||||
func (ts *Tensor) Ctensor() unsafe.Pointer {
|
||||
return unsafe.Pointer(ts.ctensor)
|
||||
}
|
||||
|
||||
// free releases C allocated memory.
|
||||
func freeCTensor(ts *Tensor) error {
|
||||
if ts == nil || ts.ctensor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
if _, ok := ExistingTensors[ts.name]; !ok {
|
||||
log.Printf("WARNING: Probably double free tensor %q. Called from %q. Just skipping...\n", ts.name, ts.calledFrom)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
if gotch.Debug {
|
||||
shape, err := ts.Size()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ERROR: failed to release tensor %q: %w\n", ts.name, err)
|
||||
}
|
||||
log.Printf(err.Error())
|
||||
|
||||
numel := uint(FlattenDim(shape))
|
||||
dtype := ts.DType()
|
||||
nbytes := int64(numel * dtype.Size())
|
||||
atomic.AddInt64(&AllocatedMem, -nbytes)
|
||||
|
||||
log.Printf("INFO: Released tensor %q - C memory(%d bytes).\n", ts.name, nbytes)
|
||||
}
|
||||
|
||||
lib.AtFree(ts.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
err := fmt.Errorf("ERROR: failed to release tensor %q - %w", ts.name, err)
|
||||
return err
|
||||
}
|
||||
|
||||
delete(ExistingTensors, ts.name)
|
||||
|
||||
// IMPORTANT. make it nil so won't double free.
|
||||
ts.ctensor = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newName(nameOpt ...string) string {
|
||||
var name string
|
||||
if len(nameOpt) > 0 {
|
||||
name = nameOpt[0]
|
||||
} else {
|
||||
name = fmt.Sprintf("tensor_%09d", TensorCount)
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
// NewTensor creates a new tensor
|
||||
func NewTensor() *Tensor {
|
||||
func NewTensor(nameOpt ...string) *Tensor {
|
||||
ctensor := lib.AtNewTensor()
|
||||
return &Tensor{ctensor}
|
||||
|
||||
return newTensor(ctensor, nameOpt...)
|
||||
}
|
||||
|
||||
func FromCtensor(ctensor unsafe.Pointer) *Tensor {
|
||||
cts := (lib.Ctensor)(ctensor)
|
||||
return &Tensor{cts}
|
||||
|
||||
return newTensor(cts)
|
||||
}
|
||||
|
||||
func (ts *Tensor) Name() string {
|
||||
return ts.name
|
||||
}
|
||||
|
||||
func (ts *Tensor) Dim() uint64 {
|
||||
|
@ -56,6 +216,11 @@ func (ts *Tensor) Dim() uint64 {
|
|||
// to that slice.
|
||||
func (ts *Tensor) Size() ([]int64, error) {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
if dim < 0 || dim > 100 {
|
||||
err := fmt.Errorf("Invalid dim: %v\n", dim)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sz := make([]int64, dim)
|
||||
szPtr, err := DataAsPtr(sz)
|
||||
if err != nil {
|
||||
|
@ -82,6 +247,34 @@ func (ts *Tensor) MustSize() []int64 {
|
|||
return shape
|
||||
}
|
||||
|
||||
func (ts *Tensor) Stride() ([]int64, error) {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
sz := make([]int64, dim)
|
||||
szPtr, err := DataAsPtr(sz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer C.free(unsafe.Pointer(szPtr))
|
||||
|
||||
lib.AtStride(ts.ctensor, szPtr)
|
||||
if err = TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
strides := decodeSize(szPtr, dim)
|
||||
|
||||
return strides, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustStride() []int64 {
|
||||
strides, err := ts.Stride()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return strides
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for 1D tensors.
|
||||
func (ts *Tensor) Size1() (int64, error) {
|
||||
shape, err := ts.Size()
|
||||
|
@ -142,16 +335,19 @@ func (ts *Tensor) Size4() ([]int64, error) {
|
|||
return shape, nil
|
||||
}
|
||||
|
||||
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
||||
// Decode sz
|
||||
// 1. Count number of elements in data
|
||||
elementNum := nsize
|
||||
// 2. Element size in bytes
|
||||
eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
// nbytes calculates tensor data size in bytes.
|
||||
func (ts *Tensor) nbytes() int64 {
|
||||
numel := ts.Numel()
|
||||
if numel == 0 {
|
||||
return 0 // ts.None
|
||||
}
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
|
||||
return int64(numel * ts.DType().Size())
|
||||
}
|
||||
|
||||
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
||||
dtype := gotch.Int64 // tensor size dtype = int64
|
||||
nbytes := int(nsize) * int(dtype.Size())
|
||||
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
|
||||
r := bytes.NewReader(dataSlice)
|
||||
dataIn := make([]int64, nsize)
|
||||
|
@ -162,20 +358,54 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
|||
return dataIn
|
||||
}
|
||||
|
||||
// TensorOptions constructs options to build/rebuild tensor.
|
||||
type TensorOptions struct {
|
||||
Name string
|
||||
DType gotch.DType
|
||||
Quantized bool
|
||||
// TODO. can expand as needed
|
||||
}
|
||||
|
||||
type TensorOpt func(*TensorOptions)
|
||||
|
||||
func DefaultTensorOptions() *TensorOptions {
|
||||
return &TensorOptions{
|
||||
Name: "",
|
||||
DType: gotch.Float,
|
||||
Quantized: false,
|
||||
}
|
||||
}
|
||||
|
||||
func WithName(v string) TensorOpt {
|
||||
return func(o *TensorOptions) {
|
||||
o.Name = v
|
||||
}
|
||||
}
|
||||
|
||||
func WithDType(v gotch.DType) TensorOpt {
|
||||
return func(o *TensorOptions) {
|
||||
o.DType = v
|
||||
}
|
||||
}
|
||||
|
||||
func WithQuantized(v bool) TensorOpt {
|
||||
return func(o *TensorOptions) {
|
||||
o.Quantized = v
|
||||
}
|
||||
}
|
||||
|
||||
// OfSlice creates tensor from a slice data
|
||||
func OfSlice(data interface{}) (*Tensor, error) {
|
||||
func OfSlice(data interface{}, opts ...TensorOpt) (*Tensor, error) {
|
||||
o := DefaultTensorOptions()
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
// convert []int -> int32. `binary.Write()` can't write `[]int` because it's not fixed-size!
|
||||
if reflect.TypeOf(data).String() == "[]int" {
|
||||
data = sliceIntToInt32(data.([]int))
|
||||
}
|
||||
|
||||
/*
|
||||
typ, dataLen, err := DataCheck(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
*/
|
||||
|
||||
v := reflect.ValueOf(data)
|
||||
kind := v.Kind().String()
|
||||
if kind != "slice" && kind != "array" {
|
||||
|
@ -183,10 +413,10 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
typ := reflect.TypeOf(data).Elem()
|
||||
elementKind := reflect.TypeOf(data).Elem().Kind()
|
||||
dataLen := v.Len()
|
||||
|
||||
dtype, err := gotch.ToDType(typ)
|
||||
dtype, err := gotch.GoKind2DType(elementKind, gotch.HalfDTypePref(o.DType), gotch.WithQuantized(o.Quantized))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -194,12 +424,7 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
|||
shape := []int64{int64(dataLen)}
|
||||
elementNum := ElementCount(shape)
|
||||
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
nbytes := int(dtype.Size()) * elementNum
|
||||
|
||||
dataPtr, buff := CMalloc(nbytes)
|
||||
defer C.free(unsafe.Pointer(dataPtr))
|
||||
|
@ -208,29 +433,25 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(dtype.Size()), int(dtype.CKind()))
|
||||
if err = TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor, o.Name), nil
|
||||
// return newTensor(ctensor), nil
|
||||
}
|
||||
|
||||
// OfDataSize creates Tensor from input byte data, shape and dtype.
|
||||
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
|
||||
|
||||
elementNum := ElementCount(shape)
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func OfDataSize(data []byte, shape []int64, dtype gotch.DType, opts ...TensorOpt) (*Tensor, error) {
|
||||
o := DefaultTensorOptions()
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
elementNum := ElementCount(shape)
|
||||
|
||||
nbytes := elementNum * int(dtype.Size())
|
||||
|
||||
if nbytes != len(data) {
|
||||
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
|
||||
|
@ -244,23 +465,19 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
if err = TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor, o.Name), nil
|
||||
// return newTensor(ctensor), nil
|
||||
}
|
||||
|
||||
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
||||
// or panic if error
|
||||
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
|
||||
ts, err := OfDataSize(data, size, dtype)
|
||||
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType, opts ...TensorOpt) *Tensor {
|
||||
ts, err := OfDataSize(data, size, dtype, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -269,8 +486,8 @@ func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
|
|||
}
|
||||
|
||||
// MustOfSlice create a tensor from slice of data. It will be panic if error.
|
||||
func MustOfSlice(data interface{}) *Tensor {
|
||||
ts, err := OfSlice(data)
|
||||
func MustOfSlice(data interface{}, opts ...TensorOpt) *Tensor {
|
||||
ts, err := OfSlice(data, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -279,8 +496,8 @@ func MustOfSlice(data interface{}) *Tensor {
|
|||
}
|
||||
|
||||
// TensorFrom create a tensor from slice of data. It will be panic if error.
|
||||
func TensorFrom(data interface{}) *Tensor {
|
||||
ts, err := OfSlice(data)
|
||||
func TensorFrom(data interface{}, opts ...TensorOpt) *Tensor {
|
||||
ts, err := OfSlice(data, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -299,7 +516,12 @@ func (ts *Tensor) Print() {
|
|||
}
|
||||
|
||||
// NewTensorFromData creates tensor from given data and shape
|
||||
func NewTensorFromData(data interface{}, shape []int64) (*Tensor, error) {
|
||||
func NewTensorFromData(data interface{}, shape []int64, opts ...TensorOpt) (*Tensor, error) {
|
||||
o := DefaultTensorOptions()
|
||||
for _, opt := range opts {
|
||||
opt(o)
|
||||
}
|
||||
|
||||
// 1. Check whether data and shape match
|
||||
elementNum, err := DataDim(data)
|
||||
if err != nil {
|
||||
|
@ -326,34 +548,18 @@ func NewTensorFromData(data interface{}, shape []int64) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
// defer C.free(unsafe.Pointer(ctensor))
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), dtype.Size(), int(dtype.CKind()))
|
||||
if err = TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor, o.Name), nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) DType() gotch.DType {
|
||||
cint := lib.AtScalarType(ts.ctensor)
|
||||
|
||||
dtype, err := gotch.CInt2DType(cint)
|
||||
if err != nil {
|
||||
log.Fatalf("Tensor DType error: %v\n", err)
|
||||
}
|
||||
|
||||
return dtype
|
||||
return gotch.CKind2DType(cint)
|
||||
}
|
||||
|
||||
func (ts *Tensor) Device() (gotch.Device, error) {
|
||||
|
@ -409,6 +615,7 @@ func (ts *Tensor) MustDevice() gotch.Device {
|
|||
* return retVal
|
||||
* }
|
||||
* */
|
||||
|
||||
// Float64Value returns a float value on tensors holding a single element.
|
||||
// An error is returned otherwise.
|
||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
|
@ -498,6 +705,15 @@ func (ts *Tensor) DataPtr() (unsafe.Pointer, error) {
|
|||
return datPtr, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustDataPtr() unsafe.Pointer {
|
||||
p, err := ts.DataPtr()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
// Defined returns true is the tensor is defined.
|
||||
func (ts *Tensor) Defined() (bool, error) {
|
||||
state := lib.AtDefined(ts.ctensor)
|
||||
|
@ -528,6 +744,52 @@ func (ts *Tensor) IsSparse() (bool, error) {
|
|||
|
||||
return state, nil
|
||||
}
|
||||
func (ts *Tensor) MustIsSparse() bool {
|
||||
state, err := ts.IsSparse()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
// IsContiguous returns true is the tensor is contiguous.
|
||||
func (ts *Tensor) IsContiguous() (bool, error) {
|
||||
state := lib.AtIsContiguous(ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
func (ts *Tensor) MustIsContiguous() bool {
|
||||
state, err := ts.IsContiguous()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
// IsMkldnn returns true is the tensor is mkldnn.
|
||||
func (ts *Tensor) IsMkldnn() (bool, error) {
|
||||
state := lib.AtIsMkldnn(ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
func (ts *Tensor) MustIsMkldnn() bool {
|
||||
state, err := ts.IsMkldnn()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
// ZeroGrad zeroes the gradient tensor attached to this tensor if defined.
|
||||
func (ts *Tensor) ZeroGrad() {
|
||||
|
@ -558,7 +820,7 @@ func (ts *Tensor) MustBackward() {
|
|||
}
|
||||
|
||||
// RunBackward runs the backward ...
|
||||
func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraphB bool) ([]Tensor, error) {
|
||||
func RunBackward(tensors []*Tensor, inputs []*Tensor, keepGraphB bool, createGraphB bool) ([]*Tensor, error) {
|
||||
// NOTE: outputs is a slice of tensors with length = len(inputs)
|
||||
var outputsPtr []*lib.Ctensor
|
||||
// Are they allocated contigously??? Definitely not.
|
||||
|
@ -599,10 +861,10 @@ func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraph
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var oTensors []Tensor
|
||||
var oTensors []*Tensor
|
||||
for i := 0; i < len(inputs); i++ {
|
||||
outputPtr := outputsPtr[i]
|
||||
oTensors = append(oTensors, Tensor{ctensor: *outputPtr})
|
||||
oTensors = append(oTensors, newTensor(*outputPtr))
|
||||
}
|
||||
|
||||
return oTensors, nil
|
||||
|
@ -612,7 +874,6 @@ func RunBackward(tensors []Tensor, inputs []Tensor, keepGraphB bool, createGraph
|
|||
//
|
||||
// NOTE: `dst` located in Go memory. Should it be?
|
||||
func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
|
||||
|
||||
// NOTE: we must make sure that `dst` has same len as `numel`. Otherwise,
|
||||
// there will be memory leak and or out of range error.
|
||||
if len(dst) < int(numel) {
|
||||
|
@ -621,12 +882,9 @@ func (ts *Tensor) CopyDataUint8(dst []uint8, numel uint) error {
|
|||
}
|
||||
|
||||
vs := unsafe.Pointer(&dst[0])
|
||||
elt_size_in_bytes, err := gotch.DTypeSize(gotch.Uint8)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
|
||||
if err = TorchErr(); err != nil {
|
||||
dtype := gotch.Uint8
|
||||
lib.AtCopyData(ts.ctensor, vs, numel, dtype.Size())
|
||||
if err := TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -648,55 +906,20 @@ func (ts *Tensor) MustCopyDataUint8(dst []uint8, numel uint) {
|
|||
// and number of elements to C land. This may break in the future
|
||||
// if Go policy changes.
|
||||
func (ts *Tensor) CopyData(dst interface{}, numel uint) error {
|
||||
|
||||
gotype, dlen, err := DataCheck(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dtype, err := gotch.ToDType(gotype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dtype, dlen, err := DataCheck(dst)
|
||||
if dlen < int(numel) {
|
||||
err = fmt.Errorf("CopyData Error: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
|
||||
err = fmt.Errorf("ts.CopyData() failed: length of destination slice data (%v) is smaller than \nnumber of elements to be copied (%v)", dlen, numel)
|
||||
return err
|
||||
}
|
||||
|
||||
if ts.DType() != dtype {
|
||||
err = fmt.Errorf("Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
||||
err = fmt.Errorf("ts.CopyData() failed: Type mismatched: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
||||
return err
|
||||
}
|
||||
|
||||
var vs unsafe.Pointer
|
||||
switch dtype {
|
||||
case gotch.Uint8:
|
||||
vs = unsafe.Pointer(&dst.([]uint8)[0])
|
||||
case gotch.Int8:
|
||||
vs = unsafe.Pointer(&dst.([]int8)[0])
|
||||
case gotch.Int16:
|
||||
vs = unsafe.Pointer(&dst.([]int16)[0])
|
||||
case gotch.Int:
|
||||
vs = unsafe.Pointer(&dst.([]int32)[0])
|
||||
case gotch.Int64:
|
||||
vs = unsafe.Pointer(&dst.([]int64)[0])
|
||||
case gotch.Float:
|
||||
vs = unsafe.Pointer(&dst.([]float32)[0])
|
||||
case gotch.Double:
|
||||
vs = unsafe.Pointer(&dst.([]float64)[0])
|
||||
case gotch.Bool:
|
||||
vs = unsafe.Pointer(&dst.([]bool)[0])
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported type: `dst` type: %v, tensor DType: %v", dtype, ts.DType())
|
||||
return err
|
||||
}
|
||||
// Get data pointer
|
||||
dataPtr := reflect.ValueOf(dst).UnsafePointer()
|
||||
|
||||
elt_size_in_bytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lib.AtCopyData(ts.ctensor, vs, numel, elt_size_in_bytes)
|
||||
lib.AtCopyData(ts.ctensor, dataPtr, numel, dtype.Size())
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -717,20 +940,24 @@ func (ts *Tensor) MustCopyData(dst interface{}, numel uint) {
|
|||
|
||||
// Numel returns the total number of elements stored in a tensor.
|
||||
func (ts *Tensor) Numel() uint {
|
||||
if !ts.MustDefined() {
|
||||
return 0 // ts.None case
|
||||
}
|
||||
|
||||
shape := ts.MustSize()
|
||||
return uint(FlattenDim(shape))
|
||||
}
|
||||
|
||||
// ShallowClone returns a new tensor that share storage with the input tensor.
|
||||
func (ts *Tensor) ShallowClone() (*Tensor, error) {
|
||||
|
||||
ctensor := lib.AtShallowClone(ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
name := fmt.Sprintf("%s_cloned", ts.name)
|
||||
return newTensor(ctensor, name), nil
|
||||
}
|
||||
|
||||
// MustShallowClone returns a new tensor that share storage with the input
|
||||
|
@ -752,7 +979,7 @@ func (ts *Tensor) Get(index int) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor), nil
|
||||
}
|
||||
|
||||
// MustGet gets the sub-tensor at the given index. It will panic if error
|
||||
|
@ -804,19 +1031,19 @@ func (ts *Tensor) MustSave(path string) {
|
|||
}
|
||||
|
||||
// Load loads a tensor from a file.
|
||||
func Load(path string) (*Tensor, error) {
|
||||
func Load(path string, nameOpt ...string) (*Tensor, error) {
|
||||
|
||||
ctensor := lib.AtLoad(path)
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
return newTensor(ctensor, nameOpt...), nil
|
||||
}
|
||||
|
||||
// MustLoad loads a tensor to a file. It will panic if error
|
||||
func MustLoad(path string) *Tensor {
|
||||
ts, err := Load(path)
|
||||
func MustLoad(path string, nameOpt ...string) *Tensor {
|
||||
ts, err := Load(path, nameOpt...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -877,7 +1104,7 @@ func LoadMulti(path string) ([]NamedTensor, error) {
|
|||
for _, v := range data.NamedCtensors {
|
||||
namedTensor := NamedTensor{
|
||||
Name: v.Name,
|
||||
Tensor: &Tensor{v.Ctensor},
|
||||
Tensor: newTensor(v.Ctensor, v.Name),
|
||||
}
|
||||
|
||||
namedTensors = append(namedTensors, namedTensor)
|
||||
|
@ -912,7 +1139,7 @@ func LoadMultiWithDevice(path string, device gotch.Device) ([]NamedTensor, error
|
|||
for _, v := range data.NamedCtensors {
|
||||
namedTensor := NamedTensor{
|
||||
Name: v.Name,
|
||||
Tensor: &Tensor{v.Ctensor},
|
||||
Tensor: newTensor(v.Ctensor, v.Name),
|
||||
}
|
||||
|
||||
namedTensors = append(namedTensors, namedTensor)
|
||||
|
@ -959,18 +1186,22 @@ func (ts *Tensor) MustToString(lw int64) string {
|
|||
|
||||
// Drop drops (frees) the tensor
|
||||
func (ts *Tensor) Drop() error {
|
||||
lib.AtFree(ts.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
return err
|
||||
if ts.ctensor == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
// Clear SetFinalizer on ts so no double free tensor.
|
||||
// Ref. https://pkg.go.dev/runtime#SetFinalizer
|
||||
runtime.SetFinalizer(ts, nil)
|
||||
|
||||
ts.calledFrom = "ts.Drop()"
|
||||
return freeCTensor(ts)
|
||||
}
|
||||
|
||||
// MustDrop drops the tensor. It will be panic if error
|
||||
func (ts *Tensor) MustDrop() {
|
||||
if err := ts.Drop(); err != nil {
|
||||
log.Fatal(err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1025,31 +1256,14 @@ func MustGradSetEnabled(b bool) bool {
|
|||
}
|
||||
|
||||
// NoGrad runs a closure without keeping track of gradients.
|
||||
func NoGrad(fn interface{}) {
|
||||
|
||||
// TODO: This is weird but somehow we need to trigger C++ print
|
||||
// to get loss function updated. Probably it is related to
|
||||
// C++ cache clearing.
|
||||
// Next step would be creating a Go func that trigger C++ cache clean
|
||||
// instead of this ugly hacky way.
|
||||
newTs := NewTensor()
|
||||
newTs.Drop()
|
||||
|
||||
func NoGrad(fn func()) {
|
||||
// Switch off Grad
|
||||
prev := MustGradSetEnabled(false)
|
||||
MustGradSetEnabled(false)
|
||||
|
||||
// Analyze input as function. If not, throw error
|
||||
f, err := NewFunc(fn)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// invokes the function
|
||||
f.Invoke()
|
||||
fn()
|
||||
|
||||
// Switch on Grad
|
||||
_ = MustGradSetEnabled(prev)
|
||||
|
||||
MustGradSetEnabled(true)
|
||||
}
|
||||
|
||||
func NoGrad1(fn func() interface{}) interface{} {
|
||||
|
@ -1140,7 +1354,7 @@ func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
|
|||
float64Ts := ts.MustTotype(gotch.Double, false)
|
||||
|
||||
float64Ts.MustCopyData(vec, numel)
|
||||
float64Ts.MustDrop()
|
||||
// float64Ts.MustDrop()
|
||||
|
||||
if del {
|
||||
ts.MustDrop()
|
||||
|
@ -1175,33 +1389,18 @@ func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
|
|||
// E.g. res := xs.Vals().([]int64)
|
||||
func (ts *Tensor) Vals() interface{} {
|
||||
dtype := ts.DType()
|
||||
numel := ts.Numel()
|
||||
numel := int(ts.Numel())
|
||||
|
||||
var retVal interface{}
|
||||
|
||||
switch dtype.Name() {
|
||||
case "uint8":
|
||||
retVal = make([]uint8, numel)
|
||||
case "int8":
|
||||
retVal = make([]int8, numel)
|
||||
case "int16":
|
||||
retVal = make([]int16, numel)
|
||||
case "int32":
|
||||
retVal = make([]int32, numel)
|
||||
case "int64":
|
||||
retVal = make([]int64, numel)
|
||||
case "float32":
|
||||
retVal = make([]float32, numel)
|
||||
case "float64":
|
||||
retVal = make([]float64, numel)
|
||||
case "bool":
|
||||
retVal = make([]bool, numel)
|
||||
default:
|
||||
log.Fatalf("Unsupported dtype (%v)", dtype)
|
||||
typ, err := dtype.GoType()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ts.CopyData(retVal, numel)
|
||||
return retVal
|
||||
dataSlice := reflect.MakeSlice(reflect.SliceOf(typ), numel, numel).Interface()
|
||||
|
||||
ts.CopyData(dataSlice, uint(numel))
|
||||
|
||||
return dataSlice
|
||||
}
|
||||
|
||||
// FlatView flattens a tensor.
|
||||
|
@ -1292,7 +1491,8 @@ func (ts *Tensor) ConstantPadNdWithVal(pad []int64, value *Scalar, del bool) (re
|
|||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
retVal = newTensor(*ptr)
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -1306,3 +1506,51 @@ func (ts *Tensor) MustConstantPadNdWithVal(pad []int64, value *Scalar, del bool)
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// TT. Added some torch.cuda APIs for handling CUDA qt
|
||||
|
||||
// CudaCurrentDevice get device index of current CUDA device.
|
||||
func CudaCurrentDevice() (int, error) {
|
||||
currentDeviceIndex := lib.AtcGetDevice()
|
||||
if err := TorchErr(); err != nil {
|
||||
err = fmt.Errorf("ts.CudaCurrentDevice() failed: %w\n", err)
|
||||
return -99, err
|
||||
}
|
||||
|
||||
return currentDeviceIndex, nil
|
||||
}
|
||||
|
||||
// CudaSetDevice set new cuda device index and returns previous cuda index.
|
||||
func CudaSetDevice(cudaDeviceIndex int) (int, error) {
|
||||
currentDeviceIndex, err := CudaCurrentDevice()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
|
||||
return -99, err
|
||||
}
|
||||
|
||||
lib.AtcSetDevice(cudaDeviceIndex)
|
||||
if err := TorchErr(); err != nil {
|
||||
err = fmt.Errorf("ts.CudaSetDevice() failed: %w\n", err)
|
||||
return -99, err
|
||||
}
|
||||
return currentDeviceIndex, nil
|
||||
}
|
||||
|
||||
// CudaSynchronize waits for all kernels in all streams on a CUDA device to complete.
|
||||
func CudaSynchronize(cudaDeviceIndexOpt ...int) error {
|
||||
var cudaDeviceIndex int
|
||||
var err error
|
||||
if len(cudaDeviceIndexOpt) > 0 {
|
||||
cudaDeviceIndex = cudaDeviceIndexOpt[0]
|
||||
} else {
|
||||
cudaDeviceIndex, err = CudaCurrentDevice()
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ts.CudaSynchronize() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
lib.AtcSynchronize(int64(cudaDeviceIndex))
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
|
83
ts/tensor_mem_test.go
Normal file
83
ts/tensor_mem_test.go
Normal file
|
@ -0,0 +1,83 @@
|
|||
package ts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
var n int = 10
|
||||
|
||||
func newData() []float32 {
|
||||
n := 3 * 224 * 224 * 12
|
||||
data := make([]float32, n)
|
||||
for i := 0; i < n; i++ {
|
||||
data[i] = rand.Float32()
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
func printMemStats(message string, rtm runtime.MemStats) {
|
||||
fmt.Println("\n===", message, "===")
|
||||
fmt.Println("Mallocs: ", rtm.Mallocs)
|
||||
fmt.Println("Frees: ", rtm.Frees)
|
||||
fmt.Println("LiveObjects: ", rtm.Mallocs-rtm.Frees)
|
||||
fmt.Println("PauseTotalNs: ", rtm.PauseTotalNs)
|
||||
fmt.Println("NumGC: ", rtm.NumGC)
|
||||
fmt.Println("LastGC: ", time.UnixMilli(int64(rtm.LastGC/1_000_000)))
|
||||
fmt.Println("HeapObjects: ", rtm.HeapObjects)
|
||||
fmt.Println("HeapAlloc: ", rtm.HeapAlloc)
|
||||
}
|
||||
|
||||
func TestMem(t *testing.T) {
|
||||
var rtm runtime.MemStats
|
||||
runtime.ReadMemStats(&rtm)
|
||||
printMemStats("Start", rtm)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
x := MustOfSlice(newData())
|
||||
log.Printf("created tensor : %q\n", x.Name())
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&rtm)
|
||||
printMemStats("After completing loop", rtm)
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&rtm)
|
||||
printMemStats("After forced GC", rtm)
|
||||
|
||||
fmt.Printf(CheckCMemLeak())
|
||||
}
|
||||
|
||||
func TestMem1(t *testing.T) {
|
||||
var rtm runtime.MemStats
|
||||
runtime.ReadMemStats(&rtm)
|
||||
printMemStats("Start", rtm)
|
||||
|
||||
for i := 0; i < n; i++ {
|
||||
x, err := Randn([]int64{2, 3, 224, 224}, gotch.Float, gotch.CPU)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
log.Printf("created tensor : %q\n", x.Name())
|
||||
time.Sleep(time.Millisecond * 3000) // 3secs
|
||||
}
|
||||
|
||||
CleanUp()
|
||||
|
||||
runtime.ReadMemStats(&rtm)
|
||||
printMemStats("After completing loop", rtm)
|
||||
|
||||
runtime.GC()
|
||||
runtime.ReadMemStats(&rtm)
|
||||
printMemStats("After forced GC", rtm)
|
||||
|
||||
fmt.Printf(CheckCMemLeak())
|
||||
}
|
|
@ -131,3 +131,61 @@ func TestOfSlice(t *testing.T) {
|
|||
t.Errorf("Got dtype: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCudaCurrentDevice(t *testing.T) {
|
||||
cudaIdx, err := ts.CudaCurrentDevice()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
t.Logf("current CUDA index: %v\n", cudaIdx) // should be 0 if having 1 GPU device
|
||||
|
||||
x := ts.MustZeros([]int64{2, 3, 4}, gotch.Float, gotch.CudaIfAvailable())
|
||||
currentCudaIndex := x.MustDevice().Value
|
||||
t.Logf("x current cuda index: %v\n", currentCudaIndex) // 0
|
||||
|
||||
previousCudaIndex, err := ts.CudaSetDevice(currentCudaIndex)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Logf("Cuda index BEFORE set: %v\n", previousCudaIndex) // 0
|
||||
|
||||
cudaIdxAfter, err := ts.CudaCurrentDevice()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
|
||||
}
|
||||
|
||||
func TestTensor_Stride(t *testing.T) {
|
||||
shape := []int64{2, 3, 4}
|
||||
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
|
||||
|
||||
got := x.MustStride()
|
||||
want := []int64{12, 4, 1}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want %v, got %v\n", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensor_IsContiguous(t *testing.T) {
|
||||
shape := []int64{2, 3, 4}
|
||||
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
|
||||
|
||||
got := x.MustIsContiguous()
|
||||
want := true
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want %v, got %v\n", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTensor_IsMkldnn(t *testing.T) {
|
||||
shape := []int64{2, 3, 4}
|
||||
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
|
||||
|
||||
got := x.MustIsMkldnn()
|
||||
want := false
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want %v, got %v\n", want, got)
|
||||
}
|
||||
}
|
||||
|
|
64
ts/util.go
64
ts/util.go
|
@ -75,7 +75,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
if err := w.WriteByte(b); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -86,14 +86,14 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
if v.Kind() == reflect.Slice {
|
||||
expected := int(shape[0])
|
||||
if v.Len() != expected {
|
||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||
return fmt.Errorf("EncodeTensor() failed: mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
||||
if len(shape) == 1 && v.Len() > 0 {
|
||||
switch v.Index(0).Kind() {
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
return binary.Write(w, nativeEndian, v.Interface())
|
||||
}
|
||||
}
|
||||
|
@ -107,7 +107,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported type %v", v.Type())
|
||||
return fmt.Errorf("EncodeTensor() failed: unsupported type %v", v.Type())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -122,7 +122,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
return err
|
||||
}
|
||||
ptr.Elem().SetBool(b == 1)
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
||||
if len(shape) == 1 && val.Len() > 0 {
|
||||
switch val.Index(0).Kind() {
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Uint16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
|
||||
return binary.Read(r, nativeEndian, val.Interface())
|
||||
}
|
||||
}
|
||||
|
@ -152,10 +152,10 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
|
|||
}
|
||||
|
||||
// ElementCount counts number of element in the tensor given a shape
|
||||
func ElementCount(shape []int64) int64 {
|
||||
n := int64(1)
|
||||
func ElementCount(shape []int64) int {
|
||||
n := 1
|
||||
for _, d := range shape {
|
||||
n *= d
|
||||
n *= int(d)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
@ -163,54 +163,48 @@ func ElementCount(shape []int64) int64 {
|
|||
// DataDim returns number of elements in data
|
||||
// NOTE: only support scalar and (nested) slice/array of scalar type
|
||||
func DataDim(data interface{}) (retVal int, err error) {
|
||||
|
||||
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
// DataCheck checks the input data for element Go type and number of elements.
|
||||
// It will return errors if element type is not supported.
|
||||
func DataCheck(data interface{}) (k reflect.Type, n int, err error) {
|
||||
|
||||
// It will return errors if element dtype is not supported.
|
||||
func DataCheck(data interface{}) (dtype gotch.DType, n int, err error) {
|
||||
return dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||
}
|
||||
|
||||
// NOTE: 0 is reflect.Kind() of Invalid
|
||||
// See: https://golang.org/pkg/reflect/#Kind
|
||||
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
||||
func dataCheck(data interface{}, count int) (dtype gotch.DType, n int, err error) {
|
||||
v := reflect.ValueOf(data)
|
||||
var goType reflect.Type = reflect.TypeOf(data)
|
||||
var total int = count
|
||||
var round = 0
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
|
||||
if round == 0 {
|
||||
round = v.Len()
|
||||
}
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
round--
|
||||
goType, total, err = dataCheck(v.Index(i).Interface(), total)
|
||||
dtype, total, err = dataCheck(v.Index(i).Interface(), total)
|
||||
|
||||
if err != nil {
|
||||
return reflect.TypeOf(reflect.Zero), 0, err
|
||||
return gotch.Invalid, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return goType, total, nil
|
||||
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
total++
|
||||
if goType.String() != "invalid" {
|
||||
goType = v.Type()
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
|
||||
return reflect.TypeOf(reflect.Zero), 0, err
|
||||
return dtype, total, nil
|
||||
}
|
||||
|
||||
return goType, total, nil
|
||||
total += 1
|
||||
dtype, err = gotch.GoKind2DType(v.Kind())
|
||||
if err != nil {
|
||||
err = fmt.Errorf("DataCheck() failed: unsupported data structure or type: %v\n", v.Kind())
|
||||
return gotch.Invalid, 0, err
|
||||
}
|
||||
|
||||
return dtype, total, nil
|
||||
}
|
||||
|
||||
// DataAsPtr write to C memory and returns a C pointer.
|
||||
|
@ -219,25 +213,19 @@ func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
|||
// Supported data types are scalar, slice/array of scalar type equivalent to
|
||||
// DType.
|
||||
func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
|
||||
|
||||
// 1. Count number of elements in data
|
||||
elementNum, err := DataDim(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2. Element size in bytes
|
||||
// 2. Number of bytes
|
||||
dtype, err := gotch.DTypeFromData(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
nbytes := int(dtype.Size()) * int(elementNum)
|
||||
|
||||
// 3. Get C pointer and prepare C memory buffer for writing
|
||||
dataPtr, buff := CMalloc(nbytes)
|
||||
|
|
|
@ -255,9 +255,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor {
|
|||
}
|
||||
|
||||
rgbTs := x.MustUnbind(-3, false)
|
||||
r := &rgbTs[0]
|
||||
g := &rgbTs[1]
|
||||
b := &rgbTs[2]
|
||||
r := rgbTs[0]
|
||||
g := rgbTs[1]
|
||||
b := rgbTs[2]
|
||||
|
||||
// This implementation closely follows the TF one:
|
||||
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
|
||||
|
@ -311,9 +311,9 @@ func adjustSaturation(x *ts.Tensor, sat float64) *ts.Tensor {
|
|||
|
||||
func rgb2HSV(x *ts.Tensor) *ts.Tensor {
|
||||
rgbTs := x.MustUnbind(-3, false)
|
||||
r := &rgbTs[0]
|
||||
g := &rgbTs[1]
|
||||
b := &rgbTs[2]
|
||||
r := rgbTs[0]
|
||||
g := rgbTs[1]
|
||||
b := rgbTs[2]
|
||||
|
||||
// # Implementation is based on https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/
|
||||
// # src/libImaging/Convert.c#L330
|
||||
|
@ -383,7 +383,7 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
|
|||
h3 := h2.MustFmod(ts.FloatScalar(1.0), true) // delete h2
|
||||
|
||||
// torch.stack((h, s, maxc), dim=-3)
|
||||
out := ts.MustStack([]ts.Tensor{*h3, *s, *maxC}, -3)
|
||||
out := ts.MustStack([]*ts.Tensor{h3, s, maxC}, -3)
|
||||
|
||||
// Delete intermediate tensors
|
||||
r.MustDrop()
|
||||
|
@ -409,9 +409,9 @@ func rgb2HSV(x *ts.Tensor) *ts.Tensor {
|
|||
|
||||
func hsv2RGB(x *ts.Tensor) *ts.Tensor {
|
||||
hsvTs := x.MustUnbind(-3, false)
|
||||
h := &hsvTs[0]
|
||||
s := &hsvTs[1]
|
||||
v := &hsvTs[2]
|
||||
h := hsvTs[0]
|
||||
s := hsvTs[1]
|
||||
v := hsvTs[2]
|
||||
// i = torch.floor(h * 6.0)
|
||||
i := h.MustMulScalar(ts.FloatScalar(6.0), false).MustFloor(true)
|
||||
// f = (h * 6.0) - i
|
||||
|
@ -448,12 +448,12 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
|
|||
// a2 = torch.stack((t, v, v, q, p, p), dim=-3)
|
||||
// a3 = torch.stack((p, p, t, v, v, q), dim=-3)
|
||||
// a4 = torch.stack((a1, a2, a3), dim=-4)
|
||||
a1 := ts.MustStack([]ts.Tensor{*v, *q, *p, *p, *t, *v}, -3)
|
||||
a2 := ts.MustStack([]ts.Tensor{*t, *v, *v, *q, *p, *p}, -3)
|
||||
a3 := ts.MustStack([]ts.Tensor{*p, *p, *t, *v, *v, *q}, -3)
|
||||
a4 := ts.MustStack([]ts.Tensor{*a1, *a2, *a3}, -4)
|
||||
a1 := ts.MustStack([]*ts.Tensor{v, q, p, p, t, v}, -3)
|
||||
a2 := ts.MustStack([]*ts.Tensor{t, v, v, q, p, p}, -3)
|
||||
a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3)
|
||||
a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4)
|
||||
|
||||
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []ts.Tensor{*mask, *a4})
|
||||
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}, []int64{0, 1})
|
||||
|
||||
// Delete intermediate tensors
|
||||
h.MustDrop()
|
||||
|
@ -491,13 +491,13 @@ func adjustHue(x *ts.Tensor, hue float64) *ts.Tensor {
|
|||
hsvImg := rgb2HSV(imgFl)
|
||||
|
||||
hsvTs := hsvImg.MustUnbind(-3, true)
|
||||
h := &hsvTs[0]
|
||||
s := &hsvTs[1]
|
||||
v := &hsvTs[2]
|
||||
h := hsvTs[0]
|
||||
s := hsvTs[1]
|
||||
v := hsvTs[2]
|
||||
// h = (h + hue_factor) % 1.0
|
||||
hAdj := h.MustAddScalar(ts.FloatScalar(hue), false).MustRemainder(ts.FloatScalar(1.0), true)
|
||||
|
||||
hsvAdj := ts.MustStack([]ts.Tensor{*hAdj, *s, *v}, -3)
|
||||
hsvAdj := ts.MustStack([]*ts.Tensor{hAdj, s, v}, -3)
|
||||
|
||||
imgHueAdj := hsv2RGB(hsvAdj)
|
||||
|
||||
|
@ -568,7 +568,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
|
|||
dim := x.MustSize()
|
||||
c := dim[0]
|
||||
|
||||
var chans []ts.Tensor = make([]ts.Tensor, c)
|
||||
var chans []*ts.Tensor = make([]*ts.Tensor, c)
|
||||
hNar := ts.NewNarrow(top, top+height)
|
||||
wNar := ts.NewNarrow(left, left+width)
|
||||
for i := 0; i < int(c); i++ {
|
||||
|
@ -579,7 +579,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
|
|||
x2 := x1T.Idx(wNar)
|
||||
x1T.MustDrop()
|
||||
out := x2.MustT(true)
|
||||
chans[i] = *out
|
||||
chans[i] = out
|
||||
}
|
||||
|
||||
cropTs := ts.MustStack(chans, 0)
|
||||
|
@ -728,7 +728,7 @@ func applyGridTransform(x, gridInput *ts.Tensor, mode string, fillValue []float6
|
|||
// dummy = torch.ones((img.shape[0], 1, img.shape[2], img.shape[3]), dtype=img.dtype, device=img.device)
|
||||
// img = torch.cat((img, dummy), dim=1)
|
||||
dummy := ts.MustOnes([]int64{img.MustSize()[0], 1, img.MustSize()[2], img.MustSize()[3]}, img.DType(), img.MustDevice())
|
||||
imgCat := ts.MustCat([]ts.Tensor{*img, *dummy}, 1)
|
||||
imgCat := ts.MustCat([]*ts.Tensor{img, dummy}, 1)
|
||||
dummy.MustDrop()
|
||||
img.MustDrop()
|
||||
|
||||
|
@ -779,9 +779,9 @@ func applyGridTransform(x, gridInput *ts.Tensor, mode string, fillValue []float6
|
|||
// (x, y) -> ( (ax + by + c) / (gx + hy + 1), (dx + ey + f) / (gx + hy + 1) )
|
||||
// Args:
|
||||
// - startpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
|
||||
// ``[top-left, top-right, bottom-right, bottom-left]`` of the original image.
|
||||
// “[top-left, top-right, bottom-right, bottom-left]“ of the original image.
|
||||
// - endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners
|
||||
// ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image.
|
||||
// “[top-left, top-right, bottom-right, bottom-left]“ of the transformed image.
|
||||
// Returns:
|
||||
// - octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
|
||||
func perspectiveCoeff(startPoints, endPoints [][]int64) []float64 {
|
||||
|
@ -929,7 +929,7 @@ func perspective(x *ts.Tensor, startPoints, endPoints [][]int64, mode string, fi
|
|||
|
||||
// Apply affine transformation on the image keeping image center invariant.
|
||||
//
|
||||
//If the image is torch Tensor, it is expected
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// Args:
|
||||
// - img (Tensor): image to transform.
|
||||
|
@ -940,8 +940,8 @@ func perspective(x *ts.Tensor, startPoints, endPoints [][]int64, mode string, fi
|
|||
// If a sequence is specified, the first value corresponds to a shear parallel to the x axis, while
|
||||
// the second value corresponds to a shear parallel to the y axis.
|
||||
// - interpolation (InterpolationMode): Desired interpolation enum defined by
|
||||
// :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
||||
// If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
||||
// :class:`torchvision.transforms.InterpolationMode`. Default is “InterpolationMode.NEAREST“.
|
||||
// If input is Tensor, only “InterpolationMode.NEAREST“, “InterpolationMode.BILINEAR“ are supported.
|
||||
// - fill (sequence or number, optional): Pixel fill value for the area outside the transformed
|
||||
// image. If given a number, the value is used for all bands respectively.
|
||||
func affine(img *ts.Tensor, angle float64, translations []int64, scale float64, shear []float64, interpolationMode string, fillValue []float64) *ts.Tensor {
|
||||
|
@ -982,17 +982,19 @@ func affine(img *ts.Tensor, angle float64, translations []int64, scale float64,
|
|||
// As it is explained in PIL.Image.rotate
|
||||
// We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
|
||||
// where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
|
||||
// C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
|
||||
// RSS is rotation with scale and shear matrix
|
||||
// RSS(a, s, (sx, sy)) =
|
||||
// = R(a) * S(s) * SHy(sy) * SHx(sx)
|
||||
// = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
|
||||
// [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
|
||||
// [ 0 , 0 , 1 ]
|
||||
//
|
||||
// C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
|
||||
// RSS is rotation with scale and shear matrix
|
||||
// RSS(a, s, (sx, sy)) =
|
||||
// = R(a) * S(s) * SHy(sy) * SHx(sx)
|
||||
// = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
|
||||
// [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
|
||||
// [ 0 , 0 , 1 ]
|
||||
//
|
||||
// where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
|
||||
// SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
|
||||
// [0, 1 ] [-tan(s), 1]
|
||||
//
|
||||
// [0, 1 ] [-tan(s), 1]
|
||||
//
|
||||
// Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
|
||||
func getInverseAffineMatrix(center []float64, angle float64, translate []float64, scale float64, shear []float64) []float64 {
|
||||
|
@ -1345,11 +1347,11 @@ func equalize(img *ts.Tensor) *ts.Tensor {
|
|||
}
|
||||
|
||||
// batched images
|
||||
var images []ts.Tensor
|
||||
var images []*ts.Tensor
|
||||
for i := 0; i < int(shape[0]); i++ {
|
||||
x := img.MustSelect(0, int64(i), false)
|
||||
o := equalizeSingleImage(x)
|
||||
images = append(images, *o)
|
||||
images = append(images, o)
|
||||
x.MustDrop()
|
||||
}
|
||||
|
||||
|
@ -1363,12 +1365,12 @@ func equalize(img *ts.Tensor) *ts.Tensor {
|
|||
|
||||
func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
|
||||
dim := img.MustSize()
|
||||
var scaledChans []ts.Tensor = make([]ts.Tensor, int(dim[0]))
|
||||
var scaledChans []*ts.Tensor = make([]*ts.Tensor, int(dim[0]))
|
||||
for i := 0; i < int(dim[0]); i++ {
|
||||
cTs := img.MustSelect(0, int64(i), false)
|
||||
scaledChan := scaleChannel(cTs)
|
||||
cTs.MustDrop()
|
||||
scaledChans[i] = *scaledChan
|
||||
scaledChans[i] = scaledChan
|
||||
}
|
||||
|
||||
out := ts.MustStack(scaledChans, 0)
|
||||
|
|
|
@ -83,8 +83,8 @@ func CFLoadDir(dir string) *Dataset {
|
|||
|
||||
testImages, testLabels := readFile(fmt.Sprintf("%v/test_batch.bin", dirAbs))
|
||||
|
||||
var trainImages []ts.Tensor
|
||||
var trainLabels []ts.Tensor
|
||||
var trainImages []*ts.Tensor
|
||||
var trainLabels []*ts.Tensor
|
||||
|
||||
trainFiles := []string{
|
||||
"data_batch_1.bin",
|
||||
|
@ -96,8 +96,8 @@ func CFLoadDir(dir string) *Dataset {
|
|||
|
||||
for _, f := range trainFiles {
|
||||
img, l := readFile(fmt.Sprintf("%v/%v", dirAbs, f))
|
||||
trainImages = append(trainImages, *img)
|
||||
trainLabels = append(trainLabels, *l)
|
||||
trainImages = append(trainImages, img)
|
||||
trainLabels = append(trainLabels, l)
|
||||
}
|
||||
|
||||
return &Dataset{
|
||||
|
|
|
@ -39,7 +39,7 @@ func (l *denseLayer) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
|
|||
ys := ys5.Apply(l.Conv2)
|
||||
ys5.MustDrop()
|
||||
|
||||
res := ts.MustCat([]ts.Tensor{*xs, *ys}, 1)
|
||||
res := ts.MustCat([]*ts.Tensor{xs, ys}, 1)
|
||||
ys.MustDrop()
|
||||
|
||||
return res
|
||||
|
|
|
@ -212,7 +212,7 @@ func LoadAndResize(path string, outW int64, outH int64) (*ts.Tensor, error) {
|
|||
// LoadDir loads all the images in a directory.
|
||||
func LoadDir(dir string, outW int64, outH int64) (*ts.Tensor, error) {
|
||||
var filePaths []string // "dir/filename.ext"
|
||||
var tensors []ts.Tensor
|
||||
var tensors []*ts.Tensor
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadDir - Read directory error: %v\n", err)
|
||||
|
@ -228,7 +228,7 @@ func LoadDir(dir string, outW int64, outH int64) (*ts.Tensor, error) {
|
|||
err = fmt.Errorf("LoadDir - LoadAndResize method call error: %v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, *tensor)
|
||||
tensors = append(tensors, tensor)
|
||||
}
|
||||
|
||||
stackedTs, err := ts.Stack(tensors, 0)
|
||||
|
|
|
@ -169,7 +169,7 @@ func (in *ImageNet) hasSuffix(path string) bool {
|
|||
}
|
||||
|
||||
func (in *ImageNet) loadImageFromDir(dir string) (*ts.Tensor, error) {
|
||||
var images []ts.Tensor
|
||||
var images []*ts.Tensor
|
||||
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
|
@ -188,7 +188,7 @@ func (in *ImageNet) loadImageFromDir(dir string) (*ts.Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
images = append(images, *img)
|
||||
images = append(images, img)
|
||||
}
|
||||
|
||||
if len(images) == 0 {
|
||||
|
@ -243,10 +243,10 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
|
|||
// fmt.Printf("Classess: %v\n", classes)
|
||||
|
||||
var (
|
||||
trainImages []ts.Tensor
|
||||
trainLabels []ts.Tensor
|
||||
testImages []ts.Tensor
|
||||
testLabels []ts.Tensor
|
||||
trainImages []*ts.Tensor
|
||||
trainLabels []*ts.Tensor
|
||||
testImages []*ts.Tensor
|
||||
testLabels []*ts.Tensor
|
||||
)
|
||||
|
||||
for labelIdx, labelDir := range classes {
|
||||
|
@ -261,10 +261,10 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
|
|||
}
|
||||
|
||||
ntrainTs := trainTs.MustSize()[0]
|
||||
trainImages = append(trainImages, *trainTs)
|
||||
trainImages = append(trainImages, trainTs)
|
||||
|
||||
trainLabelOnes := ts.MustOnes([]int64{ntrainTs}, gotch.Int64, gotch.CPU)
|
||||
trainLabels = append(trainLabels, *trainLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
|
||||
trainLabels = append(trainLabels, trainLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
|
||||
|
||||
// test
|
||||
testDir := fmt.Sprintf("%v/%v", validPath, labelDir)
|
||||
|
@ -274,10 +274,10 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
|
|||
return nil, err
|
||||
}
|
||||
ntestTs := testTs.MustSize()[0]
|
||||
testImages = append(testImages, *testTs)
|
||||
testImages = append(testImages, testTs)
|
||||
|
||||
testLabelOnes := ts.MustOnes([]int64{ntestTs}, gotch.Int64, gotch.CPU)
|
||||
testLabels = append(testLabels, *testLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
|
||||
testLabels = append(testLabels, testLabelOnes.MustMulScalar(ts.IntScalar(labelIndex), true))
|
||||
}
|
||||
|
||||
trainImageTs := ts.MustCat(trainImages, 0)
|
||||
|
@ -298,7 +298,7 @@ func (in *ImageNet) LoadFromDir(path string) (*Dataset, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func dropTsSlice(tensors []ts.Tensor) {
|
||||
func dropTsSlice(tensors []*ts.Tensor) {
|
||||
for i := 0; i < len(tensors); i++ {
|
||||
tensors[i].MustDrop()
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ func inceptionA(p *nn.Path, cIn, cPool int64) ts.ModuleT {
|
|||
bpoolTmp := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||
bpoolTs := bpoolTmp.ApplyT(bpool, train)
|
||||
|
||||
res := ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||
res := ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
|
||||
|
||||
return res
|
||||
})
|
||||
|
@ -104,7 +104,7 @@ func inceptionB(p *nn.Path, cIn int64) ts.ModuleT {
|
|||
|
||||
bpoolTs := inMaxPool2D(xs, 3, 2)
|
||||
|
||||
res := ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *bpoolTs}, 1)
|
||||
res := ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1)
|
||||
|
||||
return res
|
||||
})
|
||||
|
@ -148,7 +148,7 @@ func inceptionC(p *nn.Path, cIn int64, c7 int64) ts.ModuleT {
|
|||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||
|
||||
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||
return ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -177,7 +177,7 @@ func inceptionD(p *nn.Path, cIn int64) ts.ModuleT {
|
|||
|
||||
bpoolTs := inMaxPool2D(xs, 3, 2)
|
||||
|
||||
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *bpoolTs}, 1)
|
||||
return ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, bpoolTs}, 1)
|
||||
|
||||
})
|
||||
}
|
||||
|
@ -202,19 +202,19 @@ func inceptionE(p *nn.Path, cIn int64) ts.ModuleT {
|
|||
b2Tmp := xs.ApplyT(b21, train)
|
||||
b2aTs := b2Tmp.ApplyT(b22a, train)
|
||||
b2bTs := b2Tmp.ApplyT(b22b, train)
|
||||
b2Ts := ts.MustCat([]ts.Tensor{*b2aTs, *b2bTs}, 1)
|
||||
b2Ts := ts.MustCat([]*ts.Tensor{b2aTs, b2bTs}, 1)
|
||||
|
||||
b3Tmp1 := xs.ApplyT(b31, train)
|
||||
b3Tmp2 := b3Tmp1.ApplyT(b32, train)
|
||||
b3Tmp1.MustDrop()
|
||||
b3aTs := b3Tmp2.ApplyT(b33a, train)
|
||||
b3bTs := b3Tmp2.ApplyT(b33b, train)
|
||||
b3Ts := ts.MustCat([]ts.Tensor{*b3aTs, *b3bTs}, 1)
|
||||
b3Ts := ts.MustCat([]*ts.Tensor{b3aTs, b3bTs}, 1)
|
||||
|
||||
bpTmp1 := xs.MustAvgPool2d([]int64{3, 3}, []int64{1, 1}, []int64{1, 1}, false, true, []int64{9}, false)
|
||||
bpoolTs := bpTmp1.ApplyT(bpool, train)
|
||||
|
||||
return ts.MustCat([]ts.Tensor{*b1Ts, *b2Ts, *b3Ts, *bpoolTs}, 1)
|
||||
return ts.MustCat([]*ts.Tensor{b1Ts, b2Ts, b3Ts, bpoolTs}, 1)
|
||||
})
|
||||
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ func fire(p *nn.Path, cIn int64, cSqueeze int64, cExp1 int64, cExp3 int64) ts.Mo
|
|||
exp3Tmp := tmp2.Apply(exp3)
|
||||
exp3Ts := exp3Tmp.MustRelu(true)
|
||||
|
||||
return ts.MustCat([]ts.Tensor{*exp1Ts, *exp3Ts}, 1)
|
||||
return ts.MustCat([]*ts.Tensor{exp1Ts, exp3Ts}, 1)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user