varstore reworked and update
This commit is contained in:
parent
d95eaba5b3
commit
5a6fac51f3
5
dtype.go
5
dtype.go
|
@ -3,6 +3,7 @@ package gotch
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
// "log"
|
||||
"reflect"
|
||||
)
|
||||
|
@ -31,6 +32,7 @@ var (
|
|||
Int DType = DType{reflect.TypeOf(int32(1))} // 3
|
||||
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
|
||||
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
|
||||
Half DType = DType{reflect.TypeOf(float32(1))} // 5
|
||||
Float DType = DType{reflect.TypeOf(float32(1))} // 6
|
||||
Double DType = DType{reflect.TypeOf(float64(1))} // 7
|
||||
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
|
||||
|
@ -45,6 +47,7 @@ var dtypeGoType = map[DType]reflect.Type{
|
|||
Int16: reflect.TypeOf(int16(1)),
|
||||
Int: reflect.TypeOf(int32(1)),
|
||||
Int64: reflect.TypeOf(int64(1)),
|
||||
Half: reflect.TypeOf(float32(1)),
|
||||
Float: reflect.TypeOf(float32(1)),
|
||||
Double: reflect.TypeOf(float64(1)),
|
||||
Bool: reflect.TypeOf(true),
|
||||
|
@ -87,6 +90,7 @@ var dtypeCInt = map[DType]CInt{
|
|||
Int16: 2,
|
||||
Int: 3,
|
||||
Int64: 4,
|
||||
Half: 5,
|
||||
Float: 6,
|
||||
Double: 7,
|
||||
Bool: 11,
|
||||
|
@ -137,6 +141,7 @@ var dtypeSize = map[DType]uint{
|
|||
Int16: 2,
|
||||
Int: 4,
|
||||
Int64: 8,
|
||||
Half: 4, // Should it be?
|
||||
Float: 4,
|
||||
Double: 8,
|
||||
Bool: 1,
|
||||
|
|
|
@ -42,7 +42,7 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
|
|||
input.MustDrop()
|
||||
inputView.MustDrop()
|
||||
|
||||
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueezeDim(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||
sampledY := forwardTs.MustMultinomial(1, false, true)
|
||||
lastLabel = sampledY.Int64Values()[0]
|
||||
sampledY.MustDrop()
|
||||
|
|
|
@ -73,7 +73,7 @@ func fastResnet(p *nn.Path) *nn.SequentialT {
|
|||
|
||||
seq.Add(nn.NewLinear(p.Sub("linear"), 512, 10, nn.DefaultLinearConfig()))
|
||||
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
||||
return xs.MustMul1(ts.FloatScalar(0.125), false)
|
||||
return xs.MustMulScalar(ts.FloatScalar(0.125), false)
|
||||
}))
|
||||
|
||||
return seq
|
||||
|
|
|
@ -71,8 +71,13 @@ func runCNN1() {
|
|||
|
||||
var ds *vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||
testImages := ds.TestImages
|
||||
testLabels := ds.TestLabels
|
||||
// ds.TrainImages [60000, 784]
|
||||
// ds.TrainLabels [60000, 784]
|
||||
testImages := ds.TestImages // [10000, 784]
|
||||
testLabels := ds.TestLabels // [10000, 784]
|
||||
|
||||
fmt.Printf("testImages: %v\n", testImages.MustSize())
|
||||
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
|
||||
|
||||
device := gotch.CudaIfAvailable()
|
||||
vs := nn.NewVarStore(device)
|
||||
|
@ -87,16 +92,17 @@ func runCNN1() {
|
|||
startTime := time.Now()
|
||||
|
||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||
|
||||
totalSize := ds.TrainImages.MustSize()[0]
|
||||
samples := int(totalSize)
|
||||
// Shuffling
|
||||
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
||||
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
||||
index.MustDrop()
|
||||
|
||||
batches := samples / batchSize
|
||||
batchIndex := 0
|
||||
var epocLoss *ts.Tensor
|
||||
var epocLoss float64
|
||||
for i := 0; i < batches; i++ {
|
||||
start := batchIndex * batchSize
|
||||
size := batchSize
|
||||
|
@ -106,37 +112,33 @@ func runCNN1() {
|
|||
batchIndex += 1
|
||||
|
||||
// Indexing
|
||||
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||
bImages := imagesTs.Idx(narrowIndex)
|
||||
bLabels := labelsTs.Idx(narrowIndex)
|
||||
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
|
||||
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
|
||||
|
||||
bImages = bImages.MustTo(vs.Device(), true)
|
||||
bLabels = bLabels.MustTo(vs.Device(), true)
|
||||
|
||||
logits := net.ForwardT(bImages, true)
|
||||
bImages.MustDrop()
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
logits.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
|
||||
// loss = loss.MustSetRequiresGrad(true, false)
|
||||
loss = loss.MustSetRequiresGrad(true, true)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
epocLoss = loss.MustShallowClone()
|
||||
epocLoss.Detach_()
|
||||
|
||||
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
epocLoss = loss.Float64Values()[0]
|
||||
loss.MustDrop()
|
||||
}
|
||||
|
||||
// vs.Freeze()
|
||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
||||
// vs.Unfreeze()
|
||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
|
||||
if testAccuracy > bestAccuracy {
|
||||
bestAccuracy = testAccuracy
|
||||
}
|
||||
ts.NoGrad(func() {
|
||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
|
||||
if testAccuracy > bestAccuracy {
|
||||
bestAccuracy = testAccuracy
|
||||
}
|
||||
})
|
||||
|
||||
epocLoss.MustDrop()
|
||||
imagesTs.MustDrop()
|
||||
labelsTs.MustDrop()
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ func gramMatrix(m *ts.Tensor) *ts.Tensor {
|
|||
gram := mview.MustMatmul(mviewT, true)
|
||||
mviewT.MustDrop()
|
||||
|
||||
return gram.MustDiv1(ts.IntScalar(a*b*c*d), true)
|
||||
return gram.MustDivScalar(ts.IntScalar(a*b*c*d), true)
|
||||
}
|
||||
|
||||
func styleLoss(m1 *ts.Tensor, m2 *ts.Tensor) *ts.Tensor {
|
||||
|
@ -138,7 +138,7 @@ func main() {
|
|||
|
||||
vs := nn.NewVarStore(device)
|
||||
path := vs.Root()
|
||||
inputVar := path.VarCopy("img", contentImg)
|
||||
inputVar := path.MustVarCopy("img", contentImg)
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, LearningRate)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -168,7 +168,7 @@ func main() {
|
|||
t.MustDrop()
|
||||
}
|
||||
|
||||
lossMul := sLoss.MustMul1(styleWeight, true)
|
||||
lossMul := sLoss.MustMulScalar(styleWeight, true)
|
||||
loss := lossMul.MustAdd(cLoss, true)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
|
|
|
@ -248,8 +248,8 @@ func conv(vs *nn.Path, index uint, p int64, b *Block) (retVal1 int64, retVal2 in
|
|||
|
||||
var res *ts.Tensor
|
||||
if leaky {
|
||||
tmp2Mul := tmp2.MustMul1(ts.FloatScalar(0.1), false)
|
||||
res = tmp2.MustMax1(tmp2Mul, true)
|
||||
tmp2Mul := tmp2.MustMulScalar(ts.FloatScalar(0.1), false)
|
||||
res = tmp2.MustMaximum(tmp2Mul, true)
|
||||
tmp2Mul.MustDrop()
|
||||
} else {
|
||||
res = tmp2
|
||||
|
@ -434,7 +434,7 @@ func detect(xs *ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) *
|
|||
})
|
||||
|
||||
sliceApplyAndSet(xsTs, 0, 4, func(xs *ts.Tensor) *ts.Tensor {
|
||||
return xs.MustMul1(ts.IntScalar(stride), false)
|
||||
return xs.MustMulScalar(ts.IntScalar(stride), false)
|
||||
})
|
||||
|
||||
// TODO: delete all middle tensors.
|
||||
|
|
|
@ -3,14 +3,15 @@ package main
|
|||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
"log"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -273,7 +274,7 @@ func main() {
|
|||
|
||||
imgTmp1 := imageTs.MustUnsqueeze(0, true)
|
||||
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
|
||||
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
|
||||
img := imgTmp2.MustDivScalar(ts.FloatScalar(255.0), true)
|
||||
predictTmp := model.ForwardT(img, false)
|
||||
|
||||
predictions := predictTmp.MustSqueeze(true)
|
||||
|
|
|
@ -41,10 +41,10 @@ type BatchNorm struct {
|
|||
func NewBatchNorm(vs *Path, nd uint, outDim int64, config *BatchNormConfig) *BatchNorm {
|
||||
return &BatchNorm{
|
||||
config: config,
|
||||
RunningMean: vs.ZerosNoTrain("running_mean", []int64{outDim}),
|
||||
RunningVar: vs.OnesNoTrain("running_var", []int64{outDim}),
|
||||
Ws: vs.NewVar("weight", []int64{outDim}, config.WsInit),
|
||||
Bs: vs.NewVar("bias", []int64{outDim}, config.BsInit),
|
||||
RunningMean: vs.MustZerosNoTrain("running_mean", []int64{outDim}),
|
||||
RunningVar: vs.MustOnesNoTrain("running_var", []int64{outDim}),
|
||||
Ws: vs.MustNewVar("weight", []int64{outDim}, config.WsInit),
|
||||
Bs: vs.MustNewVar("bias", []int64{outDim}, config.BsInit),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -73,10 +73,10 @@ func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
|
|||
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
|
||||
return &ConvTranspose1D{
|
||||
|
@ -104,11 +104,11 @@ func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
|
|||
)
|
||||
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return &ConvTranspose2D{
|
||||
Ws: ws,
|
||||
|
@ -134,11 +134,11 @@ func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
|
|||
)
|
||||
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return &ConvTranspose3D{
|
||||
Ws: ws,
|
||||
|
|
24
nn/conv.go
24
nn/conv.go
|
@ -289,11 +289,11 @@ func NewConv1D(vs *Path, inDim, outDim, k int64, cfg *Conv1DConfig) *Conv1D {
|
|||
bs *ts.Tensor = ts.NewTensor()
|
||||
)
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, k)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return &Conv1D{
|
||||
Ws: ws,
|
||||
|
@ -316,11 +316,11 @@ func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg *Conv2DConfig) *Conv2
|
|||
bs *ts.Tensor = ts.NewTensor()
|
||||
)
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, k, k)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return &Conv2D{
|
||||
Ws: ws,
|
||||
|
@ -343,11 +343,11 @@ func NewConv3D(vs *Path, inDim, outDim, k int64, cfg *Conv3DConfig) *Conv3D {
|
|||
bs *ts.Tensor = ts.NewTensor()
|
||||
)
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, k, k, k)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
|
||||
return &Conv3D{
|
||||
Ws: ws,
|
||||
|
@ -418,11 +418,11 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
|
|||
case len(ksizes) == 1 && configT.String() == "*nn.Conv1DConfig":
|
||||
cfg := config.(*Conv1DConfig)
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
return &Conv1D{
|
||||
Ws: ws,
|
||||
Bs: bs,
|
||||
|
@ -431,11 +431,11 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
|
|||
case len(ksizes) == 2 && configT.String() == "*nn.Conv2DConfig":
|
||||
cfg := config.(*Conv2DConfig)
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
return &Conv2D{
|
||||
Ws: ws,
|
||||
Bs: bs,
|
||||
|
@ -444,11 +444,11 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
|
|||
case len(ksizes) == 3 && configT.String() == "*nn.Conv3DConfig":
|
||||
cfg := config.(*Conv3DConfig)
|
||||
if cfg.Bias {
|
||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||
}
|
||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||
weightSize = append(weightSize, ksizes...)
|
||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||
return &Conv3D{
|
||||
Ws: ws,
|
||||
Bs: bs,
|
||||
|
|
|
@ -36,7 +36,7 @@ func TrainableCModuleLoad(p *Path, file string) (*TrainableCModule, error) {
|
|||
// NOTE: return is a newly created and added tensor in varstore.
|
||||
// This tensor is different from input named tensor.
|
||||
// If not using, just ignore it. Drop it, will drop tensor at varstore.
|
||||
_ = p.Add(name, namedTensor.Tensor, requiresGrad)
|
||||
_ = p.MustAdd(name, namedTensor.Tensor, requiresGrad)
|
||||
|
||||
// Clean-up named tensors.
|
||||
namedTensor.Tensor.MustDrop()
|
||||
|
@ -62,7 +62,7 @@ func TrainableCModuleLoadData(p *Path, stream io.Reader) (*TrainableCModule, err
|
|||
// NOTE: return is a newly created and added tensor in varstore.
|
||||
// This tensor is different from input named tensor.
|
||||
// If not using, just ignore it. Drop it, will drop tensor at varstore.
|
||||
_ = p.Add(name, namedTensor.Tensor, requiresGrad)
|
||||
_ = p.MustAdd(name, namedTensor.Tensor, requiresGrad)
|
||||
|
||||
// Clean-up named tensors.
|
||||
namedTensor.Tensor.MustDrop()
|
||||
|
|
|
@ -39,8 +39,8 @@ func NewLayerNorm(vs *Path, normalizedShape []int64, config *LayerNormConfig) *L
|
|||
bs *ts.Tensor
|
||||
)
|
||||
if config.ElementwiseAffine {
|
||||
ws = vs.NewVar("weight", normalizedShape, config.WsInit)
|
||||
bs = vs.NewVar("bias", normalizedShape, config.BsInit)
|
||||
ws = vs.MustNewVar("weight", normalizedShape, config.WsInit)
|
||||
bs = vs.MustNewVar("bias", normalizedShape, config.BsInit)
|
||||
}
|
||||
|
||||
return &LayerNorm{config, ws, bs, normalizedShape}
|
||||
|
|
|
@ -49,14 +49,14 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
|||
case c.BsInit == nil:
|
||||
bound := 1.0 / math.Sqrt(float64(inDim))
|
||||
bsInit := NewUniformInit(-bound, bound)
|
||||
bs = vs.NewVar("bias", []int64{outDim}, bsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, bsInit)
|
||||
case c.BsInit != nil:
|
||||
bs = vs.NewVar("bias", []int64{outDim}, c.BsInit)
|
||||
bs = vs.MustNewVar("bias", []int64{outDim}, c.BsInit)
|
||||
}
|
||||
}
|
||||
|
||||
return &Linear{
|
||||
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
||||
Ws: vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
||||
Bs: bs,
|
||||
}
|
||||
}
|
||||
|
|
295
nn/optimizer.go
295
nn/optimizer.go
|
@ -5,14 +5,18 @@ package nn
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Optimizer is a struct object to run gradient descent.
|
||||
type Optimizer struct {
|
||||
opt *ts.COptimizer
|
||||
variablesInOptimizer uint8
|
||||
varstore *VarStore
|
||||
opt *ts.COptimizer
|
||||
// variablesInOptimizer uint8
|
||||
variablesInOptimizer map[string]struct{}
|
||||
config interface{}
|
||||
stepCount int
|
||||
}
|
||||
|
@ -34,25 +38,27 @@ type OptimizerConfig interface {
|
|||
}
|
||||
|
||||
// defaultBuild is `default` Build method for OptimizerConfig interface
|
||||
func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (retVal *Optimizer, err error) {
|
||||
func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (*Optimizer, error) {
|
||||
opt, err := config.buildCOpt(lr)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(vs.Vars.TrainableVariables) > 0 {
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
names := make(map[string]struct{})
|
||||
for name, v := range vs.vars {
|
||||
if v.Trainable {
|
||||
if err = opt.AddParameter(v.Tensor, v.Group); err != nil {
|
||||
err = fmt.Errorf("Optimizer defaultBuild - AddParameter failed: %w\n", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
names[name] = struct{}{}
|
||||
}
|
||||
|
||||
return &Optimizer{
|
||||
opt: opt,
|
||||
// variables: vs.Vars,
|
||||
variablesInOptimizer: uint8(len(vs.Vars.TrainableVariables)),
|
||||
varstore: vs,
|
||||
opt: opt,
|
||||
variablesInOptimizer: names,
|
||||
config: config,
|
||||
stepCount: 0,
|
||||
}, nil
|
||||
|
@ -215,51 +221,79 @@ func (c *RMSPropConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
|
|||
|
||||
// Optimizer methods:
|
||||
// ==================
|
||||
|
||||
func (opt *Optimizer) addMissingVariables() {
|
||||
|
||||
// missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
|
||||
//
|
||||
// if missingVariables > 0 {
|
||||
// var tensors []ts.Tensor
|
||||
// for _, t := range opt.variables.TrainableVariables[opt.variablesInOptimizer:] {
|
||||
// tensor := t.MustShallowClone()
|
||||
// tensor.Detach_()
|
||||
// tensors = append(tensors, tensor)
|
||||
// }
|
||||
//
|
||||
// opt.opt.AddParameters(tensors)
|
||||
// opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
|
||||
// }
|
||||
|
||||
type param struct {
|
||||
tensor *ts.Tensor
|
||||
group uint
|
||||
}
|
||||
trainables := make(map[string]param)
|
||||
for name, v := range opt.varstore.vars {
|
||||
if v.Trainable {
|
||||
trainables[name] = param{tensor: v.Tensor, group: v.Group}
|
||||
}
|
||||
}
|
||||
missingVariables := len(trainables) - len(opt.variablesInOptimizer)
|
||||
if missingVariables > 0 {
|
||||
log.Println("INFO: Optimizer.addMissingVariables()...")
|
||||
for name, x := range trainables {
|
||||
if _, ok := opt.variablesInOptimizer[name]; !ok {
|
||||
opt.opt.AddParameter(x.tensor, x.group)
|
||||
opt.variablesInOptimizer[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ZeroGrad zeroes the gradient for the tensors tracked by this optimizer.
|
||||
func (opt *Optimizer) ZeroGrad() {
|
||||
opt.addMissingVariables()
|
||||
func (opt *Optimizer) ZeroGrad() error {
|
||||
if err := opt.opt.ZeroGrad(); err != nil {
|
||||
log.Fatalf("Optimizer - ZeroGrad method call error: %v\n", err)
|
||||
err = fmt.Errorf("Optimizer.ZeroGrad() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustZeroGrad zeroes the gradient for the tensors tracked by this optimizer.
|
||||
func (opt *Optimizer) MustZeroGrad() {
|
||||
err := opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clips gradient value at some specified maximum value.
|
||||
func (opt *Optimizer) ClipGradValue(max float64) {
|
||||
opt.varstore.Lock()
|
||||
defer opt.varstore.Unlock()
|
||||
|
||||
// opt.variables.mutex.Lock()
|
||||
// defer opt.variables.mutex.Unlock()
|
||||
|
||||
// for _, tensor := range opt.variables.TrainableVariables {
|
||||
// tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||
// }
|
||||
for _, v := range opt.varstore.vars {
|
||||
if v.Trainable {
|
||||
// v.Tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||
gradTs := v.Tensor.MustGrad(false)
|
||||
gradTs.Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step performs an optimization step, updating the tracked tensors based on their gradients.
|
||||
func (opt *Optimizer) Step() {
|
||||
opt.addMissingVariables()
|
||||
func (opt *Optimizer) Step() error {
|
||||
err := opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - Step method call error: %v\n", err)
|
||||
err = fmt.Errorf("Optimizer.Step() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
opt.stepCount += 1
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustStep performs an optimization step, updating the tracked tensors based on their gradients.
|
||||
func (opt *Optimizer) MustStep() {
|
||||
err := opt.Step()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ResetStepCount set step count to zero.
|
||||
|
@ -273,51 +307,208 @@ func (opt *Optimizer) StepCount() int {
|
|||
}
|
||||
|
||||
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
func (opt *Optimizer) BackwardStep(loss *ts.Tensor) {
|
||||
opt.addMissingVariables()
|
||||
func (opt *Optimizer) BackwardStep(loss *ts.Tensor) error {
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStep method call - ZeroGrad error: %v\n", err)
|
||||
err = fmt.Errorf("Optimizer.BackwardStep() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
loss.MustBackward()
|
||||
err = opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStep method call - Step() error: %v\n", err)
|
||||
err = fmt.Errorf("Optimizer.BackwardStep() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustBackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
func (opt *Optimizer) MustBackwardStep(loss *ts.Tensor) {
|
||||
err := opt.BackwardStep(loss)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
//
|
||||
// The gradients are clipped based on `max` before being applied.
|
||||
func (opt *Optimizer) BackwardStepClip(loss *ts.Tensor, max float64) {
|
||||
opt.addMissingVariables()
|
||||
func (opt *Optimizer) BackwardStepClip(loss *ts.Tensor, max float64) error {
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStepClip method call - ZeroGrad error: %v\n", err)
|
||||
err = fmt.Errorf("Optimizer.BackwardStepClip() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
loss.MustBackward()
|
||||
opt.ClipGradValue(max)
|
||||
err = opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStepClip method call - Step() error: %v\n", err)
|
||||
err = fmt.Errorf("Optimizer.BackwardStepClip() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustBackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
//
|
||||
// The gradients are clipped based on `max` before being applied.
|
||||
func (opt *Optimizer) MustBackwardStepClip(loss *ts.Tensor, max float64) {
|
||||
err := opt.BackwardStepClip(loss, max)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO. Clips gradient L2 norm over all trainable parameters.
|
||||
type ClipOpts struct {
|
||||
NormType float64
|
||||
ErrorIfNonFinite bool
|
||||
}
|
||||
|
||||
type ClipOpt func(*ClipOpts)
|
||||
|
||||
func defaultClipOpts() *ClipOpts {
|
||||
return &ClipOpts{
|
||||
NormType: 2.0,
|
||||
ErrorIfNonFinite: false, // will switch to "true" in the future.
|
||||
}
|
||||
}
|
||||
|
||||
func WithNormType(v float64) ClipOpt {
|
||||
return func(o *ClipOpts) {
|
||||
o.NormType = v
|
||||
}
|
||||
}
|
||||
|
||||
func WithErrorIfNonFinite(v bool) ClipOpt {
|
||||
return func(o *ClipOpts) {
|
||||
o.ErrorIfNonFinite = v
|
||||
}
|
||||
}
|
||||
|
||||
/// Clips gradient L2 norm over all trainable parameters.
|
||||
//
|
||||
// The norm is computed over all gradients together, as if they were
|
||||
// concatenated into a single vector.
|
||||
func (opt *Optimizer) ClipGradNorm(max float64) {
|
||||
// TODO.
|
||||
log.Fatalf("Not implemented yet!")
|
||||
//
|
||||
/// Args:
|
||||
// - max: max norm of the gradient
|
||||
// - o.NormType. Type of the used p-norm, can be "inf" for infinity norm. Default= 2.0
|
||||
// - o.ErrorIfNonFinite bool. If true, throw error if total norm of the gradients from paramters is "nan", "inf" or "-inf". Default=false
|
||||
// Returns: total norm of the parameters (viewed as a single vector)
|
||||
// ref. https://github.com/pytorch/pytorch/blob/cb4aeff7d8e4c70bb638cf159878c5204d0cc2da/torch/nn/utils/clip_grad.py#L59
|
||||
func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
||||
o := defaultClipOpts()
|
||||
for _, option := range opts {
|
||||
option(o)
|
||||
}
|
||||
|
||||
opt.varstore.Lock()
|
||||
defer opt.varstore.Unlock()
|
||||
parameters := opt.varstore.TrainableVariables()
|
||||
if len(parameters) == 0 {
|
||||
// return ts.MustOfSlice([]float64{0.0}), nil
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
norms []ts.Tensor
|
||||
totalNorm *ts.Tensor
|
||||
)
|
||||
|
||||
device := opt.varstore.device
|
||||
if o.NormType == math.Inf(1) {
|
||||
for _, v := range opt.varstore.vars {
|
||||
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
|
||||
norms = append(norms, *n)
|
||||
}
|
||||
// total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||
totalNorm = ts.MustStack(norms, 0).MustMax(true)
|
||||
} else {
|
||||
for _, v := range opt.varstore.vars {
|
||||
// x := v.Tensor.MustGrad(false).MustNorm(true)
|
||||
|
||||
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
|
||||
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||
norms = append(norms, *x)
|
||||
}
|
||||
}
|
||||
|
||||
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
|
||||
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||
for _, x := range norms {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
||||
totalNormVal := totalNorm.Float64Values(true)[0]
|
||||
// if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||
if o.ErrorIfNonFinite && (math.IsNaN(totalNormVal) || math.IsInf(totalNormVal, 1)) {
|
||||
err := fmt.Errorf("The total norm of order (%v) for gradients from 'parameters' is non-finite, so it cannot be clipped. To disable this error and scale the gradients by the non-finite norm anyway, set option.ErrorIfNonFinite= false", o.NormType)
|
||||
return err
|
||||
}
|
||||
|
||||
// clip_coef = max_norm / (total_norm + 1e-6)
|
||||
// clipCoefTs := ts.TensorFrom([]float64{max}).MustDiv(totalNorm, true)
|
||||
clipCoef := max / (totalNormVal + 1e-6)
|
||||
// NOTE: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||
// avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||
// when the gradients do not reside in CPU memory.
|
||||
// clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
if clipCoef > 1.0 {
|
||||
clipCoef = 1.0
|
||||
}
|
||||
for _, v := range opt.varstore.vars {
|
||||
if v.Trainable {
|
||||
// p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
|
||||
// v.Tensor.MustGrad(false).MustDetach(true).MustMulScalar_(ts.FloatScalar(clipCoef))
|
||||
v.Tensor.MustGrad(false).MustMulScalar_(ts.FloatScalar(clipCoef))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO. Applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
// BackwardStepClipNorm applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
//
|
||||
// The gradients L2 norm is clipped based on `max`.
|
||||
func (opt *Optimizer) BackwardStepClipNorm(loss *ts.Tensor, max float64) {
|
||||
// TODO.
|
||||
log.Fatalf("Not implemented yet!")
|
||||
func (opt *Optimizer) BackwardStepClipNorm(loss *ts.Tensor, max float64, opts ...ClipOpt) error {
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
err = loss.Backward()
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = opt.ClipGradNorm(max, opts...)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
err = opt.Step()
|
||||
if err != nil {
|
||||
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MustBackwardStepClipNorm applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
//
|
||||
// The gradients L2 norm is clipped based on `max`.
|
||||
func (opt *Optimizer) MustBackwardStepClipNorm(loss *ts.Tensor, max float64, opts ...ClipOpt) {
|
||||
err := opt.BackwardStepClipNorm(loss, max, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetLR sets the optimizer learning rate.
|
||||
|
|
|
@ -1,69 +1,73 @@
|
|||
package nn_test
|
||||
|
||||
/*
|
||||
* import (
|
||||
* // "reflect"
|
||||
* "fmt"
|
||||
* "log"
|
||||
* "testing"
|
||||
*
|
||||
* "github.com/sugarme/gotch"
|
||||
* "github.com/sugarme/gotch/nn"
|
||||
* ts "github.com/sugarme/gotch/tensor"
|
||||
* )
|
||||
*
|
||||
* func TestOptimizer(t *testing.T) {
|
||||
*
|
||||
* var data []float32
|
||||
* for i := 0; i < 15; i++ {
|
||||
* data = append(data, float32(i))
|
||||
* }
|
||||
* xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
|
||||
* if err != nil {
|
||||
* log.Fatal(err)
|
||||
* }
|
||||
*
|
||||
* ys := xs.MustMul1(ts.FloatScalar(0.42), false).MustAdd1(ts.FloatScalar(1.337), false)
|
||||
*
|
||||
* vs := nn.NewVarStore(gotch.CPU)
|
||||
*
|
||||
* optCfg := nn.DefaultSGDConfig()
|
||||
* opt, err := optCfg.Build(vs, 1e-2)
|
||||
* if err != nil {
|
||||
* t.Errorf("Failed building SGD optimizer")
|
||||
* }
|
||||
*
|
||||
* cfg := nn.LinearConfig{
|
||||
* WsInit: nn.NewConstInit(0.0),
|
||||
* BsInit: nn.NewConstInit(0.0),
|
||||
* Bias: true,
|
||||
* }
|
||||
*
|
||||
* linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
||||
*
|
||||
* logits := xs.Apply(linear)
|
||||
* loss := logits.MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
*
|
||||
* initialLoss := loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0})
|
||||
*
|
||||
* wantLoss := float64(1.0)
|
||||
*
|
||||
* if initialLoss < wantLoss {
|
||||
* t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
|
||||
* }
|
||||
*
|
||||
* for i := 0; i < 50; i++ {
|
||||
* loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
*
|
||||
* opt.BackwardStep(loss)
|
||||
* fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0}))
|
||||
* }
|
||||
*
|
||||
* loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
||||
* finalLoss := loss.Values()[0]
|
||||
* fmt.Printf("Final loss: %v\n", finalLoss)
|
||||
*
|
||||
* if finalLoss > 0.25 {
|
||||
* t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
|
||||
* }
|
||||
* } */
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func TestOptimizer(t *testing.T) {
|
||||
x := ts.MustArangeStart(ts.IntScalar(1), ts.IntScalar(15), gotch.Float, gotch.CPU).MustView([]int64{-1, 1}, true)
|
||||
// y = x * 0.42 + 1.337
|
||||
y := x.MustMulScalar(ts.FloatScalar(0.42), false).MustAddScalar(ts.FloatScalar(1.337), false)
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
path := vs.Root()
|
||||
|
||||
cfg := &nn.LinearConfig{
|
||||
WsInit: nn.NewConstInit(0.0),
|
||||
BsInit: nn.NewConstInit(0.0),
|
||||
Bias: true,
|
||||
}
|
||||
model := nn.NewLinear(path, 1, 1, cfg)
|
||||
|
||||
lr := 1e-2
|
||||
opt, err := nn.DefaultSGDConfig().Build(vs, lr)
|
||||
if err != nil {
|
||||
t.Errorf("Failed building SGD optimizer")
|
||||
}
|
||||
|
||||
initialLoss := x.ApplyT(model, true).MustMseLoss(y, 1, true).Float64Values(true)[0]
|
||||
wantLoss := float64(1.0)
|
||||
if initialLoss < wantLoss {
|
||||
t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
|
||||
}
|
||||
|
||||
// Optimization loop
|
||||
for i := 0; i < 50; i++ {
|
||||
logits := model.ForwardT(x, true)
|
||||
loss := logits.MustMseLoss(y, 1, true)
|
||||
if i%10 == 0 {
|
||||
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0}))
|
||||
}
|
||||
opt.BackwardStep(loss)
|
||||
}
|
||||
|
||||
loss := x.Apply(model).MustMseLoss(y, 1, true)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
loss = x.Apply(model).MustMseLoss(y, 1, true)
|
||||
finalLoss := loss.Float64Values()[0]
|
||||
fmt.Printf("Final loss: %v\n", finalLoss)
|
||||
|
||||
if finalLoss > 0.25 {
|
||||
t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
|
||||
}
|
||||
}
|
||||
|
||||
// see https://github.com/pytorch/pytorch/blob/9b203f667ac096db9f5f5679ac3e3d7931c34d36/test/test_nn.py#L2308
|
||||
func TestClipGradNorm(t *testing.T) {
|
||||
// TODO.
|
||||
// vs := nn.NewVarStore(gotch.CPU)
|
||||
// path := vs.Root()
|
||||
// l := nn.NewLinear(path, 10, 10, nn.DefaultLinearConfig())
|
||||
// maxNorm := 2.0
|
||||
}
|
||||
|
||||
// see https://github.com/pytorch/pytorch/blob/9b203f667ac096db9f5f5679ac3e3d7931c34d36/test/test_nn.py#L2364
|
||||
func TestClipGradValue(t *testing.T) {
|
||||
// TODO
|
||||
}
|
||||
|
|
28
nn/other.go
28
nn/other.go
|
@ -24,6 +24,9 @@ func (d *Dropout) ForwardT(input *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
|||
return ts.MustDropout(input, d.dropoutProb, train)
|
||||
}
|
||||
|
||||
// Parameter:
|
||||
// ==========
|
||||
|
||||
// NewParameter creates a kind of tensor that is considered as a module parameter.
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html
|
||||
func NewParameter(path *Path, name string, x *ts.Tensor, requireGradOpt ...bool) *ts.Tensor {
|
||||
|
@ -32,11 +35,34 @@ func NewParameter(path *Path, name string, x *ts.Tensor, requireGradOpt ...bool)
|
|||
requiredGrad = requireGradOpt[0]
|
||||
}
|
||||
|
||||
param := path.Add(name, x, requiredGrad)
|
||||
param := path.MustAdd(name, x, requiredGrad)
|
||||
|
||||
return param
|
||||
}
|
||||
|
||||
// Buffer:
|
||||
// =======
|
||||
|
||||
// NewBuffer creates new buffer.
|
||||
//
|
||||
// Buffer is different from Parameter as its requiredGrad always false.
|
||||
// - `o.Persistent` param. Default=true. If `true` buffer variable will be saved when `nn.VarStore.Save()` is called.
|
||||
//
|
||||
// Ref.
|
||||
// - https://github.com/pytorch/pytorch/blob/f71eede85a69caed637008e331f5ac5f5b7717ae/torch/nn/modules/module.py#L275
|
||||
// - https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/2
|
||||
func NewBuffer(path *Path, name string, x *ts.Tensor, persistentOpt ...bool) *ts.Tensor {
|
||||
persistent := true
|
||||
if len(persistentOpt) > 0 {
|
||||
persistent = persistentOpt[0]
|
||||
}
|
||||
opts := []AddOpt{
|
||||
WithPersistent(persistent),
|
||||
WithVarType("buffer"),
|
||||
}
|
||||
return path.MustAdd(name, x, false, opts...) // requiredGrad always false. Different from parameter.
|
||||
}
|
||||
|
||||
// Identity:
|
||||
// =========
|
||||
|
||||
|
|
32
nn/rnn.go
32
nn/rnn.go
|
@ -97,26 +97,26 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
|||
}
|
||||
switch numDirections {
|
||||
case 1:
|
||||
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
wIh := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||
wHh := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
|
||||
case 2: // bi-directional
|
||||
// forward
|
||||
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
wIh := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||
wHh := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
|
||||
// reverse
|
||||
wIhR := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
|
||||
wHhR := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
|
||||
bIhR := vs.Zeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
|
||||
bHhR := vs.Zeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
|
||||
wIhR := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
|
||||
wHhR := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
|
||||
bIhR := vs.MustZeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
|
||||
bHhR := vs.MustZeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
|
||||
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
|
||||
}
|
||||
}
|
||||
|
@ -234,10 +234,10 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
|
|||
inputDim = hiddenDim * numDirections
|
||||
}
|
||||
|
||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inputDim})
|
||||
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
||||
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
||||
wIh := vs.MustKaimingUniform("w_ih", []int64{gateDim, inputDim})
|
||||
wHh := vs.MustKaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||
bIh := vs.MustZeros("b_ih", []int64{gateDim})
|
||||
bHh := vs.MustZeros("b_hh", []int64{gateDim})
|
||||
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
}
|
||||
|
|
|
@ -251,6 +251,7 @@ func BatchAccuracyForLogits(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d got
|
|||
|
||||
logits := m.ForwardT(bImages, false)
|
||||
acc := logits.AccuracyForLogits(bLabels)
|
||||
logits.MustDrop()
|
||||
sumAccuracy += acc.Float64Values()[0] * size
|
||||
sampleCount += size
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ type Embedding struct {
|
|||
// NewEmbedding creates a new Embedding
|
||||
func NewEmbedding(vs *Path, numEmbeddings int64, embeddingDim int64, config *EmbeddingConfig) *Embedding {
|
||||
return &Embedding{
|
||||
Ws: vs.NewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
|
||||
Ws: vs.MustNewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
|
233
nn/varstore.go
233
nn/varstore.go
|
@ -187,6 +187,46 @@ func (vs *VarStore) Load(filepath string) error {
|
|||
v.Tensor.Copy_(currTs)
|
||||
})
|
||||
}
|
||||
|
||||
for _, x := range namedTensors {
|
||||
x.Tensor.MustDrop()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadWeights loads pretrained weights to VarStore.
|
||||
func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error {
|
||||
var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0)
|
||||
for _, namedTensor := range namedTensors {
|
||||
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor
|
||||
}
|
||||
|
||||
// Match and in-place copy value (update) from newly loaded tensors
|
||||
// to existing named tensors if name is matched. Throw error otherwise.
|
||||
vs.Lock()
|
||||
defer vs.Unlock()
|
||||
|
||||
for name, v := range vs.vars {
|
||||
// missing variable
|
||||
currTs, ok := namedTensorsMap[name]
|
||||
if !ok {
|
||||
err := fmt.Errorf("VarStore.LoadWeights() failed: there's a tensor with name %q in VarStore, but not found in the loaded weights.\n", name)
|
||||
return err
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
sourceShape := currTs.MustSize()
|
||||
destShape := v.Tensor.MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
err := fmt.Errorf("VarStore.LoadWeights() failed. Mismatched shape error for variable name: %v - At store: %v - At source %v\n", name, destShape, sourceShape)
|
||||
return err
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
v.Tensor.Copy_(currTs)
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -242,6 +282,60 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) {
|
|||
})
|
||||
}
|
||||
|
||||
for _, x := range namedTensors {
|
||||
x.Tensor.MustDrop()
|
||||
}
|
||||
|
||||
return missingVariables, nil
|
||||
}
|
||||
|
||||
// LoadWeightsPartial loads the VarStore variable values from a file if it exists.
|
||||
//
|
||||
// Weight values for the tensors currently stored in the var-store and the given file get
|
||||
// loaded from the given file. If a variable in the var store is not present in the given file,
|
||||
// it is skipped and its values are not updated. This method should be used if pre-trained
|
||||
// weight for only parts of the model are available.
|
||||
// Note that the set of variables stored in the var-store is not changed, only the values
|
||||
// for these tensors are modified.
|
||||
//
|
||||
// Returns a String Vector containing the names of missing variables.
|
||||
func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string, error) {
|
||||
var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0)
|
||||
for _, namedTensor := range namedTensors {
|
||||
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor
|
||||
}
|
||||
|
||||
var missingVariables []string
|
||||
|
||||
// Match and in-place copy value (update) from newly loaded tensors
|
||||
// to existing named tensors if name is matched. Throw error otherwise.
|
||||
vs.Lock()
|
||||
defer vs.Unlock()
|
||||
|
||||
for name, v := range vs.vars {
|
||||
var currTs *ts.Tensor
|
||||
var ok bool
|
||||
|
||||
// missing variable
|
||||
if currTs, ok = namedTensorsMap[name]; !ok {
|
||||
missingVariables = append(missingVariables, name)
|
||||
continue
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
destShape := currTs.MustSize()
|
||||
sourceShape := v.Tensor.MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", name, destShape, sourceShape)
|
||||
missingVariables = append(missingVariables, name)
|
||||
continue
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
v.Tensor.Copy_(currTs)
|
||||
})
|
||||
}
|
||||
|
||||
return missingVariables, nil
|
||||
}
|
||||
|
||||
|
@ -284,7 +378,7 @@ func (vs *VarStore) Unfreeze() error {
|
|||
//
|
||||
// All the variables in this var store have to exist with the same
|
||||
// name in the source var store, otherwise an error is returned.
|
||||
func (vs *VarStore) Copy(src VarStore) error {
|
||||
func (vs *VarStore) Copy(src *VarStore) error {
|
||||
vs.Lock()
|
||||
defer vs.Unlock()
|
||||
src.Lock()
|
||||
|
@ -343,6 +437,34 @@ func (vs *VarStore) Summary() {
|
|||
fmt.Printf("Num of layers: %v\n", len(vars))
|
||||
}
|
||||
|
||||
// ToDType casts all variables in VarStore to specified DType.
|
||||
//
|
||||
// NOTE. only float-like types (Half, Float, Double) can ensure convertible.
|
||||
func (vs *VarStore) ToDType(dtype gotch.DType) {
|
||||
vs.Root().ToDType(dtype)
|
||||
}
|
||||
|
||||
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToHalf() {
|
||||
vs.Root().ToHalf()
|
||||
}
|
||||
|
||||
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToFloat() {
|
||||
vs.Root().ToFloat()
|
||||
}
|
||||
|
||||
// ToDouble casts all float-like variables in VarStore to `Double` dtype.
|
||||
//
|
||||
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||
func (vs *VarStore) ToDouble() {
|
||||
vs.Root().ToDouble()
|
||||
}
|
||||
|
||||
// Path methods:
|
||||
// =============
|
||||
|
||||
|
@ -467,6 +589,23 @@ func (p *Path) Add(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) (*
|
|||
return p.add(name, x, trainable, o.VarType, o.Persistent)
|
||||
}
|
||||
|
||||
// MustAdd adds a tensor to a given path.
|
||||
//
|
||||
// Args
|
||||
// - name: intention name of variable in VarStore (if duplicated, it will be added a suffix number)
|
||||
// - x: tensor holding values to keep in VarStore
|
||||
// - trainable: marked whether tensor is trainable.
|
||||
// - o.VarType: variable type, i.e., either "parameter" or "buffer"
|
||||
// - o.Persistent: whether to save this variables when `VarStore.Save()` is called. Only applied to `buffer` type.
|
||||
// Returns a reference to a tensor stored in VarStore.
|
||||
func (p *Path) MustAdd(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) *ts.Tensor {
|
||||
x, err := p.Add(name, x, trainable, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) {
|
||||
path := p.getpath(name)
|
||||
|
||||
|
@ -480,9 +619,73 @@ func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool,
|
|||
}
|
||||
|
||||
func (p *Path) SetGroup(g uint) {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
|
||||
// TODO. set group for individual variables.
|
||||
// TBD. variables of current path only or all sub paths as well?
|
||||
// For now, just set group for all variable at the path
|
||||
path := strings.Join(p.path, SEP)
|
||||
for name, v := range p.varstore.vars {
|
||||
vpaths := strings.Split(name, SEP)
|
||||
vpath := strings.Join(vpaths[:len(vpaths)-1], SEP)
|
||||
if vpath == path {
|
||||
v.Group = g
|
||||
p.varstore.vars[name] = v
|
||||
}
|
||||
}
|
||||
p.group = g
|
||||
}
|
||||
|
||||
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
|
||||
//
|
||||
// NOTE. this method should be used for floating-point conversion, i.e.,
|
||||
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double".
|
||||
func (p *Path) ToDType(dtype gotch.DType) {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
path := strings.Join(p.path, SEP)
|
||||
for name, v := range p.varstore.vars {
|
||||
if strings.Contains(name, path) {
|
||||
newVar := v
|
||||
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
||||
p.varstore.vars[name] = newVar
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// toFloat casts all float-like variables in this current path and sub-paths to specified dtype.
|
||||
func (p *Path) toFloat(dtype gotch.DType) {
|
||||
p.varstore.Lock()
|
||||
defer p.varstore.Unlock()
|
||||
path := strings.Join(p.path, SEP)
|
||||
for name, v := range p.varstore.vars {
|
||||
if strings.Contains(name, path) {
|
||||
dtype := v.Tensor.DType()
|
||||
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double {
|
||||
newVar := v
|
||||
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
||||
p.varstore.vars[name] = newVar
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToHalf casts all variables in current path and subpaths to `Half` precision.
|
||||
func (p *Path) ToHalf() {
|
||||
p.toFloat(gotch.Half)
|
||||
}
|
||||
|
||||
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||
func (p *Path) ToFloat() {
|
||||
p.toFloat(gotch.Float)
|
||||
}
|
||||
|
||||
// ToDouble casts all variables in current path and subpaths to `Double` precision.
|
||||
func (p *Path) ToDouble() {
|
||||
p.toFloat(gotch.Double)
|
||||
}
|
||||
|
||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||
//
|
||||
// The new variable is named according to the name parameter and
|
||||
|
@ -506,6 +709,20 @@ func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tens
|
|||
return out, nil
|
||||
}
|
||||
|
||||
// MustZerosNoTrain creates a new variable initialized with zeros.
|
||||
//
|
||||
// The new variable is named according to the name parameter and
|
||||
// has the specified shape. The variable will not be trainable so
|
||||
// gradients will not be tracked.
|
||||
// The variable uses a float tensor initialized with zeros.
|
||||
func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor {
|
||||
x, err := p.ZerosNoTrain(name, dims, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// OnesNoTrain creates a new variable initialized with ones.
|
||||
//
|
||||
// The new variable is named according to the name parameter and
|
||||
|
@ -529,6 +746,20 @@ func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tenso
|
|||
return out, nil
|
||||
}
|
||||
|
||||
// MustOnesNoTrain creates a new variable initialized with ones.
|
||||
//
|
||||
// The new variable is named according to the name parameter and
|
||||
// has the specified shape. The variable will not be trainable so
|
||||
// gradients will not be tracked.
|
||||
// The variable uses a float tensor initialized with ones.
|
||||
func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor {
|
||||
x, err := p.OnesNoTrain(name, dims, opts...)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// NewVar creates a new variable.
|
||||
//
|
||||
// The new variable is named according to the name parameter and
|
||||
|
|
|
@ -439,26 +439,19 @@ func LoadAll(vs *nn.VarStore, modelFile string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
// for tsName, _ := range vs.Vars.NamedVariables {
|
||||
for tsName := range vs.Vars.NamedVariables {
|
||||
// missing variable
|
||||
currTs, ok := weights[tsName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("LoadAll() failed: Cannot find tensor with name: %v in variable store. \n", tsName)
|
||||
return err
|
||||
var namedTensors []ts.NamedTensor
|
||||
for n, x := range weights {
|
||||
namedTs := ts.NamedTensor{
|
||||
Name: n,
|
||||
Tensor: x,
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
sourceShape := currTs.MustSize()
|
||||
destShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
err = fmt.Errorf("LoadAll() failed: Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
||||
return err
|
||||
}
|
||||
namedTensors = append(namedTensors, namedTs)
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||
})
|
||||
err = vs.LoadWeights(namedTensors)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, x := range weights {
|
||||
|
@ -477,32 +470,21 @@ func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var namedTensors []ts.NamedTensor
|
||||
for n, x := range weights {
|
||||
namedTs := ts.NamedTensor{
|
||||
Name: n,
|
||||
Tensor: x,
|
||||
}
|
||||
|
||||
namedTensors = append(namedTensors, namedTs)
|
||||
}
|
||||
|
||||
var missingVariables []string
|
||||
|
||||
// Match and in-place copy value (update) from newly loaded tensors
|
||||
// to existing named tensors if name is matched. Throw error otherwise.
|
||||
for tsName := range vs.Vars.NamedVariables {
|
||||
var currTs *ts.Tensor
|
||||
var ok bool
|
||||
|
||||
// missing variable
|
||||
if currTs, ok = weights[tsName]; !ok {
|
||||
missingVariables = append(missingVariables, tsName)
|
||||
continue
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
destShape := currTs.MustSize()
|
||||
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", tsName, destShape, sourceShape)
|
||||
missingVariables = append(missingVariables, tsName)
|
||||
continue
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||
})
|
||||
missingVariables, err = vs.LoadWeightsPartial(namedTensors)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, x := range weights {
|
||||
|
|
|
@ -1101,38 +1101,39 @@ func (ngg *NoGradGuard) Enable() {
|
|||
_ = MustGradSetEnabled(ngg.enabled)
|
||||
}
|
||||
|
||||
// Reduction type is an enum-like type
|
||||
type Reduction int
|
||||
|
||||
const (
|
||||
// Do not reduce
|
||||
ReductionNone Reduction = iota
|
||||
ReductionNone int64 = 0
|
||||
// Mean of losses
|
||||
ReductionMean
|
||||
ReductionMean int64 = 1
|
||||
// Sum of losses
|
||||
ReductionSum
|
||||
ReductionSum int64 = 2
|
||||
// Escape hatch in case new options become available
|
||||
ReductionOther
|
||||
ReductionOther int64 = 3
|
||||
)
|
||||
|
||||
func (r Reduction) ToInt() int {
|
||||
switch r {
|
||||
case ReductionNone:
|
||||
return 0
|
||||
case ReductionMean:
|
||||
return 1
|
||||
case ReductionSum:
|
||||
return 2
|
||||
case ReductionOther:
|
||||
return 3
|
||||
}
|
||||
|
||||
// NOTE. should it be panic here instead of returning -1?
|
||||
return -1
|
||||
}
|
||||
// func (r Reduction) ToInt() int {
|
||||
// switch r {
|
||||
// case ReductionNone:
|
||||
// return 0
|
||||
// case ReductionMean:
|
||||
// return 1
|
||||
// case ReductionSum:
|
||||
// return 2
|
||||
// case ReductionOther:
|
||||
// return 3
|
||||
// }
|
||||
//
|
||||
// // NOTE. should it be panic here instead of returning -1?
|
||||
// return -1
|
||||
// }
|
||||
|
||||
// Float64Values returns values of tensor in a slice of float64.
|
||||
func (ts *Tensor) Float64Values() []float64 {
|
||||
func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
|
||||
del := false
|
||||
if len(delOpt) > 0 {
|
||||
del = delOpt[0]
|
||||
}
|
||||
numel := ts.Numel()
|
||||
vec := make([]float64, numel)
|
||||
|
||||
|
@ -1141,11 +1142,19 @@ func (ts *Tensor) Float64Values() []float64 {
|
|||
float64Ts.MustCopyData(vec, numel)
|
||||
float64Ts.MustDrop()
|
||||
|
||||
if del {
|
||||
ts.MustDrop()
|
||||
}
|
||||
|
||||
return vec
|
||||
}
|
||||
|
||||
// Int64Values returns values of tensor in a slice of int64.
|
||||
func (ts *Tensor) Int64Values() []int64 {
|
||||
func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
|
||||
del := false
|
||||
if len(delOpt) > 0 {
|
||||
del = delOpt[0]
|
||||
}
|
||||
numel := ts.Numel()
|
||||
vec := make([]int64, numel)
|
||||
|
||||
|
@ -1154,6 +1163,10 @@ func (ts *Tensor) Int64Values() []int64 {
|
|||
int64Ts.MustCopyData(vec, numel)
|
||||
int64Ts.MustDrop()
|
||||
|
||||
if del {
|
||||
ts.MustDrop()
|
||||
}
|
||||
|
||||
return vec
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user