added dtype option to nn package

This commit is contained in:
sugarme 2023-07-07 12:24:56 +10:00
parent 523061eca6
commit 34e87b1302
12 changed files with 164 additions and 129 deletions

View File

@ -394,3 +394,30 @@ func DTypeFromData(data interface{}) (DType, error) {
// single element
return GoKind2DType(dataKind)
}
// IsFloatDType returns whether dtype is floating point data type.
func IsFloatDType(dtype DType) bool {
switch dtype {
case Double, Float, Half, BFloat16:
return true
default:
return false
}
}
// Default DType:
// ==============
var DefaultDType DType = Float
// SetDefaultDType set DefaultDType to new value and return the previous one.
func SetDefaultDType(dtype DType) DType {
odtype := DefaultDType
DefaultDType = dtype
if Debug {
log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype)
}
return odtype
}

Binary file not shown.

Binary file not shown.

View File

@ -1,67 +0,0 @@
package main
import (
"fmt"
"runtime/metrics"
// "math/rand"
// "time"
// "github.com/sugarme/gotch"
// "github.com/sugarme/gotch/ts"
)
// Run with: `go build -gcflags="-m=3"`
//go:noinline
func main() {
s := []metrics.Sample{{Name: "/gc/stack/starting-size:bytes"}}
metrics.Read(s)
fmt.Printf("Initial stack size: %d\n", s[0].Value.Uint64())
// x, err := ts.Randn([]int64{2, 3, 224, 224}, gotch.Float, gotch.CPU)
// if err != nil {
// panic(err)
// }
// fmt.Printf("x: %v\n", x.Name())
// x := ts.MustOfSlice([]float32{1, 2, 3})
// fmt.Printf("x: %v\n", x.Name())
x := new(foo)
// x := newFoo()
// x := &foo{
// name: "foo",
// f: &foo1{foo1Name: "foo1"},
// }
fmt.Printf("x: %q\n", x.name)
// b := new(ts.Tensor)
// fmt.Printf("b: %v\n", b)
// time.Sleep(time.Second * 2)
}
type foo struct {
data [1e4]interface{} // 10_000 * 4 = 40_000 bytes
name string
// f *foo1
}
type foo1 struct {
foo1Name string
}
func newFoo() *foo {
return new(foo)
}
// func newData() []float32 {
// // n := 3 * 224 * 224 * 12
// n := 3
// data := make([]float32, n)
// for i := 0; i < n; i++ {
// data[i] = rand.Float32()
// }
//
// return data
// }

Binary file not shown.

View File

@ -1,6 +1,7 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
@ -25,8 +26,10 @@ func main() {
panic(err)
}
err = pickle.LoadInfo(modelFile)
m, err := pickle.LoadModelInfo(modelFile)
if err != nil {
log.Fatal(err)
}
fmt.Println(m)
}

View File

@ -12,7 +12,7 @@ import (
type Init interface {
// creates a new tensor with specified initiation
InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor)
InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor)
// re-initializes (in-place) an existing tensor with the specified initiation
Set(tensor *ts.Tensor)
@ -25,18 +25,24 @@ type constInit struct {
value float64
}
var _ Init = new(constInit)
func NewConstInit(v float64) constInit {
return constInit{v}
}
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
func (c constInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
var err error
kind := gotch.Float
switch {
case c.value == 0.0:
retVal = ts.MustZeros(dims, kind, device)
retVal = ts.MustZeros(dims, dtype, device)
case c.value == 1.0:
retVal = ts.MustOnes(dims, kind, device)
retVal = ts.MustOnes(dims, dtype, device)
default:
data := make([]float64, ts.FlattenDim(dims))
for i := range data {
@ -68,17 +74,24 @@ type randnInit struct {
stdev float64
}
var _ Init = new(randnInit)
func NewRandnInit(mean, stdev float64) randnInit {
return randnInit{mean, stdev}
}
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
return ts.MustRandn(dims, gotch.Float, device)
func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
initTs := ts.MustRandn(dims, gotch.Float, device)
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
return ts.MustRandn(dims, dtype, device)
}
initTs := ts.MustRandn(dims, dtype, device)
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
}
@ -101,14 +114,20 @@ type uniformInit struct {
up float64
}
var _ Init = new(uniformInit)
func NewUniformInit(lo, up float64) uniformInit {
return uniformInit{lo, up}
}
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
var err error
kind := gotch.Float
retVal = ts.MustZeros(dims, kind, device)
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(u.lo, u.up)
if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
@ -174,6 +193,8 @@ type kaimingUniformInit struct {
NonLinearity string
}
var _ Init = new(kaimingUniformInit)
func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
o := DefaultKaimingOptions()
for _, opt := range opts {
@ -187,7 +208,12 @@ func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
}
}
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
fanIn, _, err := CalculateFans(dims)
if err != nil {
panic(err)
@ -204,8 +230,7 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retV
// Calculate uniform bounds from standard deviation
bound := math.Sqrt(3.0) * std
kind := gotch.Float
retVal = ts.MustZeros(dims, kind, device)
retVal = ts.MustZeros(dims, dtype, device)
retVal.Uniform_(-bound, bound)
return retVal

View File

@ -40,11 +40,14 @@ type Linear struct {
// outDim - output dimension (y) [output features - columns]
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
dtype := gotch.DefaultDType
var bs *ts.Tensor
// bs has size of output dimension
switch c.Bias {
case false:
bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device())
// FIXME. do we need this? or just remove it and in the `Forward` creating on-fly
// with same dtype and device to the input.
bs = ts.MustZeros([]int64{outDim}, dtype, vs.Device())
case true:
switch {
case c.BsInit == nil:
@ -87,20 +90,19 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
//
// Example:
//
// inDim := 3
// outDim := 2
// batchSize := 4
// weights: 2x3
// [ 1 1 1
// 1 1 1 ]
// inDim := 3
// outDim := 2
// batchSize := 4
// weights: 2x3
// [ 1 1 1
// 1 1 1 ]
//
// input node: 3x4
// [ 1 1 1
// 1 1 1
// 1 1 1
// 1 1 1 ]
// input node: 3x4
// [ 1 1 1
// 1 1 1
// 1 1 1
// 1 1 1 ]
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true)
}
@ -109,7 +111,6 @@ func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
//
// NOTE: train param will not be used.
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true)
}

View File

@ -68,7 +68,7 @@ func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tenso
reduction := options.Reduction
ignoreIndex := options.IgnoreIndex
logSm := logits.MustLogSoftmax(-1, gotch.Float, false)
logSm := logits.MustLogSoftmax(-1, dtype, false)
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
ws.MustDrop()

View File

@ -7,7 +7,6 @@ import (
"log"
"math"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts"
)
@ -418,6 +417,10 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
)
device := opt.varstore.device
// FIXME. What about mixed-precision?
dtype := parameters[0].DType()
if o.NormType == math.Inf(1) {
for _, v := range opt.varstore.vars {
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
@ -431,14 +434,14 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
norms = append(norms, x)
}
}
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
for _, x := range norms {
x.MustDrop()
}

View File

@ -149,7 +149,9 @@ func (l *LSTM) ZeroState(batchDim int64) State {
layerDim := l.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, l.hiddenDim}
zeros := ts.MustZeros(shape, gotch.Float, l.device)
dtype := l.flatWeights[0].DType()
zeros := ts.MustZeros(shape, dtype, l.device)
retVal := &LSTMState{
Tensor1: zeros.MustShallowClone(),
@ -269,7 +271,8 @@ func (g *GRU) ZeroState(batchDim int64) State {
layerDim := g.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, g.hiddenDim}
tensor := ts.MustZeros(shape, gotch.Float, g.device)
dtype := g.flatWeights[0].DType()
tensor := ts.MustZeros(shape, dtype, g.device)
return &GRUState{Tensor: tensor}
}

View File

@ -417,12 +417,21 @@ func (vs *VarStore) Summary() {
layers = append(layers, name)
}
sort.Strings(layers)
var dtype gotch.DType
isFirst := true
for _, l := range layers {
var x *ts.Tensor
var isBuffer bool
for name, v := range vars {
if name == l {
x = v.Tensor
// Get DType of first tensor for representation only
if isFirst {
dtype = x.DType()
}
isFirst = false
isBuffer = v.Type == "buffer"
break
}
@ -435,25 +444,19 @@ func (vs *VarStore) Summary() {
}
fmt.Printf("Num of layers: %v\n", len(vars))
fmt.Printf("DType: %v\n", dtype)
}
// ToDType casts all variables in VarStore to specified DType.
//
// NOTE. only float-like types (Half, Float, Double) can ensure convertible.
// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.
func (vs *VarStore) ToDType(dtype gotch.DType) {
vs.Root().ToDType(dtype)
}
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToHalf() {
vs.Root().ToHalf()
}
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
// NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype.
func (vs *VarStore) ToFloat() {
vs.Root().ToFloat()
}
@ -465,6 +468,20 @@ func (vs *VarStore) ToDouble() {
vs.Root().ToDouble()
}
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToHalf() {
vs.Root().ToHalf()
}
// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToBFloat16() {
vs.Root().ToBFloat16()
}
// Path methods:
// =============
@ -664,7 +681,7 @@ func (p *Path) SetGroup(g uint) {
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
//
// NOTE. this method should be used for floating-point conversion, i.e.,
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double".
// "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double".
func (p *Path) ToDType(dtype gotch.DType) {
p.varstore.Lock()
defer p.varstore.Unlock()
@ -686,7 +703,7 @@ func (p *Path) toFloat(dtype gotch.DType) {
for name, v := range p.varstore.vars {
if strings.Contains(name, path) {
dtype := v.Tensor.DType()
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double {
if gotch.IsFloatDType(dtype) {
newVar := v
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
p.varstore.vars[name] = newVar
@ -695,19 +712,37 @@ func (p *Path) toFloat(dtype gotch.DType) {
}
}
// ToHalf casts all variables in current path and subpaths to `Half` precision.
// ToFloat casts all variables in current path and subpaths to `Float` precision.
func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) {
dtype := gotch.Float
if len(floatDTypeOpt) > 0 {
dt := floatDTypeOpt[0]
if !gotch.IsFloatDType(dt) {
// Ingore the option
if gotch.Debug {
log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt)
}
} else {
dtype = dt
}
}
p.toFloat(dtype)
}
// ToDouble casts all variables in current path and subpaths to `Double` precision dtype.
func (p *Path) ToDouble() {
p.toFloat(gotch.Double)
}
// ToHalf casts all variables in current path and subpaths to `Half` precision dtype.
func (p *Path) ToHalf() {
p.toFloat(gotch.Half)
}
// ToFloat casts all variables in current path and subpaths to `Float` precision.
func (p *Path) ToFloat() {
p.toFloat(gotch.Float)
}
// ToDouble casts all variables in current path and subpaths to `Double` precision.
func (p *Path) ToDouble() {
p.toFloat(gotch.Double)
// ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype.
func (p *Path) ToBFloat16() {
p.toFloat(gotch.BFloat16)
}
// ZerosNoTrain creates a new variable initialized with zeros.
@ -718,7 +753,8 @@ func (p *Path) ToDouble() {
// The variable uses a float tensor initialized with zeros.
func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
device := p.Device()
z, err := ts.Zeros(dims, gotch.Float, device)
dtype := gotch.DefaultDType
z, err := ts.Zeros(dims, dtype, device)
if err != nil {
err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err)
return nil, err
@ -755,7 +791,8 @@ func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.T
// The variable uses a float tensor initialized with ones.
func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
device := p.Device()
z, err := ts.Ones(dims, gotch.Float, device)
dtype := gotch.DefaultDType
z, err := ts.Ones(dims, dtype, device)
if err != nil {
err = fmt.Errorf("Path.OneNoTrain() failed: %w", err)
return nil, err
@ -792,7 +829,8 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
// The variable uses a float tensor initialized as per the
// related argument.
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
v := ini.InitTensor(dims, p.varstore.device)
dtype := gotch.DefaultDType
v := ini.InitTensor(dims, p.varstore.device, dtype)
out, err := p.Add(name, v, true, opts...)
if err != nil {
return nil, err
@ -1098,7 +1136,8 @@ func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor {
// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable.
func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
o := ts.MustOnes(dims, gotch.Float, e.path.Device())
dtype := gotch.DefaultDType
o := ts.MustOnes(dims, dtype, e.path.Device())
out, err := e.path.getOrAddWithLock(e.name, o, true, opts...)
if err != nil {
return nil, err
@ -1161,7 +1200,8 @@ func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.
// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable.
func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
dtype := gotch.DefaultDType
z := ts.MustZeros(dims, dtype, e.path.Device())
out, err := e.path.getOrAddWithLock(e.name, z, true, opts...)
if err != nil {
return nil, err