diff --git a/dtype.go b/dtype.go index 02d7225..de7ae27 100644 --- a/dtype.go +++ b/dtype.go @@ -394,3 +394,30 @@ func DTypeFromData(data interface{}) (DType, error) { // single element return GoKind2DType(dataKind) } + +// IsFloatDType returns whether dtype is floating point data type. +func IsFloatDType(dtype DType) bool { + switch dtype { + case Double, Float, Half, BFloat16: + return true + + default: + return false + } +} + +// Default DType: +// ============== +var DefaultDType DType = Float + +// SetDefaultDType set DefaultDType to new value and return the previous one. +func SetDefaultDType(dtype DType) DType { + odtype := DefaultDType + DefaultDType = dtype + + if Debug { + log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype) + } + + return odtype +} diff --git a/example/heap/heap b/example/heap/heap deleted file mode 100755 index d6d1c68..0000000 Binary files a/example/heap/heap and /dev/null differ diff --git a/example/heap/main b/example/heap/main deleted file mode 100755 index dc18bff..0000000 Binary files a/example/heap/main and /dev/null differ diff --git a/example/heap/main.go b/example/heap/main.go deleted file mode 100644 index cefe541..0000000 --- a/example/heap/main.go +++ /dev/null @@ -1,67 +0,0 @@ -package main - -import ( - "fmt" - "runtime/metrics" - // "math/rand" - // "time" - // "github.com/sugarme/gotch" - // "github.com/sugarme/gotch/ts" -) - -// Run with: `go build -gcflags="-m=3"` - -//go:noinline -func main() { - s := []metrics.Sample{{Name: "/gc/stack/starting-size:bytes"}} - metrics.Read(s) - fmt.Printf("Initial stack size: %d\n", s[0].Value.Uint64()) - - // x, err := ts.Randn([]int64{2, 3, 224, 224}, gotch.Float, gotch.CPU) - // if err != nil { - // panic(err) - // } - // fmt.Printf("x: %v\n", x.Name()) - - // x := ts.MustOfSlice([]float32{1, 2, 3}) - // fmt.Printf("x: %v\n", x.Name()) - - x := new(foo) - // x := newFoo() - // x := &foo{ - // name: "foo", - // f: &foo1{foo1Name: "foo1"}, - // } - - fmt.Printf("x: %q\n", x.name) - - // b := new(ts.Tensor) - // fmt.Printf("b: %v\n", b) - - // time.Sleep(time.Second * 2) -} - -type foo struct { - data [1e4]interface{} // 10_000 * 4 = 40_000 bytes - name string - // f *foo1 -} - -type foo1 struct { - foo1Name string -} - -func newFoo() *foo { - return new(foo) -} - -// func newData() []float32 { -// // n := 3 * 224 * 224 * 12 -// n := 3 -// data := make([]float32, n) -// for i := 0; i < n; i++ { -// data[i] = rand.Float32() -// } -// -// return data -// } diff --git a/example/heap/main.o b/example/heap/main.o deleted file mode 100644 index 287c2c8..0000000 Binary files a/example/heap/main.o and /dev/null differ diff --git a/example/pickle/main.go b/example/pickle/main.go index bec4742..fa79824 100644 --- a/example/pickle/main.go +++ b/example/pickle/main.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "log" "github.com/sugarme/gotch" @@ -25,8 +26,10 @@ func main() { panic(err) } - err = pickle.LoadInfo(modelFile) + m, err := pickle.LoadModelInfo(modelFile) if err != nil { log.Fatal(err) } + + fmt.Println(m) } diff --git a/nn/init.go b/nn/init.go index 91197f7..87add04 100644 --- a/nn/init.go +++ b/nn/init.go @@ -12,7 +12,7 @@ import ( type Init interface { // creates a new tensor with specified initiation - InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) + InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) // re-initializes (in-place) an existing tensor with the specified initiation Set(tensor *ts.Tensor) @@ -25,18 +25,24 @@ type constInit struct { value float64 } +var _ Init = new(constInit) + func NewConstInit(v float64) constInit { return constInit{v} } -func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { +func (c constInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) { + dtype := gotch.DefaultDType + if len(dtypeOpt) > 0 { + dtype = dtypeOpt[0] + } + var err error - kind := gotch.Float switch { case c.value == 0.0: - retVal = ts.MustZeros(dims, kind, device) + retVal = ts.MustZeros(dims, dtype, device) case c.value == 1.0: - retVal = ts.MustOnes(dims, kind, device) + retVal = ts.MustOnes(dims, dtype, device) default: data := make([]float64, ts.FlattenDim(dims)) for i := range data { @@ -68,17 +74,24 @@ type randnInit struct { stdev float64 } +var _ Init = new(randnInit) + func NewRandnInit(mean, stdev float64) randnInit { return randnInit{mean, stdev} } -func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { - // if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 { - if r.mean == 0 { - return ts.MustRandn(dims, gotch.Float, device) +func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) { + dtype := gotch.DefaultDType + if len(dtypeOpt) > 0 { + dtype = dtypeOpt[0] } - initTs := ts.MustRandn(dims, gotch.Float, device) + // if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 { + if r.mean == 0 { + return ts.MustRandn(dims, dtype, device) + } + + initTs := ts.MustRandn(dims, dtype, device) return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true) } @@ -101,14 +114,20 @@ type uniformInit struct { up float64 } +var _ Init = new(uniformInit) + func NewUniformInit(lo, up float64) uniformInit { return uniformInit{lo, up} } -func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { +func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) { + dtype := gotch.DefaultDType + if len(dtypeOpt) > 0 { + dtype = dtypeOpt[0] + } + var err error - kind := gotch.Float - retVal = ts.MustZeros(dims, kind, device) + retVal = ts.MustZeros(dims, dtype, device) retVal.Uniform_(u.lo, u.up) if err != nil { log.Fatalf("uniformInit - InitTensor method call error: %v\n", err) @@ -174,6 +193,8 @@ type kaimingUniformInit struct { NonLinearity string } +var _ Init = new(kaimingUniformInit) + func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit { o := DefaultKaimingOptions() for _, opt := range opts { @@ -187,7 +208,12 @@ func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit { } } -func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { +func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) { + dtype := gotch.DefaultDType + if len(dtypeOpt) > 0 { + dtype = dtypeOpt[0] + } + fanIn, _, err := CalculateFans(dims) if err != nil { panic(err) @@ -204,8 +230,7 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retV // Calculate uniform bounds from standard deviation bound := math.Sqrt(3.0) * std - kind := gotch.Float - retVal = ts.MustZeros(dims, kind, device) + retVal = ts.MustZeros(dims, dtype, device) retVal.Uniform_(-bound, bound) return retVal diff --git a/nn/linear.go b/nn/linear.go index 9551536..f7013c4 100644 --- a/nn/linear.go +++ b/nn/linear.go @@ -40,11 +40,14 @@ type Linear struct { // outDim - output dimension (y) [output features - columns] // NOTE: w will have shape{outDim, inDim}; b will have shape{outDim} func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear { + dtype := gotch.DefaultDType var bs *ts.Tensor // bs has size of output dimension switch c.Bias { case false: - bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device()) + // FIXME. do we need this? or just remove it and in the `Forward` creating on-fly + // with same dtype and device to the input. + bs = ts.MustZeros([]int64{outDim}, dtype, vs.Device()) case true: switch { case c.BsInit == nil: @@ -87,20 +90,19 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear { // // Example: // -// inDim := 3 -// outDim := 2 -// batchSize := 4 -// weights: 2x3 -// [ 1 1 1 -// 1 1 1 ] +// inDim := 3 +// outDim := 2 +// batchSize := 4 +// weights: 2x3 +// [ 1 1 1 +// 1 1 1 ] // -// input node: 3x4 -// [ 1 1 1 -// 1 1 1 -// 1 1 1 -// 1 1 1 ] +// input node: 3x4 +// [ 1 1 1 +// 1 1 1 +// 1 1 1 +// 1 1 1 ] func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) { - mul := xs.MustMatmul(l.Ws, false) return mul.MustAdd(l.Bs, true) } @@ -109,7 +111,6 @@ func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) { // // NOTE: train param will not be used. func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) { - mul := xs.MustMatmul(l.Ws, false) return mul.MustAdd(l.Bs, true) } diff --git a/nn/loss.go b/nn/loss.go index 6564a63..a30bb28 100644 --- a/nn/loss.go +++ b/nn/loss.go @@ -68,7 +68,7 @@ func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tenso reduction := options.Reduction ignoreIndex := options.IgnoreIndex - logSm := logits.MustLogSoftmax(-1, gotch.Float, false) + logSm := logits.MustLogSoftmax(-1, dtype, false) loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true) ws.MustDrop() diff --git a/nn/optimizer.go b/nn/optimizer.go index f71c45a..e882e86 100644 --- a/nn/optimizer.go +++ b/nn/optimizer.go @@ -7,7 +7,6 @@ import ( "log" "math" - "github.com/sugarme/gotch" "github.com/sugarme/gotch/ts" ) @@ -418,6 +417,10 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error { ) device := opt.varstore.device + + // FIXME. What about mixed-precision? + dtype := parameters[0].DType() + if o.NormType == math.Inf(1) { for _, v := range opt.varstore.vars { n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true) @@ -431,14 +434,14 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error { // NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm // Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm - x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true) + x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true) norms = append(norms, x) } } // totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true) // total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) - totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true) + totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true) for _, x := range norms { x.MustDrop() } diff --git a/nn/rnn.go b/nn/rnn.go index c72f75a..b7d02bf 100644 --- a/nn/rnn.go +++ b/nn/rnn.go @@ -149,7 +149,9 @@ func (l *LSTM) ZeroState(batchDim int64) State { layerDim := l.config.NumLayers * numDirections shape := []int64{layerDim, batchDim, l.hiddenDim} - zeros := ts.MustZeros(shape, gotch.Float, l.device) + + dtype := l.flatWeights[0].DType() + zeros := ts.MustZeros(shape, dtype, l.device) retVal := &LSTMState{ Tensor1: zeros.MustShallowClone(), @@ -269,7 +271,8 @@ func (g *GRU) ZeroState(batchDim int64) State { layerDim := g.config.NumLayers * numDirections shape := []int64{layerDim, batchDim, g.hiddenDim} - tensor := ts.MustZeros(shape, gotch.Float, g.device) + dtype := g.flatWeights[0].DType() + tensor := ts.MustZeros(shape, dtype, g.device) return &GRUState{Tensor: tensor} } diff --git a/nn/varstore.go b/nn/varstore.go index 1cbc0a0..e2620a7 100644 --- a/nn/varstore.go +++ b/nn/varstore.go @@ -417,12 +417,21 @@ func (vs *VarStore) Summary() { layers = append(layers, name) } sort.Strings(layers) + var dtype gotch.DType + isFirst := true for _, l := range layers { var x *ts.Tensor var isBuffer bool for name, v := range vars { if name == l { x = v.Tensor + + // Get DType of first tensor for representation only + if isFirst { + dtype = x.DType() + } + isFirst = false + isBuffer = v.Type == "buffer" break } @@ -435,25 +444,19 @@ func (vs *VarStore) Summary() { } fmt.Printf("Num of layers: %v\n", len(vars)) + fmt.Printf("DType: %v\n", dtype) } // ToDType casts all variables in VarStore to specified DType. // -// NOTE. only float-like types (Half, Float, Double) can ensure convertible. +// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible. func (vs *VarStore) ToDType(dtype gotch.DType) { vs.Root().ToDType(dtype) } -// ToHalf casts all float-like variables in VarStore to `Half` dtype. -// -// NOTE. float-like includes `Half`, `Float` and `Double` dtype. -func (vs *VarStore) ToHalf() { - vs.Root().ToHalf() -} - // ToFloat casts all float-like variables in VarStore to `Float` dtype. // -// NOTE. float-like includes `Half`, `Float` and `Double` dtype. +// NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype. func (vs *VarStore) ToFloat() { vs.Root().ToFloat() } @@ -465,6 +468,20 @@ func (vs *VarStore) ToDouble() { vs.Root().ToDouble() } +// ToHalf casts all float-like variables in VarStore to `Half` dtype. +// +// NOTE. float-like includes `Half`, `Float` and `Double` dtype. +func (vs *VarStore) ToHalf() { + vs.Root().ToHalf() +} + +// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype. +// +// NOTE. float-like includes `Half`, `Float` and `Double` dtype. +func (vs *VarStore) ToBFloat16() { + vs.Root().ToBFloat16() +} + // Path methods: // ============= @@ -664,7 +681,7 @@ func (p *Path) SetGroup(g uint) { // ToDType casts all variables in this path and its sub-paths to the specified dtype. // // NOTE. this method should be used for floating-point conversion, i.e., -// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double". +// "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double". func (p *Path) ToDType(dtype gotch.DType) { p.varstore.Lock() defer p.varstore.Unlock() @@ -686,7 +703,7 @@ func (p *Path) toFloat(dtype gotch.DType) { for name, v := range p.varstore.vars { if strings.Contains(name, path) { dtype := v.Tensor.DType() - if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double { + if gotch.IsFloatDType(dtype) { newVar := v newVar.Tensor = v.Tensor.MustTotype(dtype, true) p.varstore.vars[name] = newVar @@ -695,19 +712,37 @@ func (p *Path) toFloat(dtype gotch.DType) { } } -// ToHalf casts all variables in current path and subpaths to `Half` precision. +// ToFloat casts all variables in current path and subpaths to `Float` precision. +func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) { + dtype := gotch.Float + if len(floatDTypeOpt) > 0 { + dt := floatDTypeOpt[0] + if !gotch.IsFloatDType(dt) { + // Ingore the option + if gotch.Debug { + log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt) + } + } else { + dtype = dt + } + } + + p.toFloat(dtype) +} + +// ToDouble casts all variables in current path and subpaths to `Double` precision dtype. +func (p *Path) ToDouble() { + p.toFloat(gotch.Double) +} + +// ToHalf casts all variables in current path and subpaths to `Half` precision dtype. func (p *Path) ToHalf() { p.toFloat(gotch.Half) } -// ToFloat casts all variables in current path and subpaths to `Float` precision. -func (p *Path) ToFloat() { - p.toFloat(gotch.Float) -} - -// ToDouble casts all variables in current path and subpaths to `Double` precision. -func (p *Path) ToDouble() { - p.toFloat(gotch.Double) +// ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype. +func (p *Path) ToBFloat16() { + p.toFloat(gotch.BFloat16) } // ZerosNoTrain creates a new variable initialized with zeros. @@ -718,7 +753,8 @@ func (p *Path) ToDouble() { // The variable uses a float tensor initialized with zeros. func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { device := p.Device() - z, err := ts.Zeros(dims, gotch.Float, device) + dtype := gotch.DefaultDType + z, err := ts.Zeros(dims, dtype, device) if err != nil { err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err) return nil, err @@ -755,7 +791,8 @@ func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.T // The variable uses a float tensor initialized with ones. func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { device := p.Device() - z, err := ts.Ones(dims, gotch.Float, device) + dtype := gotch.DefaultDType + z, err := ts.Ones(dims, dtype, device) if err != nil { err = fmt.Errorf("Path.OneNoTrain() failed: %w", err) return nil, err @@ -792,7 +829,8 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te // The variable uses a float tensor initialized as per the // related argument. func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) { - v := ini.InitTensor(dims, p.varstore.device) + dtype := gotch.DefaultDType + v := ini.InitTensor(dims, p.varstore.device, dtype) out, err := p.Add(name, v, true, opts...) if err != nil { return nil, err @@ -1098,7 +1136,8 @@ func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor { // OrOnesNoTrain returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { - o := ts.MustOnes(dims, gotch.Float, e.path.Device()) + dtype := gotch.DefaultDType + o := ts.MustOnes(dims, dtype, e.path.Device()) out, err := e.path.getOrAddWithLock(e.name, o, true, opts...) if err != nil { return nil, err @@ -1161,7 +1200,8 @@ func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts. // OrZerosNoTrain returns the existing entry if found, otherwise create a new variable. func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { - z := ts.MustZeros(dims, gotch.Float, e.path.Device()) + dtype := gotch.DefaultDType + z := ts.MustZeros(dims, dtype, e.path.Device()) out, err := e.path.getOrAddWithLock(e.name, z, true, opts...) if err != nil { return nil, err