added dtype option to nn package
This commit is contained in:
parent
523061eca6
commit
34e87b1302
27
dtype.go
27
dtype.go
|
@ -394,3 +394,30 @@ func DTypeFromData(data interface{}) (DType, error) {
|
|||
// single element
|
||||
return GoKind2DType(dataKind)
|
||||
}
|
||||
|
||||
// IsFloatDType returns whether dtype is floating point data type.
|
||||
func IsFloatDType(dtype DType) bool {
|
||||
switch dtype {
|
||||
case Double, Float, Half, BFloat16:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Default DType:
|
||||
// ==============
|
||||
var DefaultDType DType = Float
|
||||
|
||||
// SetDefaultDType set DefaultDType to new value and return the previous one.
|
||||
func SetDefaultDType(dtype DType) DType {
|
||||
odtype := DefaultDType
|
||||
DefaultDType = dtype
|
||||
|
||||
if Debug {
|
||||
log.Printf("INFO: gotch 'DefaultDType' has changed to %v. Remember to change back to previous default to avoid unexpected outcome.\n", dtype)
|
||||
}
|
||||
|
||||
return odtype
|
||||
}
|
||||
|
|
Binary file not shown.
Binary file not shown.
|
@ -1,67 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime/metrics"
|
||||
// "math/rand"
|
||||
// "time"
|
||||
// "github.com/sugarme/gotch"
|
||||
// "github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
// Run with: `go build -gcflags="-m=3"`
|
||||
|
||||
//go:noinline
|
||||
func main() {
|
||||
s := []metrics.Sample{{Name: "/gc/stack/starting-size:bytes"}}
|
||||
metrics.Read(s)
|
||||
fmt.Printf("Initial stack size: %d\n", s[0].Value.Uint64())
|
||||
|
||||
// x, err := ts.Randn([]int64{2, 3, 224, 224}, gotch.Float, gotch.CPU)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// fmt.Printf("x: %v\n", x.Name())
|
||||
|
||||
// x := ts.MustOfSlice([]float32{1, 2, 3})
|
||||
// fmt.Printf("x: %v\n", x.Name())
|
||||
|
||||
x := new(foo)
|
||||
// x := newFoo()
|
||||
// x := &foo{
|
||||
// name: "foo",
|
||||
// f: &foo1{foo1Name: "foo1"},
|
||||
// }
|
||||
|
||||
fmt.Printf("x: %q\n", x.name)
|
||||
|
||||
// b := new(ts.Tensor)
|
||||
// fmt.Printf("b: %v\n", b)
|
||||
|
||||
// time.Sleep(time.Second * 2)
|
||||
}
|
||||
|
||||
type foo struct {
|
||||
data [1e4]interface{} // 10_000 * 4 = 40_000 bytes
|
||||
name string
|
||||
// f *foo1
|
||||
}
|
||||
|
||||
type foo1 struct {
|
||||
foo1Name string
|
||||
}
|
||||
|
||||
func newFoo() *foo {
|
||||
return new(foo)
|
||||
}
|
||||
|
||||
// func newData() []float32 {
|
||||
// // n := 3 * 224 * 224 * 12
|
||||
// n := 3
|
||||
// data := make([]float32, n)
|
||||
// for i := 0; i < n; i++ {
|
||||
// data[i] = rand.Float32()
|
||||
// }
|
||||
//
|
||||
// return data
|
||||
// }
|
Binary file not shown.
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
|
@ -25,8 +26,10 @@ func main() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
err = pickle.LoadInfo(modelFile)
|
||||
m, err := pickle.LoadModelInfo(modelFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(m)
|
||||
}
|
||||
|
|
57
nn/init.go
57
nn/init.go
|
@ -12,7 +12,7 @@ import (
|
|||
|
||||
type Init interface {
|
||||
// creates a new tensor with specified initiation
|
||||
InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor)
|
||||
InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor)
|
||||
|
||||
// re-initializes (in-place) an existing tensor with the specified initiation
|
||||
Set(tensor *ts.Tensor)
|
||||
|
@ -25,18 +25,24 @@ type constInit struct {
|
|||
value float64
|
||||
}
|
||||
|
||||
var _ Init = new(constInit)
|
||||
|
||||
func NewConstInit(v float64) constInit {
|
||||
return constInit{v}
|
||||
}
|
||||
|
||||
func (c constInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
func (c constInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
var err error
|
||||
kind := gotch.Float
|
||||
switch {
|
||||
case c.value == 0.0:
|
||||
retVal = ts.MustZeros(dims, kind, device)
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
case c.value == 1.0:
|
||||
retVal = ts.MustOnes(dims, kind, device)
|
||||
retVal = ts.MustOnes(dims, dtype, device)
|
||||
default:
|
||||
data := make([]float64, ts.FlattenDim(dims))
|
||||
for i := range data {
|
||||
|
@ -68,17 +74,24 @@ type randnInit struct {
|
|||
stdev float64
|
||||
}
|
||||
|
||||
var _ Init = new(randnInit)
|
||||
|
||||
func NewRandnInit(mean, stdev float64) randnInit {
|
||||
return randnInit{mean, stdev}
|
||||
}
|
||||
|
||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||
if r.mean == 0 {
|
||||
return ts.MustRandn(dims, gotch.Float, device)
|
||||
func (r randnInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
initTs := ts.MustRandn(dims, gotch.Float, device)
|
||||
// if r.mean == 0 && math.Abs(r.stdev-1) <= math.SmallestNonzeroFloat64 {
|
||||
if r.mean == 0 {
|
||||
return ts.MustRandn(dims, dtype, device)
|
||||
}
|
||||
|
||||
initTs := ts.MustRandn(dims, dtype, device)
|
||||
return initTs.MustMulScalar(ts.FloatScalar(r.stdev), true).MustAddScalar(ts.FloatScalar(r.mean), true)
|
||||
}
|
||||
|
||||
|
@ -101,14 +114,20 @@ type uniformInit struct {
|
|||
up float64
|
||||
}
|
||||
|
||||
var _ Init = new(uniformInit)
|
||||
|
||||
func NewUniformInit(lo, up float64) uniformInit {
|
||||
return uniformInit{lo, up}
|
||||
}
|
||||
|
||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
func (u uniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
var err error
|
||||
kind := gotch.Float
|
||||
retVal = ts.MustZeros(dims, kind, device)
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(u.lo, u.up)
|
||||
if err != nil {
|
||||
log.Fatalf("uniformInit - InitTensor method call error: %v\n", err)
|
||||
|
@ -174,6 +193,8 @@ type kaimingUniformInit struct {
|
|||
NonLinearity string
|
||||
}
|
||||
|
||||
var _ Init = new(kaimingUniformInit)
|
||||
|
||||
func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
|
||||
o := DefaultKaimingOptions()
|
||||
for _, opt := range opts {
|
||||
|
@ -187,7 +208,12 @@ func NewKaimingUniformInit(opts ...KaimingOption) *kaimingUniformInit {
|
|||
}
|
||||
}
|
||||
|
||||
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal *ts.Tensor) {
|
||||
func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device, dtypeOpt ...gotch.DType) (retVal *ts.Tensor) {
|
||||
dtype := gotch.DefaultDType
|
||||
if len(dtypeOpt) > 0 {
|
||||
dtype = dtypeOpt[0]
|
||||
}
|
||||
|
||||
fanIn, _, err := CalculateFans(dims)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -204,8 +230,7 @@ func (k *kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retV
|
|||
// Calculate uniform bounds from standard deviation
|
||||
bound := math.Sqrt(3.0) * std
|
||||
|
||||
kind := gotch.Float
|
||||
retVal = ts.MustZeros(dims, kind, device)
|
||||
retVal = ts.MustZeros(dims, dtype, device)
|
||||
retVal.Uniform_(-bound, bound)
|
||||
|
||||
return retVal
|
||||
|
|
29
nn/linear.go
29
nn/linear.go
|
@ -40,11 +40,14 @@ type Linear struct {
|
|||
// outDim - output dimension (y) [output features - columns]
|
||||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||
func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||
dtype := gotch.DefaultDType
|
||||
var bs *ts.Tensor
|
||||
// bs has size of output dimension
|
||||
switch c.Bias {
|
||||
case false:
|
||||
bs = ts.MustZeros([]int64{outDim}, gotch.Float, vs.Device())
|
||||
// FIXME. do we need this? or just remove it and in the `Forward` creating on-fly
|
||||
// with same dtype and device to the input.
|
||||
bs = ts.MustZeros([]int64{outDim}, dtype, vs.Device())
|
||||
case true:
|
||||
switch {
|
||||
case c.BsInit == nil:
|
||||
|
@ -87,20 +90,19 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||
//
|
||||
// Example:
|
||||
//
|
||||
// inDim := 3
|
||||
// outDim := 2
|
||||
// batchSize := 4
|
||||
// weights: 2x3
|
||||
// [ 1 1 1
|
||||
// 1 1 1 ]
|
||||
// inDim := 3
|
||||
// outDim := 2
|
||||
// batchSize := 4
|
||||
// weights: 2x3
|
||||
// [ 1 1 1
|
||||
// 1 1 1 ]
|
||||
//
|
||||
// input node: 3x4
|
||||
// [ 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
// input node: 3x4
|
||||
// [ 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
||||
|
||||
mul := xs.MustMatmul(l.Ws, false)
|
||||
return mul.MustAdd(l.Bs, true)
|
||||
}
|
||||
|
@ -109,7 +111,6 @@ func (l *Linear) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
|
|||
//
|
||||
// NOTE: train param will not be used.
|
||||
func (l *Linear) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
||||
|
||||
mul := xs.MustMatmul(l.Ws, false)
|
||||
return mul.MustAdd(l.Bs, true)
|
||||
}
|
||||
|
|
|
@ -68,7 +68,7 @@ func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tenso
|
|||
reduction := options.Reduction
|
||||
ignoreIndex := options.IgnoreIndex
|
||||
|
||||
logSm := logits.MustLogSoftmax(-1, gotch.Float, false)
|
||||
logSm := logits.MustLogSoftmax(-1, dtype, false)
|
||||
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
|
||||
ws.MustDrop()
|
||||
|
||||
|
|
|
@ -7,7 +7,6 @@ import (
|
|||
"log"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/ts"
|
||||
)
|
||||
|
||||
|
@ -418,6 +417,10 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
|||
)
|
||||
|
||||
device := opt.varstore.device
|
||||
|
||||
// FIXME. What about mixed-precision?
|
||||
dtype := parameters[0].DType()
|
||||
|
||||
if o.NormType == math.Inf(1) {
|
||||
for _, v := range opt.varstore.vars {
|
||||
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
|
||||
|
@ -431,14 +434,14 @@ func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
|||
|
||||
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
|
||||
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
|
||||
norms = append(norms, x)
|
||||
}
|
||||
}
|
||||
|
||||
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
|
||||
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, dtype, true)
|
||||
for _, x := range norms {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
|
|
@ -149,7 +149,9 @@ func (l *LSTM) ZeroState(batchDim int64) State {
|
|||
|
||||
layerDim := l.config.NumLayers * numDirections
|
||||
shape := []int64{layerDim, batchDim, l.hiddenDim}
|
||||
zeros := ts.MustZeros(shape, gotch.Float, l.device)
|
||||
|
||||
dtype := l.flatWeights[0].DType()
|
||||
zeros := ts.MustZeros(shape, dtype, l.device)
|
||||
|
||||
retVal := &LSTMState{
|
||||
Tensor1: zeros.MustShallowClone(),
|
||||
|
@ -269,7 +271,8 @@ func (g *GRU) ZeroState(batchDim int64) State {
|
|||
layerDim := g.config.NumLayers * numDirections
|
||||
shape := []int64{layerDim, batchDim, g.hiddenDim}
|
||||
|
||||
tensor := ts.MustZeros(shape, gotch.Float, g.device)
|
||||
dtype := g.flatWeights[0].DType()
|
||||
tensor := ts.MustZeros(shape, dtype, g.device)
|
||||
|
||||
return &GRUState{Tensor: tensor}
|
||||
}
|
||||
|
|
|
@ -417,12 +417,21 @@ func (vs *VarStore) Summary() {
|
|||
layers = append(layers, name)
|
||||
}
|
||||
sort.Strings(layers)
|
||||
var dtype gotch.DType
|
||||
isFirst := true
|
||||
for _, l := range layers {
|
||||
var x *ts.Tensor
|
||||
var isBuffer bool
|
||||
for name, v := range vars {
|
||||
if name == l {
|
||||
x = v.Tensor
|
||||
|
||||
// Get DType of first tensor for representation only
|
||||
if isFirst {
|
||||
dtype = x.DType()
|
||||
}
|
||||
isFirst = false
|
||||
|
||||
isBuffer = v.Type == "buffer"
|
||||
break
|
||||
}
|
||||
|
@ -435,25 +444,19 @@ func (vs *VarStore) Summary() {
|
|||
}
|
||||
|
||||
fmt.Printf("Num of layers: %v\n", len(vars))
|
||||
fmt.Printf("DType: %v\n", dtype)
|
||||
}
|
||||
|
||||
// ToDType casts all variables in VarStore to specified DType.
|
||||
//
|
||||
// NOTE. only float-like types (Half, Float, Double) can ensure convertible.
|
||||
// NOTE. only float-like types (Half, BFloat16, Float, Double) can ensure convertible.
|
||||
func (vs *VarStore) ToDType(dtype gotch.DType) {
|
||||
vs.Root().ToDType(dtype)
|
||||
}
|
||||
|
||||
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToHalf() {
|
||||
vs.Root().ToHalf()
|
||||
}
|
||||
|
||||
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
// NOTE. float-like includes `Half`,`BFloat16`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToFloat() {
|
||||
vs.Root().ToFloat()
|
||||
}
|
||||
|
@ -465,6 +468,20 @@ func (vs *VarStore) ToDouble() {
|
|||
vs.Root().ToDouble()
|
||||
}
|
||||
|
||||
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToHalf() {
|
||||
vs.Root().ToHalf()
|
||||
}
|
||||
|
||||
// ToBFloat16 casts all float-like variables in VarStore to `BFloat16` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToBFloat16() {
|
||||
vs.Root().ToBFloat16()
|
||||
}
|
||||
|
||||
// Path methods:
|
||||
// =============
|
||||
|
||||
|
@ -664,7 +681,7 @@ func (p *Path) SetGroup(g uint) {
|
|||
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
|
||||
//
|
||||
// NOTE. this method should be used for floating-point conversion, i.e.,
|
||||
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double".
|
||||
// "gotch.Float", "gotch.Half", "gotch.BFloat16", "gotch.Double".
|
||||
func (p *Path) ToDType(dtype gotch.DType) {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
|
@ -686,7 +703,7 @@ func (p *Path) toFloat(dtype gotch.DType) {
|
|||
for name, v := range p.varstore.vars {
|
||||
if strings.Contains(name, path) {
|
||||
dtype := v.Tensor.DType()
|
||||
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double {
|
||||
if gotch.IsFloatDType(dtype) {
|
||||
newVar := v
|
||||
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
||||
p.varstore.vars[name] = newVar
|
||||
|
@ -695,19 +712,37 @@ func (p *Path) toFloat(dtype gotch.DType) {
|
|||
}
|
||||
}
|
||||
|
||||
// ToHalf casts all variables in current path and subpaths to `Half` precision.
|
||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||
func (p *Path) ToFloat(floatDTypeOpt ...gotch.DType) {
|
||||
dtype := gotch.Float
|
||||
if len(floatDTypeOpt) > 0 {
|
||||
dt := floatDTypeOpt[0]
|
||||
if !gotch.IsFloatDType(dt) {
|
||||
// Ingore the option
|
||||
if gotch.Debug {
|
||||
log.Printf("WARNING: nn.Path.ToFloat() input dtype is invalid float DType %v. Just ignoring...\n", dt)
|
||||
}
|
||||
} else {
|
||||
dtype = dt
|
||||
}
|
||||
}
|
||||
|
||||
p.toFloat(dtype)
|
||||
}
|
||||
|
||||
// ToDouble casts all variables in current path and subpaths to `Double` precision dtype.
|
||||
func (p *Path) ToDouble() {
|
||||
p.toFloat(gotch.Double)
|
||||
}
|
||||
|
||||
// ToHalf casts all variables in current path and subpaths to `Half` precision dtype.
|
||||
func (p *Path) ToHalf() {
|
||||
p.toFloat(gotch.Half)
|
||||
}
|
||||
|
||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||
func (p *Path) ToFloat() {
|
||||
p.toFloat(gotch.Float)
|
||||
}
|
||||
|
||||
// ToDouble casts all variables in current path and subpaths to `Double` precision.
|
||||
func (p *Path) ToDouble() {
|
||||
p.toFloat(gotch.Double)
|
||||
// ToBFloat16() converts all variables in current path and subpaths to `BFloat16` dtype.
|
||||
func (p *Path) ToBFloat16() {
|
||||
p.toFloat(gotch.BFloat16)
|
||||
}
|
||||
|
||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||
|
@ -718,7 +753,8 @@ func (p *Path) ToDouble() {
|
|||
// The variable uses a float tensor initialized with zeros.
|
||||
func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
device := p.Device()
|
||||
z, err := ts.Zeros(dims, gotch.Float, device)
|
||||
dtype := gotch.DefaultDType
|
||||
z, err := ts.Zeros(dims, dtype, device)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Path.ZerosNoTrain() failed: %w", err)
|
||||
return nil, err
|
||||
|
@ -755,7 +791,8 @@ func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.T
|
|||
// The variable uses a float tensor initialized with ones.
|
||||
func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
device := p.Device()
|
||||
z, err := ts.Ones(dims, gotch.Float, device)
|
||||
dtype := gotch.DefaultDType
|
||||
z, err := ts.Ones(dims, dtype, device)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Path.OneNoTrain() failed: %w", err)
|
||||
return nil, err
|
||||
|
@ -792,7 +829,8 @@ func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Te
|
|||
// The variable uses a float tensor initialized as per the
|
||||
// related argument.
|
||||
func (p *Path) NewVar(name string, dims []int64, ini Init, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
v := ini.InitTensor(dims, p.varstore.device)
|
||||
dtype := gotch.DefaultDType
|
||||
v := ini.InitTensor(dims, p.varstore.device, dtype)
|
||||
out, err := p.Add(name, v, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1098,7 +1136,8 @@ func (e *Entry) MustOrOnes(dims []int64, opts ...AddOpt) *ts.Tensor {
|
|||
|
||||
// OrOnesNoTrain returns the existing entry if found, otherwise create a new variable.
|
||||
func (e *Entry) OrOnesNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
o := ts.MustOnes(dims, gotch.Float, e.path.Device())
|
||||
dtype := gotch.DefaultDType
|
||||
o := ts.MustOnes(dims, dtype, e.path.Device())
|
||||
out, err := e.path.getOrAddWithLock(e.name, o, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -1161,7 +1200,8 @@ func (e *Entry) MustOrUniform(dims []int64, lo, up float64, opts ...AddOpt) *ts.
|
|||
|
||||
// OrZerosNoTrain returns the existing entry if found, otherwise create a new variable.
|
||||
func (e *Entry) OrZerosNoTrain(dims []int64, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
z := ts.MustZeros(dims, gotch.Float, e.path.Device())
|
||||
dtype := gotch.DefaultDType
|
||||
z := ts.MustZeros(dims, dtype, e.path.Device())
|
||||
out, err := e.path.getOrAddWithLock(e.name, z, true, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Reference in New Issue
Block a user