added dtype option to nn package

This commit is contained in:
sugarme 2023-07-07 12:24:56 +10:00
parent 523061eca6
commit 34e87b1302
12 changed files with 164 additions and 129 deletions

View File

@ -394,3 +394,30 @@ func DTypeFromData(data interface{}) (DType, error) {
// single element // single element
return GoKind2DType(dataKind) return GoKind2DType(dataKind)
} }
// IsFloatDType returns whether dtype is floating point data type.
func IsFloatDType(dtype DType) bool {
switch dtype {
case Double, Float, Half, BFloat16:
return true
default:
return false
}
}
// Default DType:
// ==============
var DefaultDType DType = Float
// SetDefaultDType set DefaultDType to new value and return the previous one.
func SetDefaultDType(dtype DType) DType {
odtype := DefaultDType
DefaultDType = dtype
if Debug {
log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype)
}
return odtype
}

Binary file not shown.

Binary file not shown.

View File

@ -1,67 +0,0 @@
package main
import (
"fmt"
"runtime/metrics"
// "math/rand"
// "time"
// "github.com/sugarme/gotch"
// "github.com/sugarme/gotch/ts"
)
// Run with: `go build -gcflags="-m=3"`
//go:noinline
func main() {
s := []metrics.Sample{{Name: "/gc/stack/starting-size:bytes"}}
metrics.Read(s)
fmt.Printf("Initial stack size: %d\n", s[0].Value.Uint64())
// x, err := ts.Randn([]int64{2, 3, 224, 224}, gotch.Float, gotch.CPU)
// if err != nil {
// panic(err)
// }
// fmt.Printf("x: %v\n", x.Name())
// x := ts.MustOfSlice([]float32{1, 2, 3})
// fmt.Printf("x: %v\n", x.Name())
x := new(foo)
// x := newFoo()
// x := &foo{
// name: "foo",
// f: &foo1{foo1Name: "foo1"},
// }
fmt.Printf("x: %q\n", x.name)
// b := new(ts.Tensor)
// fmt.Printf("b: %v\n", b)
// time.Sleep(time.Second * 2)
}
type foo struct {
data [1e4]interface{} // 10_000 * 4 = 40_000 bytes
name string
// f *foo1
}
type foo1 struct {
foo1Name string
}
func newFoo() *foo {
return new(foo)
}
// func newData() []float32 {
// // n := 3 * 224 * 224 * 12
// n := 3
// data := make([]float32, n)
// for i := 0; i < n; i++ {
// data[i] = rand.Float32()
// }
//
// return data
// }

Binary file not shown.

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"fmt"
"log" "log"
"github.com/sugarme/gotch" "github.com/sugarme/gotch"
@ -25,8 +26,10 @@ func main() {
panic(err) panic(err)
} }
err = pickle.LoadInfo(modelFile) m, err := pickle.LoadModelInfo(modelFile)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
fmt.Println(m)
} }

View File

@ -12,7 +12,7 @@ import (
type Init interface { type Init interface {
// creates a new tensor with specified initiation // creates a new tensor with specified initiation
InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor)
// re-initializes (in-place) an existing tensor with the specified initiation // re-initializes (in-place) an existing tensor with the specified initiation
Set(tensor *ts.Tensor) Set(tensor *ts.Tensor)
@ -25,18 +25,24 @@ type constInit struct {
value float64 value float64
} }
var _ Init = new(constInit)
func NewConstInit(v float64) constInit { func NewConstInit(v float64) constInit {
return constInit{v} return constInit{v}
} }
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { func (c constInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
var err error var err error
kind := gotch.Float
switch { switch {
case c.value == 0.0: case c.value == 0.0:
retVal = ts.MustZeros(dims, kind, device) retVal = ts.MustZeros(dims, dtype, device)
case c.value == 1.0: case c.value == 1.0:
retVal = ts.MustOnes(dims, kind, device) retVal = ts.MustOnes(dims, dtype, device)
default: default:
data := make([]float64, ts.FlattenDim(dims)) data := make([]float64, ts.FlattenDim(dims))
for i := range data { for i := range data {
@ -68,17 +74,24 @@ type randnInit struct {
stdev float64 stdev float64
} }
var _ Init = new(randnInit)
func NewRandnInit(mean, stdev float64) randnInit { func NewRandnInit(mean, stdev float64) randnInit {
return randnInit{mean, stdev} return randnInit{mean, stdev}
} }
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 { dtype := gotch.DefaultDType
if r.mean == 0 { if len(dtypeOpt) > 0 {
return ts.MustRandn(dims, gotch.Float, device) dtype = dtypeOpt[0]
} }
initTs := ts.MustRandn(dims, gotch.Float, device) // if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
if r.mean == 0 {
return ts.MustRandn(dims, dtype, device)
}
initTs := ts.MustRandn(dims, dtype, device)
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true) return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
} }
@ -101,14 +114,20 @@ type uniformInit struct {
up float64 up float64
} }
var _ Init = new(uniformInit)
func NewUniformInit(lo, up float64) uniformInit { func NewUniformInit(lo, up float64) uniformInit {
return uniformInit{lo, up} return uniformInit{lo, up}
} }
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
var err error var err error
kind := gotch.Float retVal = ts.MustZeros(dims, dtype, device)
retVal = ts.MustZeros(dims, kind, device)
retVal.Uniform_(u.lo, u.up) retVal.Uniform_(u.lo, u.up)
if err != nil { if err != nil {
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err) log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
@ -174,6 +193,8 @@ type kaimingUniformInit struct {
NonLinearity string NonLinearity string
} }
var _ Init = new(kaimingUniformInit)
func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit { func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
o := DefaultKaimingOptions() o := DefaultKaimingOptions()
for _, opt := range opts { for _, opt := range opts {
@ -187,7 +208,12 @@ func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
} }
} }
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) { func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
dtype := gotch.DefaultDType
if len(dtypeOpt) > 0 {
dtype = dtypeOpt[0]
}
fanIn, _, err := CalculateFans(dims) fanIn, _, err := CalculateFans(dims)
if err != nil { if err != nil {
panic(err) panic(err)
@ -204,8 +230,7 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retV
// Calculate uniform bounds from standard deviation // Calculate uniform bounds from standard deviation
bound := math.Sqrt(3.0) * std bound := math.Sqrt(3.0) * std
kind := gotch.Float retVal = ts.MustZeros(dims, dtype, device)
retVal = ts.MustZeros(dims, kind, device)
retVal.Uniform_(-bound, bound) retVal.Uniform_(-bound, bound)
return retVal return retVal

View File

@ -40,11 +40,14 @@ type Linear struct {
// outDim - output dimension (y) [output features - columns] // outDim - output dimension (y) [output features - columns]
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim} // NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear { func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
dtype := gotch.DefaultDType
var bs *ts.Tensor var bs *ts.Tensor
// bs has size of output dimension // bs has size of output dimension
switch c.Bias { switch c.Bias {
case false: case false:
bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device()) // FIXME. do we need this? or just remove it and in the `Forward` creating on-fly
// with same dtype and device to the input.
bs = ts.MustZeros([]int64{outDim}, dtype, vs.Device())
case true: case true:
switch { switch {
case c.BsInit == nil: case c.BsInit == nil:
@ -87,20 +90,19 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
// //
// Example: // Example:
// //
// inDim := 3 // inDim := 3
// outDim := 2 // outDim := 2
// batchSize := 4 // batchSize := 4
// weights: 2x3 // weights: 2x3
// [ 1 1 1 // [ 1 1 1
// 1 1 1 ] // 1 1 1 ]
// //
// input node: 3x4 // input node: 3x4
// [ 1 1 1 // [ 1 1 1
// 1 1 1 // 1 1 1
// 1 1 1 // 1 1 1
// 1 1 1 ] // 1 1 1 ]
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) { func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false) mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true) return mul.MustAdd(l.Bs, true)
} }
@ -109,7 +111,6 @@ func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
// //
// NOTE: train param will not be used. // NOTE: train param will not be used.
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) { func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
mul := xs.MustMatmul(l.Ws, false) mul := xs.MustMatmul(l.Ws, false)
return mul.MustAdd(l.Bs, true) return mul.MustAdd(l.Bs, true)
} }

View File

@ -68,7 +68,7 @@ func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tenso
reduction := options.Reduction reduction := options.Reduction
ignoreIndex := options.IgnoreIndex ignoreIndex := options.IgnoreIndex
logSm := logits.MustLogSoftmax(-1, gotch.Float, false) logSm := logits.MustLogSoftmax(-1, dtype, false)
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true) loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
ws.MustDrop() ws.MustDrop()

View File

@ -7,7 +7,6 @@ import (
"log" "log"
"math" "math"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/ts" "github.com/sugarme/gotch/ts"
) )
@ -418,6 +417,10 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
) )
device := opt.varstore.device device := opt.varstore.device
// FIXME. What about mixed-precision?
dtype := parameters[0].DType()
if o.NormType == math.Inf(1) { if o.NormType == math.Inf(1) {
for _, v := range opt.varstore.vars { for _, v := range opt.varstore.vars {
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true) n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
@ -431,14 +434,14 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm // NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm // Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true) x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
norms = append(norms, x) norms = append(norms, x)
} }
} }
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true) // totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) // total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true) totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
for _, x := range norms { for _, x := range norms {
x.MustDrop() x.MustDrop()
} }

View File

@ -149,7 +149,9 @@ func (l *LSTM) ZeroState(batchDim int64) State {
layerDim := l.config.NumLayers * numDirections layerDim := l.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, l.hiddenDim} shape := []int64{layerDim, batchDim, l.hiddenDim}
zeros := ts.MustZeros(shape, gotch.Float, l.device)
dtype := l.flatWeights[0].DType()
zeros := ts.MustZeros(shape, dtype, l.device)
retVal := &LSTMState{ retVal := &LSTMState{
Tensor1: zeros.MustShallowClone(), Tensor1: zeros.MustShallowClone(),
@ -269,7 +271,8 @@ func (g *GRU) ZeroState(batchDim int64) State {
layerDim := g.config.NumLayers * numDirections layerDim := g.config.NumLayers * numDirections
shape := []int64{layerDim, batchDim, g.hiddenDim} shape := []int64{layerDim, batchDim, g.hiddenDim}
tensor := ts.MustZeros(shape, gotch.Float, g.device) dtype := g.flatWeights[0].DType()
tensor := ts.MustZeros(shape, dtype, g.device)
return &GRUState{Tensor: tensor} return &GRUState{Tensor: tensor}
} }

View File

@ -417,12 +417,21 @@ func (vs *VarStore) Summary() {
layers = append(layers, name) layers = append(layers, name)
} }
sort.Strings(layers) sort.Strings(layers)
var dtype gotch.DType
isFirst := true
for _, l := range layers { for _, l := range layers {
var x *ts.Tensor var x *ts.Tensor
var isBuffer bool var isBuffer bool
for name, v := range vars { for name, v := range vars {
if name == l { if name == l {
x = v.Tensor x = v.Tensor
// Get DType of first tensor for representation only
if isFirst {
dtype = x.DType()
}
isFirst = false
isBuffer = v.Type == "buffer" isBuffer = v.Type == "buffer"
break break
} }
@ -435,25 +444,19 @@ func (vs *VarStore) Summary() {
} }
fmt.Printf("Num of layers: %v\n", len(vars)) fmt.Printf("Num of layers: %v\n", len(vars))
fmt.Printf("DType: %v\n", dtype)
} }
// ToDType casts all variables in VarStore to specified DType. // ToDType casts all variables in VarStore to specified DType.
// //
// NOTE. only float-like types (Half, Float, Double) can ensure convertible. // NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.
func (vs *VarStore) ToDType(dtype gotch.DType) { func (vs *VarStore) ToDType(dtype gotch.DType) {
vs.Root().ToDType(dtype) vs.Root().ToDType(dtype)
} }
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToHalf() {
vs.Root().ToHalf()
}
// ToFloat casts all float-like variables in VarStore to `Float` dtype. // ToFloat casts all float-like variables in VarStore to `Float` dtype.
// //
// NOTE. float-like includes `Half`, `Float` and `Double` dtype. // NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype.
func (vs *VarStore) ToFloat() { func (vs *VarStore) ToFloat() {
vs.Root().ToFloat() vs.Root().ToFloat()
} }
@ -465,6 +468,20 @@ func (vs *VarStore) ToDouble() {
vs.Root().ToDouble() vs.Root().ToDouble()
} }
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToHalf() {
vs.Root().ToHalf()
}
// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype.
//
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
func (vs *VarStore) ToBFloat16() {
vs.Root().ToBFloat16()
}
// Path methods: // Path methods:
// ============= // =============
@ -664,7 +681,7 @@ func (p *Path) SetGroup(g uint) {
// ToDType casts all variables in this path and its sub-paths to the specified dtype. // ToDType casts all variables in this path and its sub-paths to the specified dtype.
// //
// NOTE. this method should be used for floating-point conversion, i.e., // NOTE. this method should be used for floating-point conversion, i.e.,
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double". // "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double".
func (p *Path) ToDType(dtype gotch.DType) { func (p *Path) ToDType(dtype gotch.DType) {
p.varstore.Lock() p.varstore.Lock()
defer p.varstore.Unlock() defer p.varstore.Unlock()
@ -686,7 +703,7 @@ func (p *Path) toFloat(dtype gotch.DType) {
for name, v := range p.varstore.vars { for name, v := range p.varstore.vars {
if strings.Contains(name, path) { if strings.Contains(name, path) {
dtype := v.Tensor.DType() dtype := v.Tensor.DType()
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double { if gotch.IsFloatDType(dtype) {
newVar := v newVar := v
newVar.Tensor = v.Tensor.MustTotype(dtype, true) newVar.Tensor = v.Tensor.MustTotype(dtype, true)
p.varstore.vars[name] = newVar p.varstore.vars[name] = newVar
@ -695,19 +712,37 @@ func (p *Path) toFloat(dtype gotch.DType) {
} }
} }
// ToHalf casts all variables in current path and subpaths to `Half` precision. // ToFloat casts all variables in current path and subpaths to `Float` precision.
func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) {
dtype := gotch.Float
if len(floatDTypeOpt) > 0 {
dt := floatDTypeOpt[0]
if !gotch.IsFloatDType(dt) {
// Ingore the option
if gotch.Debug {
log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt)
}
} else {
dtype = dt
}
}
p.toFloat(dtype)
}
// ToDouble casts all variables in current path and subpaths to `Double` precision dtype.
func (p *Path) ToDouble() {
p.toFloat(gotch.Double)
}
// ToHalf casts all variables in current path and subpaths to `Half` precision dtype.
func (p *Path) ToHalf() { func (p *Path) ToHalf() {
p.toFloat(gotch.Half) p.toFloat(gotch.Half)
} }
// ToFloat casts all variables in current path and subpaths to `Float` precision. // ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype.
func (p *Path) ToFloat() { func (p *Path) ToBFloat16() {
p.toFloat(gotch.Float) p.toFloat(gotch.BFloat16)
}
// ToDouble casts all variables in current path and subpaths to `Double` precision.
func (p *Path) ToDouble() {
p.toFloat(gotch.Double)
} }
// ZerosNoTrain creates a new variable initialized with zeros. // ZerosNoTrain creates a new variable initialized with zeros.
@ -718,7 +753,8 @@ func (p *Path) ToDouble() {
// The variable uses a float tensor initialized with zeros. // The variable uses a float tensor initialized with zeros.
func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
device := p.Device() device := p.Device()
z, err := ts.Zeros(dims, gotch.Float, device) dtype := gotch.DefaultDType
z, err := ts.Zeros(dims, dtype, device)
if err != nil { if err != nil {
err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err) err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err)
return nil, err return nil, err
@ -755,7 +791,8 @@ func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.T
// The variable uses a float tensor initialized with ones. // The variable uses a float tensor initialized with ones.
func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) { func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
device := p.Device() device := p.Device()
z, err := ts.Ones(dims, gotch.Float, device) dtype := gotch.DefaultDType
z, err := ts.Ones(dims, dtype, device)
if err != nil { if err != nil {
err = fmt.Errorf("Path.OneNoTrain() failed: %w", err) err = fmt.Errorf("Path.OneNoTrain() failed: %w", err)
return nil, err return nil, err
@ -792,7 +829,8 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
// The variable uses a float tensor initialized as per the // The variable uses a float tensor initialized as per the
// related argument. // related argument.
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) { func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
v := ini.InitTensor(dims, p.varstore.device) dtype := gotch.DefaultDType
v := ini.InitTensor(dims, p.varstore.device, dtype)
out, err := p.Add(name, v, true, opts...) out, err := p.Add(name, v, true, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1098,7 +1136,8 @@ func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor {
// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable. // OrOnesNoTrain returns the existing entry if found, otherwise create a new variable.
func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
o := ts.MustOnes(dims, gotch.Float, e.path.Device()) dtype := gotch.DefaultDType
o := ts.MustOnes(dims, dtype, e.path.Device())
out, err := e.path.getOrAddWithLock(e.name, o, true, opts...) out, err := e.path.getOrAddWithLock(e.name, o, true, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -1161,7 +1200,8 @@ func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.
// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable. // OrZerosNoTrain returns the existing entry if found, otherwise create a new variable.
func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) { func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
z := ts.MustZeros(dims, gotch.Float, e.path.Device()) dtype := gotch.DefaultDType
z := ts.MustZeros(dims, dtype, e.path.Device())
out, err := e.path.getOrAddWithLock(e.name, z, true, opts...) out, err := e.path.getOrAddWithLock(e.name, z, true, opts...)
if err != nil { if err != nil {
return nil, err return nil, err