added dtype option to nn package
This commit is contained in:
parent
523061eca6
commit
34e87b1302
27
dtype.go
27
dtype.go
|
@ -394,3 +394,30 @@ func DTypeFromData(data interface{}) (DType, error) {
|
||||||
// single element
|
// single element
|
||||||
return GoKind2DType(dataKind)
|
return GoKind2DType(dataKind)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsFloatDType returns whether dtype is floating point data type.
|
||||||
|
func IsFloatDType(dtype DType) bool {
|
||||||
|
switch dtype {
|
||||||
|
case Double, Float, Half, BFloat16:
|
||||||
|
return true
|
||||||
|
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default DType:
|
||||||
|
// ==============
|
||||||
|
var DefaultDType DType = Float
|
||||||
|
|
||||||
|
// SetDefaultDType set DefaultDType to new value and return the previous one.
|
||||||
|
func SetDefaultDType(dtype DType) DType {
|
||||||
|
odtype := DefaultDType
|
||||||
|
DefaultDType = dtype
|
||||||
|
|
||||||
|
if Debug {
|
||||||
|
log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
return odtype
|
||||||
|
}
|
||||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -1,67 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime/metrics"
|
|
||||||
// "math/rand"
|
|
||||||
// "time"
|
|
||||||
// "github.com/sugarme/gotch"
|
|
||||||
// "github.com/sugarme/gotch/ts"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Run with: `go build -gcflags="-m=3"`
|
|
||||||
|
|
||||||
//go:noinline
|
|
||||||
func main() {
|
|
||||||
s := []metrics.Sample{{Name: "/gc/stack/starting-size:bytes"}}
|
|
||||||
metrics.Read(s)
|
|
||||||
fmt.Printf("Initial stack size: %d\n", s[0].Value.Uint64())
|
|
||||||
|
|
||||||
// x, err := ts.Randn([]int64{2, 3, 224, 224}, gotch.Float, gotch.CPU)
|
|
||||||
// if err != nil {
|
|
||||||
// panic(err)
|
|
||||||
// }
|
|
||||||
// fmt.Printf("x: %v\n", x.Name())
|
|
||||||
|
|
||||||
// x := ts.MustOfSlice([]float32{1, 2, 3})
|
|
||||||
// fmt.Printf("x: %v\n", x.Name())
|
|
||||||
|
|
||||||
x := new(foo)
|
|
||||||
// x := newFoo()
|
|
||||||
// x := &foo{
|
|
||||||
// name: "foo",
|
|
||||||
// f: &foo1{foo1Name: "foo1"},
|
|
||||||
// }
|
|
||||||
|
|
||||||
fmt.Printf("x: %q\n", x.name)
|
|
||||||
|
|
||||||
// b := new(ts.Tensor)
|
|
||||||
// fmt.Printf("b: %v\n", b)
|
|
||||||
|
|
||||||
// time.Sleep(time.Second * 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
type foo struct {
|
|
||||||
data [1e4]interface{} // 10_000 * 4 = 40_000 bytes
|
|
||||||
name string
|
|
||||||
// f *foo1
|
|
||||||
}
|
|
||||||
|
|
||||||
type foo1 struct {
|
|
||||||
foo1Name string
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFoo() *foo {
|
|
||||||
return new(foo)
|
|
||||||
}
|
|
||||||
|
|
||||||
// func newData() []float32 {
|
|
||||||
// // n := 3 * 224 * 224 * 12
|
|
||||||
// n := 3
|
|
||||||
// data := make([]float32, n)
|
|
||||||
// for i := 0; i < n; i++ {
|
|
||||||
// data[i] = rand.Float32()
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// return data
|
|
||||||
// }
|
|
Binary file not shown.
|
@ -1,6 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
|
@ -25,8 +26,10 @@ func main() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pickle.LoadInfo(modelFile)
|
m, err := pickle.LoadModelInfo(modelFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println(m)
|
||||||
}
|
}
|
||||||
|
|
57
nn/init.go
57
nn/init.go
|
@ -12,7 +12,7 @@ import (
|
||||||
|
|
||||||
type Init interface {
|
type Init interface {
|
||||||
// creates a new tensor with specified initiation
|
// creates a new tensor with specified initiation
|
||||||
InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor)
|
InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor)
|
||||||
|
|
||||||
// re-initializes (in-place) an existing tensor with the specified initiation
|
// re-initializes (in-place) an existing tensor with the specified initiation
|
||||||
Set(tensor *ts.Tensor)
|
Set(tensor *ts.Tensor)
|
||||||
|
@ -25,18 +25,24 @@ type constInit struct {
|
||||||
value float64
|
value float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Init = new(constInit)
|
||||||
|
|
||||||
func NewConstInit(v float64) constInit {
|
func NewConstInit(v float64) constInit {
|
||||||
return constInit{v}
|
return constInit{v}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
func (c constInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||||
|
dtype := gotch.DefaultDType
|
||||||
|
if len(dtypeOpt) > 0 {
|
||||||
|
dtype = dtypeOpt[0]
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
kind := gotch.Float
|
|
||||||
switch {
|
switch {
|
||||||
case c.value == 0.0:
|
case c.value == 0.0:
|
||||||
retVal = ts.MustZeros(dims, kind, device)
|
retVal = ts.MustZeros(dims, dtype, device)
|
||||||
case c.value == 1.0:
|
case c.value == 1.0:
|
||||||
retVal = ts.MustOnes(dims, kind, device)
|
retVal = ts.MustOnes(dims, dtype, device)
|
||||||
default:
|
default:
|
||||||
data := make([]float64, ts.FlattenDim(dims))
|
data := make([]float64, ts.FlattenDim(dims))
|
||||||
for i := range data {
|
for i := range data {
|
||||||
|
@ -68,17 +74,24 @@ type randnInit struct {
|
||||||
stdev float64
|
stdev float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Init = new(randnInit)
|
||||||
|
|
||||||
func NewRandnInit(mean, stdev float64) randnInit {
|
func NewRandnInit(mean, stdev float64) randnInit {
|
||||||
return randnInit{mean, stdev}
|
return randnInit{mean, stdev}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
dtype := gotch.DefaultDType
|
||||||
if r.mean == 0 {
|
if len(dtypeOpt) > 0 {
|
||||||
return ts.MustRandn(dims, gotch.Float, device)
|
dtype = dtypeOpt[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
initTs := ts.MustRandn(dims, gotch.Float, device)
|
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||||
|
if r.mean == 0 {
|
||||||
|
return ts.MustRandn(dims, dtype, device)
|
||||||
|
}
|
||||||
|
|
||||||
|
initTs := ts.MustRandn(dims, dtype, device)
|
||||||
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,14 +114,20 @@ type uniformInit struct {
|
||||||
up float64
|
up float64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Init = new(uniformInit)
|
||||||
|
|
||||||
func NewUniformInit(lo, up float64) uniformInit {
|
func NewUniformInit(lo, up float64) uniformInit {
|
||||||
return uniformInit{lo, up}
|
return uniformInit{lo, up}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||||
|
dtype := gotch.DefaultDType
|
||||||
|
if len(dtypeOpt) > 0 {
|
||||||
|
dtype = dtypeOpt[0]
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
kind := gotch.Float
|
retVal = ts.MustZeros(dims, dtype, device)
|
||||||
retVal = ts.MustZeros(dims, kind, device)
|
|
||||||
retVal.Uniform_(u.lo, u.up)
|
retVal.Uniform_(u.lo, u.up)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||||
|
@ -174,6 +193,8 @@ type kaimingUniformInit struct {
|
||||||
NonLinearity string
|
NonLinearity string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var _ Init = new(kaimingUniformInit)
|
||||||
|
|
||||||
func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
|
func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
|
||||||
o := DefaultKaimingOptions()
|
o := DefaultKaimingOptions()
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
|
@ -187,7 +208,12 @@ func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||||
|
dtype := gotch.DefaultDType
|
||||||
|
if len(dtypeOpt) > 0 {
|
||||||
|
dtype = dtypeOpt[0]
|
||||||
|
}
|
||||||
|
|
||||||
fanIn, _, err := CalculateFans(dims)
|
fanIn, _, err := CalculateFans(dims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
@ -204,8 +230,7 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retV
|
||||||
// Calculate uniform bounds from standard deviation
|
// Calculate uniform bounds from standard deviation
|
||||||
bound := math.Sqrt(3.0) * std
|
bound := math.Sqrt(3.0) * std
|
||||||
|
|
||||||
kind := gotch.Float
|
retVal = ts.MustZeros(dims, dtype, device)
|
||||||
retVal = ts.MustZeros(dims, kind, device)
|
|
||||||
retVal.Uniform_(-bound, bound)
|
retVal.Uniform_(-bound, bound)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
|
|
29
nn/linear.go
29
nn/linear.go
|
@ -40,11 +40,14 @@ type Linear struct {
|
||||||
// outDim - output dimension (y) [output features - columns]
|
// outDim - output dimension (y) [output features - columns]
|
||||||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||||
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||||
|
dtype := gotch.DefaultDType
|
||||||
var bs *ts.Tensor
|
var bs *ts.Tensor
|
||||||
// bs has size of output dimension
|
// bs has size of output dimension
|
||||||
switch c.Bias {
|
switch c.Bias {
|
||||||
case false:
|
case false:
|
||||||
bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device())
|
// FIXME. do we need this? or just remove it and in the `Forward` creating on-fly
|
||||||
|
// with same dtype and device to the input.
|
||||||
|
bs = ts.MustZeros([]int64{outDim}, dtype, vs.Device())
|
||||||
case true:
|
case true:
|
||||||
switch {
|
switch {
|
||||||
case c.BsInit == nil:
|
case c.BsInit == nil:
|
||||||
|
@ -87,20 +90,19 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
// inDim := 3
|
// inDim := 3
|
||||||
// outDim := 2
|
// outDim := 2
|
||||||
// batchSize := 4
|
// batchSize := 4
|
||||||
// weights: 2x3
|
// weights: 2x3
|
||||||
// [ 1 1 1
|
// [ 1 1 1
|
||||||
// 1 1 1 ]
|
// 1 1 1 ]
|
||||||
//
|
//
|
||||||
// input node: 3x4
|
// input node: 3x4
|
||||||
// [ 1 1 1
|
// [ 1 1 1
|
||||||
// 1 1 1
|
// 1 1 1
|
||||||
// 1 1 1
|
// 1 1 1
|
||||||
// 1 1 1 ]
|
// 1 1 1 ]
|
||||||
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
||||||
|
|
||||||
mul := xs.MustMatmul(l.Ws, false)
|
mul := xs.MustMatmul(l.Ws, false)
|
||||||
return mul.MustAdd(l.Bs, true)
|
return mul.MustAdd(l.Bs, true)
|
||||||
}
|
}
|
||||||
|
@ -109,7 +111,6 @@ func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
||||||
//
|
//
|
||||||
// NOTE: train param will not be used.
|
// NOTE: train param will not be used.
|
||||||
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
||||||
|
|
||||||
mul := xs.MustMatmul(l.Ws, false)
|
mul := xs.MustMatmul(l.Ws, false)
|
||||||
return mul.MustAdd(l.Bs, true)
|
return mul.MustAdd(l.Bs, true)
|
||||||
}
|
}
|
||||||
|
|
|
@ -68,7 +68,7 @@ func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tenso
|
||||||
reduction := options.Reduction
|
reduction := options.Reduction
|
||||||
ignoreIndex := options.IgnoreIndex
|
ignoreIndex := options.IgnoreIndex
|
||||||
|
|
||||||
logSm := logits.MustLogSoftmax(-1, gotch.Float, false)
|
logSm := logits.MustLogSoftmax(-1, dtype, false)
|
||||||
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
|
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
|
||||||
ws.MustDrop()
|
ws.MustDrop()
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/ts"
|
"github.com/sugarme/gotch/ts"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -418,6 +417,10 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
||||||
)
|
)
|
||||||
|
|
||||||
device := opt.varstore.device
|
device := opt.varstore.device
|
||||||
|
|
||||||
|
// FIXME. What about mixed-precision?
|
||||||
|
dtype := parameters[0].DType()
|
||||||
|
|
||||||
if o.NormType == math.Inf(1) {
|
if o.NormType == math.Inf(1) {
|
||||||
for _, v := range opt.varstore.vars {
|
for _, v := range opt.varstore.vars {
|
||||||
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
|
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
|
||||||
|
@ -431,14 +434,14 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
||||||
|
|
||||||
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
|
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
|
||||||
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
|
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
|
||||||
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
|
||||||
norms = append(norms, x)
|
norms = append(norms, x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
|
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
|
||||||
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||||
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
|
||||||
for _, x := range norms {
|
for _, x := range norms {
|
||||||
x.MustDrop()
|
x.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
|
@ -149,7 +149,9 @@ func (l *LSTM) ZeroState(batchDim int64) State {
|
||||||
|
|
||||||
layerDim := l.config.NumLayers * numDirections
|
layerDim := l.config.NumLayers * numDirections
|
||||||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||||
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
|
||||||
|
dtype := l.flatWeights[0].DType()
|
||||||
|
zeros := ts.MustZeros(shape, dtype, l.device)
|
||||||
|
|
||||||
retVal := &LSTMState{
|
retVal := &LSTMState{
|
||||||
Tensor1: zeros.MustShallowClone(),
|
Tensor1: zeros.MustShallowClone(),
|
||||||
|
@ -269,7 +271,8 @@ func (g *GRU) ZeroState(batchDim int64) State {
|
||||||
layerDim := g.config.NumLayers * numDirections
|
layerDim := g.config.NumLayers * numDirections
|
||||||
shape := []int64{layerDim, batchDim, g.hiddenDim}
|
shape := []int64{layerDim, batchDim, g.hiddenDim}
|
||||||
|
|
||||||
tensor := ts.MustZeros(shape, gotch.Float, g.device)
|
dtype := g.flatWeights[0].DType()
|
||||||
|
tensor := ts.MustZeros(shape, dtype, g.device)
|
||||||
|
|
||||||
return &GRUState{Tensor: tensor}
|
return &GRUState{Tensor: tensor}
|
||||||
}
|
}
|
||||||
|
|
|
@ -417,12 +417,21 @@ func (vs *VarStore) Summary() {
|
||||||
layers = append(layers, name)
|
layers = append(layers, name)
|
||||||
}
|
}
|
||||||
sort.Strings(layers)
|
sort.Strings(layers)
|
||||||
|
var dtype gotch.DType
|
||||||
|
isFirst := true
|
||||||
for _, l := range layers {
|
for _, l := range layers {
|
||||||
var x *ts.Tensor
|
var x *ts.Tensor
|
||||||
var isBuffer bool
|
var isBuffer bool
|
||||||
for name, v := range vars {
|
for name, v := range vars {
|
||||||
if name == l {
|
if name == l {
|
||||||
x = v.Tensor
|
x = v.Tensor
|
||||||
|
|
||||||
|
// Get DType of first tensor for representation only
|
||||||
|
if isFirst {
|
||||||
|
dtype = x.DType()
|
||||||
|
}
|
||||||
|
isFirst = false
|
||||||
|
|
||||||
isBuffer = v.Type == "buffer"
|
isBuffer = v.Type == "buffer"
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -435,25 +444,19 @@ func (vs *VarStore) Summary() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Num of layers: %v\n", len(vars))
|
fmt.Printf("Num of layers: %v\n", len(vars))
|
||||||
|
fmt.Printf("DType: %v\n", dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToDType casts all variables in VarStore to specified DType.
|
// ToDType casts all variables in VarStore to specified DType.
|
||||||
//
|
//
|
||||||
// NOTE. only float-like types (Half, Float, Double) can ensure convertible.
|
// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.
|
||||||
func (vs *VarStore) ToDType(dtype gotch.DType) {
|
func (vs *VarStore) ToDType(dtype gotch.DType) {
|
||||||
vs.Root().ToDType(dtype)
|
vs.Root().ToDType(dtype)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
|
||||||
//
|
|
||||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
|
||||||
func (vs *VarStore) ToHalf() {
|
|
||||||
vs.Root().ToHalf()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
|
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
|
||||||
//
|
//
|
||||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
// NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype.
|
||||||
func (vs *VarStore) ToFloat() {
|
func (vs *VarStore) ToFloat() {
|
||||||
vs.Root().ToFloat()
|
vs.Root().ToFloat()
|
||||||
}
|
}
|
||||||
|
@ -465,6 +468,20 @@ func (vs *VarStore) ToDouble() {
|
||||||
vs.Root().ToDouble()
|
vs.Root().ToDouble()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
||||||
|
//
|
||||||
|
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||||
|
func (vs *VarStore) ToHalf() {
|
||||||
|
vs.Root().ToHalf()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype.
|
||||||
|
//
|
||||||
|
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||||
|
func (vs *VarStore) ToBFloat16() {
|
||||||
|
vs.Root().ToBFloat16()
|
||||||
|
}
|
||||||
|
|
||||||
// Path methods:
|
// Path methods:
|
||||||
// =============
|
// =============
|
||||||
|
|
||||||
|
@ -664,7 +681,7 @@ func (p *Path) SetGroup(g uint) {
|
||||||
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
|
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
|
||||||
//
|
//
|
||||||
// NOTE. this method should be used for floating-point conversion, i.e.,
|
// NOTE. this method should be used for floating-point conversion, i.e.,
|
||||||
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double".
|
// "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double".
|
||||||
func (p *Path) ToDType(dtype gotch.DType) {
|
func (p *Path) ToDType(dtype gotch.DType) {
|
||||||
p.varstore.Lock()
|
p.varstore.Lock()
|
||||||
defer p.varstore.Unlock()
|
defer p.varstore.Unlock()
|
||||||
|
@ -686,7 +703,7 @@ func (p *Path) toFloat(dtype gotch.DType) {
|
||||||
for name, v := range p.varstore.vars {
|
for name, v := range p.varstore.vars {
|
||||||
if strings.Contains(name, path) {
|
if strings.Contains(name, path) {
|
||||||
dtype := v.Tensor.DType()
|
dtype := v.Tensor.DType()
|
||||||
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double {
|
if gotch.IsFloatDType(dtype) {
|
||||||
newVar := v
|
newVar := v
|
||||||
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
||||||
p.varstore.vars[name] = newVar
|
p.varstore.vars[name] = newVar
|
||||||
|
@ -695,19 +712,37 @@ func (p *Path) toFloat(dtype gotch.DType) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToHalf casts all variables in current path and subpaths to `Half` precision.
|
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||||
|
func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) {
|
||||||
|
dtype := gotch.Float
|
||||||
|
if len(floatDTypeOpt) > 0 {
|
||||||
|
dt := floatDTypeOpt[0]
|
||||||
|
if !gotch.IsFloatDType(dt) {
|
||||||
|
// Ingore the option
|
||||||
|
if gotch.Debug {
|
||||||
|
log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
dtype = dt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.toFloat(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDouble casts all variables in current path and subpaths to `Double` precision dtype.
|
||||||
|
func (p *Path) ToDouble() {
|
||||||
|
p.toFloat(gotch.Double)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToHalf casts all variables in current path and subpaths to `Half` precision dtype.
|
||||||
func (p *Path) ToHalf() {
|
func (p *Path) ToHalf() {
|
||||||
p.toFloat(gotch.Half)
|
p.toFloat(gotch.Half)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
// ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype.
|
||||||
func (p *Path) ToFloat() {
|
func (p *Path) ToBFloat16() {
|
||||||
p.toFloat(gotch.Float)
|
p.toFloat(gotch.BFloat16)
|
||||||
}
|
|
||||||
|
|
||||||
// ToDouble casts all variables in current path and subpaths to `Double` precision.
|
|
||||||
func (p *Path) ToDouble() {
|
|
||||||
p.toFloat(gotch.Double)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||||
|
@ -718,7 +753,8 @@ func (p *Path) ToDouble() {
|
||||||
// The variable uses a float tensor initialized with zeros.
|
// The variable uses a float tensor initialized with zeros.
|
||||||
func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||||
device := p.Device()
|
device := p.Device()
|
||||||
z, err := ts.Zeros(dims, gotch.Float, device)
|
dtype := gotch.DefaultDType
|
||||||
|
z, err := ts.Zeros(dims, dtype, device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err)
|
err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -755,7 +791,8 @@ func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.T
|
||||||
// The variable uses a float tensor initialized with ones.
|
// The variable uses a float tensor initialized with ones.
|
||||||
func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||||
device := p.Device()
|
device := p.Device()
|
||||||
z, err := ts.Ones(dims, gotch.Float, device)
|
dtype := gotch.DefaultDType
|
||||||
|
z, err := ts.Ones(dims, dtype, device)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("Path.OneNoTrain() failed: %w", err)
|
err = fmt.Errorf("Path.OneNoTrain() failed: %w", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -792,7 +829,8 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
|
||||||
// The variable uses a float tensor initialized as per the
|
// The variable uses a float tensor initialized as per the
|
||||||
// related argument.
|
// related argument.
|
||||||
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
|
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
|
||||||
v := ini.InitTensor(dims, p.varstore.device)
|
dtype := gotch.DefaultDType
|
||||||
|
v := ini.InitTensor(dims, p.varstore.device, dtype)
|
||||||
out, err := p.Add(name, v, true, opts...)
|
out, err := p.Add(name, v, true, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -1098,7 +1136,8 @@ func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor {
|
||||||
|
|
||||||
// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable.
|
// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable.
|
||||||
func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||||
o := ts.MustOnes(dims, gotch.Float, e.path.Device())
|
dtype := gotch.DefaultDType
|
||||||
|
o := ts.MustOnes(dims, dtype, e.path.Device())
|
||||||
out, err := e.path.getOrAddWithLock(e.name, o, true, opts...)
|
out, err := e.path.getOrAddWithLock(e.name, o, true, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -1161,7 +1200,8 @@ func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.
|
||||||
|
|
||||||
// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable.
|
// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable.
|
||||||
func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||||
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
|
dtype := gotch.DefaultDType
|
||||||
|
z := ts.MustZeros(dims, dtype, e.path.Device())
|
||||||
out, err := e.path.getOrAddWithLock(e.name, z, true, opts...)
|
out, err := e.path.getOrAddWithLock(e.name, z, true, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Reference in New Issue
Block a user