varstore reworked and update
This commit is contained in:
parent
d95eaba5b3
commit
5a6fac51f3
5
dtype.go
5
dtype.go
|
@ -3,6 +3,7 @@ package gotch
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
// "log"
|
// "log"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
@ -31,6 +32,7 @@ var (
|
||||||
Int DType = DType{reflect.TypeOf(int32(1))} // 3
|
Int DType = DType{reflect.TypeOf(int32(1))} // 3
|
||||||
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
|
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
|
||||||
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
|
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
|
||||||
|
Half DType = DType{reflect.TypeOf(float32(1))} // 5
|
||||||
Float DType = DType{reflect.TypeOf(float32(1))} // 6
|
Float DType = DType{reflect.TypeOf(float32(1))} // 6
|
||||||
Double DType = DType{reflect.TypeOf(float64(1))} // 7
|
Double DType = DType{reflect.TypeOf(float64(1))} // 7
|
||||||
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
|
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
|
||||||
|
@ -45,6 +47,7 @@ var dtypeGoType = map[DType]reflect.Type{
|
||||||
Int16: reflect.TypeOf(int16(1)),
|
Int16: reflect.TypeOf(int16(1)),
|
||||||
Int: reflect.TypeOf(int32(1)),
|
Int: reflect.TypeOf(int32(1)),
|
||||||
Int64: reflect.TypeOf(int64(1)),
|
Int64: reflect.TypeOf(int64(1)),
|
||||||
|
Half: reflect.TypeOf(float32(1)),
|
||||||
Float: reflect.TypeOf(float32(1)),
|
Float: reflect.TypeOf(float32(1)),
|
||||||
Double: reflect.TypeOf(float64(1)),
|
Double: reflect.TypeOf(float64(1)),
|
||||||
Bool: reflect.TypeOf(true),
|
Bool: reflect.TypeOf(true),
|
||||||
|
@ -87,6 +90,7 @@ var dtypeCInt = map[DType]CInt{
|
||||||
Int16: 2,
|
Int16: 2,
|
||||||
Int: 3,
|
Int: 3,
|
||||||
Int64: 4,
|
Int64: 4,
|
||||||
|
Half: 5,
|
||||||
Float: 6,
|
Float: 6,
|
||||||
Double: 7,
|
Double: 7,
|
||||||
Bool: 11,
|
Bool: 11,
|
||||||
|
@ -137,6 +141,7 @@ var dtypeSize = map[DType]uint{
|
||||||
Int16: 2,
|
Int16: 2,
|
||||||
Int: 4,
|
Int: 4,
|
||||||
Int64: 8,
|
Int64: 8,
|
||||||
|
Half: 4, // Should it be?
|
||||||
Float: 4,
|
Float: 4,
|
||||||
Double: 8,
|
Double: 8,
|
||||||
Bool: 1,
|
Bool: 1,
|
||||||
|
|
|
@ -42,7 +42,7 @@ func sample(data *ts.TextData, lstm *nn.LSTM, linear *nn.Linear, device gotch.De
|
||||||
input.MustDrop()
|
input.MustDrop()
|
||||||
inputView.MustDrop()
|
inputView.MustDrop()
|
||||||
|
|
||||||
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueeze1(0, true).MustSoftmax(-1, gotch.Float, true)
|
forwardTs := linear.Forward(state.(*nn.LSTMState).H()).MustSqueezeDim(0, true).MustSoftmax(-1, gotch.Float, true)
|
||||||
sampledY := forwardTs.MustMultinomial(1, false, true)
|
sampledY := forwardTs.MustMultinomial(1, false, true)
|
||||||
lastLabel = sampledY.Int64Values()[0]
|
lastLabel = sampledY.Int64Values()[0]
|
||||||
sampledY.MustDrop()
|
sampledY.MustDrop()
|
||||||
|
|
|
@ -73,7 +73,7 @@ func fastResnet(p *nn.Path) *nn.SequentialT {
|
||||||
|
|
||||||
seq.Add(nn.NewLinear(p.Sub("linear"), 512, 10, nn.DefaultLinearConfig()))
|
seq.Add(nn.NewLinear(p.Sub("linear"), 512, 10, nn.DefaultLinearConfig()))
|
||||||
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
|
||||||
return xs.MustMul1(ts.FloatScalar(0.125), false)
|
return xs.MustMulScalar(ts.FloatScalar(0.125), false)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
return seq
|
return seq
|
||||||
|
|
|
@ -71,8 +71,13 @@ func runCNN1() {
|
||||||
|
|
||||||
var ds *vision.Dataset
|
var ds *vision.Dataset
|
||||||
ds = vision.LoadMNISTDir(MnistDirNN)
|
ds = vision.LoadMNISTDir(MnistDirNN)
|
||||||
testImages := ds.TestImages
|
// ds.TrainImages [60000, 784]
|
||||||
testLabels := ds.TestLabels
|
// ds.TrainLabels [60000, 784]
|
||||||
|
testImages := ds.TestImages // [10000, 784]
|
||||||
|
testLabels := ds.TestLabels // [10000, 784]
|
||||||
|
|
||||||
|
fmt.Printf("testImages: %v\n", testImages.MustSize())
|
||||||
|
fmt.Printf("testLabels: %v\n", testLabels.MustSize())
|
||||||
|
|
||||||
device := gotch.CudaIfAvailable()
|
device := gotch.CudaIfAvailable()
|
||||||
vs := nn.NewVarStore(device)
|
vs := nn.NewVarStore(device)
|
||||||
|
@ -87,16 +92,17 @@ func runCNN1() {
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
for epoch := 0; epoch < epochsCNN; epoch++ {
|
for epoch := 0; epoch < epochsCNN; epoch++ {
|
||||||
|
|
||||||
totalSize := ds.TrainImages.MustSize()[0]
|
totalSize := ds.TrainImages.MustSize()[0]
|
||||||
samples := int(totalSize)
|
samples := int(totalSize)
|
||||||
|
// Shuffling
|
||||||
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||||
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
||||||
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
||||||
|
index.MustDrop()
|
||||||
|
|
||||||
batches := samples / batchSize
|
batches := samples / batchSize
|
||||||
batchIndex := 0
|
batchIndex := 0
|
||||||
var epocLoss *ts.Tensor
|
var epocLoss float64
|
||||||
for i := 0; i < batches; i++ {
|
for i := 0; i < batches; i++ {
|
||||||
start := batchIndex * batchSize
|
start := batchIndex * batchSize
|
||||||
size := batchSize
|
size := batchSize
|
||||||
|
@ -106,37 +112,33 @@ func runCNN1() {
|
||||||
batchIndex += 1
|
batchIndex += 1
|
||||||
|
|
||||||
// Indexing
|
// Indexing
|
||||||
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
bImages := imagesTs.MustNarrow(0, int64(start), int64(size), false)
|
||||||
bImages := imagesTs.Idx(narrowIndex)
|
bLabels := labelsTs.MustNarrow(0, int64(start), int64(size), false)
|
||||||
bLabels := labelsTs.Idx(narrowIndex)
|
|
||||||
|
|
||||||
bImages = bImages.MustTo(vs.Device(), true)
|
bImages = bImages.MustTo(vs.Device(), true)
|
||||||
bLabels = bLabels.MustTo(vs.Device(), true)
|
bLabels = bLabels.MustTo(vs.Device(), true)
|
||||||
|
|
||||||
logits := net.ForwardT(bImages, true)
|
logits := net.ForwardT(bImages, true)
|
||||||
|
bImages.MustDrop()
|
||||||
loss := logits.CrossEntropyForLogits(bLabels)
|
loss := logits.CrossEntropyForLogits(bLabels)
|
||||||
|
logits.MustDrop()
|
||||||
|
bLabels.MustDrop()
|
||||||
|
|
||||||
// loss = loss.MustSetRequiresGrad(true, false)
|
loss = loss.MustSetRequiresGrad(true, true)
|
||||||
opt.BackwardStep(loss)
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
epocLoss = loss.MustShallowClone()
|
epocLoss = loss.Float64Values()[0]
|
||||||
epocLoss.Detach_()
|
loss.MustDrop()
|
||||||
|
|
||||||
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])
|
|
||||||
|
|
||||||
bImages.MustDrop()
|
|
||||||
bLabels.MustDrop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// vs.Freeze()
|
ts.NoGrad(func() {
|
||||||
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
testAccuracy := nn.BatchAccuracyForLogits(vs, net, testImages, testLabels, vs.Device(), 1024)
|
||||||
// vs.Unfreeze()
|
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss, testAccuracy*100.0)
|
||||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
|
if testAccuracy > bestAccuracy {
|
||||||
if testAccuracy > bestAccuracy {
|
bestAccuracy = testAccuracy
|
||||||
bestAccuracy = testAccuracy
|
}
|
||||||
}
|
})
|
||||||
|
|
||||||
epocLoss.MustDrop()
|
|
||||||
imagesTs.MustDrop()
|
imagesTs.MustDrop()
|
||||||
labelsTs.MustDrop()
|
labelsTs.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,7 @@ func gramMatrix(m *ts.Tensor) *ts.Tensor {
|
||||||
gram := mview.MustMatmul(mviewT, true)
|
gram := mview.MustMatmul(mviewT, true)
|
||||||
mviewT.MustDrop()
|
mviewT.MustDrop()
|
||||||
|
|
||||||
return gram.MustDiv1(ts.IntScalar(a*b*c*d), true)
|
return gram.MustDivScalar(ts.IntScalar(a*b*c*d), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
func styleLoss(m1 *ts.Tensor, m2 *ts.Tensor) *ts.Tensor {
|
func styleLoss(m1 *ts.Tensor, m2 *ts.Tensor) *ts.Tensor {
|
||||||
|
@ -138,7 +138,7 @@ func main() {
|
||||||
|
|
||||||
vs := nn.NewVarStore(device)
|
vs := nn.NewVarStore(device)
|
||||||
path := vs.Root()
|
path := vs.Root()
|
||||||
inputVar := path.VarCopy("img", contentImg)
|
inputVar := path.MustVarCopy("img", contentImg)
|
||||||
opt, err := nn.DefaultAdamConfig().Build(vs, LearningRate)
|
opt, err := nn.DefaultAdamConfig().Build(vs, LearningRate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -168,7 +168,7 @@ func main() {
|
||||||
t.MustDrop()
|
t.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
lossMul := sLoss.MustMul1(styleWeight, true)
|
lossMul := sLoss.MustMulScalar(styleWeight, true)
|
||||||
loss := lossMul.MustAdd(cLoss, true)
|
loss := lossMul.MustAdd(cLoss, true)
|
||||||
opt.BackwardStep(loss)
|
opt.BackwardStep(loss)
|
||||||
|
|
||||||
|
|
|
@ -248,8 +248,8 @@ func conv(vs *nn.Path, index uint, p int64, b *Block) (retVal1 int64, retVal2 in
|
||||||
|
|
||||||
var res *ts.Tensor
|
var res *ts.Tensor
|
||||||
if leaky {
|
if leaky {
|
||||||
tmp2Mul := tmp2.MustMul1(ts.FloatScalar(0.1), false)
|
tmp2Mul := tmp2.MustMulScalar(ts.FloatScalar(0.1), false)
|
||||||
res = tmp2.MustMax1(tmp2Mul, true)
|
res = tmp2.MustMaximum(tmp2Mul, true)
|
||||||
tmp2Mul.MustDrop()
|
tmp2Mul.MustDrop()
|
||||||
} else {
|
} else {
|
||||||
res = tmp2
|
res = tmp2
|
||||||
|
@ -434,7 +434,7 @@ func detect(xs *ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) *
|
||||||
})
|
})
|
||||||
|
|
||||||
sliceApplyAndSet(xsTs, 0, 4, func(xs *ts.Tensor) *ts.Tensor {
|
sliceApplyAndSet(xsTs, 0, 4, func(xs *ts.Tensor) *ts.Tensor {
|
||||||
return xs.MustMul1(ts.IntScalar(stride), false)
|
return xs.MustMulScalar(ts.IntScalar(stride), false)
|
||||||
})
|
})
|
||||||
|
|
||||||
// TODO: delete all middle tensors.
|
// TODO: delete all middle tensors.
|
||||||
|
|
|
@ -3,14 +3,15 @@ package main
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
"github.com/sugarme/gotch/nn"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
"github.com/sugarme/gotch/vision"
|
|
||||||
"log"
|
"log"
|
||||||
"math"
|
"math"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sort"
|
"sort"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -273,7 +274,7 @@ func main() {
|
||||||
|
|
||||||
imgTmp1 := imageTs.MustUnsqueeze(0, true)
|
imgTmp1 := imageTs.MustUnsqueeze(0, true)
|
||||||
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
|
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
|
||||||
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
|
img := imgTmp2.MustDivScalar(ts.FloatScalar(255.0), true)
|
||||||
predictTmp := model.ForwardT(img, false)
|
predictTmp := model.ForwardT(img, false)
|
||||||
|
|
||||||
predictions := predictTmp.MustSqueeze(true)
|
predictions := predictTmp.MustSqueeze(true)
|
||||||
|
|
|
@ -41,10 +41,10 @@ type BatchNorm struct {
|
||||||
func NewBatchNorm(vs *Path, nd uint, outDim int64, config *BatchNormConfig) *BatchNorm {
|
func NewBatchNorm(vs *Path, nd uint, outDim int64, config *BatchNormConfig) *BatchNorm {
|
||||||
return &BatchNorm{
|
return &BatchNorm{
|
||||||
config: config,
|
config: config,
|
||||||
RunningMean: vs.ZerosNoTrain("running_mean", []int64{outDim}),
|
RunningMean: vs.MustZerosNoTrain("running_mean", []int64{outDim}),
|
||||||
RunningVar: vs.OnesNoTrain("running_var", []int64{outDim}),
|
RunningVar: vs.MustOnesNoTrain("running_var", []int64{outDim}),
|
||||||
Ws: vs.NewVar("weight", []int64{outDim}, config.WsInit),
|
Ws: vs.MustNewVar("weight", []int64{outDim}, config.WsInit),
|
||||||
Bs: vs.NewVar("bias", []int64{outDim}, config.BsInit),
|
Bs: vs.MustNewVar("bias", []int64{outDim}, config.BsInit),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -73,10 +73,10 @@ func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
|
||||||
|
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, ksizes...)
|
weightSize = append(weightSize, ksizes...)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ConvTranspose1D{
|
return &ConvTranspose1D{
|
||||||
|
@ -104,11 +104,11 @@ func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, ksizes...)
|
weightSize = append(weightSize, ksizes...)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return &ConvTranspose2D{
|
return &ConvTranspose2D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
|
@ -134,11 +134,11 @@ func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, ksizes...)
|
weightSize = append(weightSize, ksizes...)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return &ConvTranspose3D{
|
return &ConvTranspose3D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
|
|
24
nn/conv.go
24
nn/conv.go
|
@ -289,11 +289,11 @@ func NewConv1D(vs *Path, inDim, outDim, k int64, cfg *Conv1DConfig) *Conv1D {
|
||||||
bs *ts.Tensor = ts.NewTensor()
|
bs *ts.Tensor = ts.NewTensor()
|
||||||
)
|
)
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, k)
|
weightSize = append(weightSize, k)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return &Conv1D{
|
return &Conv1D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
|
@ -316,11 +316,11 @@ func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg *Conv2DConfig) *Conv2
|
||||||
bs *ts.Tensor = ts.NewTensor()
|
bs *ts.Tensor = ts.NewTensor()
|
||||||
)
|
)
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, k, k)
|
weightSize = append(weightSize, k, k)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return &Conv2D{
|
return &Conv2D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
|
@ -343,11 +343,11 @@ func NewConv3D(vs *Path, inDim, outDim, k int64, cfg *Conv3DConfig) *Conv3D {
|
||||||
bs *ts.Tensor = ts.NewTensor()
|
bs *ts.Tensor = ts.NewTensor()
|
||||||
)
|
)
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, k, k, k)
|
weightSize = append(weightSize, k, k, k)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
return &Conv3D{
|
return &Conv3D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
|
@ -418,11 +418,11 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
|
||||||
case len(ksizes) == 1 && configT.String() == "*nn.Conv1DConfig":
|
case len(ksizes) == 1 && configT.String() == "*nn.Conv1DConfig":
|
||||||
cfg := config.(*Conv1DConfig)
|
cfg := config.(*Conv1DConfig)
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, ksizes...)
|
weightSize = append(weightSize, ksizes...)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
return &Conv1D{
|
return &Conv1D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
Bs: bs,
|
Bs: bs,
|
||||||
|
@ -431,11 +431,11 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
|
||||||
case len(ksizes) == 2 && configT.String() == "*nn.Conv2DConfig":
|
case len(ksizes) == 2 && configT.String() == "*nn.Conv2DConfig":
|
||||||
cfg := config.(*Conv2DConfig)
|
cfg := config.(*Conv2DConfig)
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, ksizes...)
|
weightSize = append(weightSize, ksizes...)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
return &Conv2D{
|
return &Conv2D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
Bs: bs,
|
Bs: bs,
|
||||||
|
@ -444,11 +444,11 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
|
||||||
case len(ksizes) == 3 && configT.String() == "*nn.Conv3DConfig":
|
case len(ksizes) == 3 && configT.String() == "*nn.Conv3DConfig":
|
||||||
cfg := config.(*Conv3DConfig)
|
cfg := config.(*Conv3DConfig)
|
||||||
if cfg.Bias {
|
if cfg.Bias {
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
}
|
}
|
||||||
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
weightSize = append(weightSize, ksizes...)
|
weightSize = append(weightSize, ksizes...)
|
||||||
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
|
||||||
return &Conv3D{
|
return &Conv3D{
|
||||||
Ws: ws,
|
Ws: ws,
|
||||||
Bs: bs,
|
Bs: bs,
|
||||||
|
|
|
@ -36,7 +36,7 @@ func TrainableCModuleLoad(p *Path, file string) (*TrainableCModule, error) {
|
||||||
// NOTE: return is a newly created and added tensor in varstore.
|
// NOTE: return is a newly created and added tensor in varstore.
|
||||||
// This tensor is different from input named tensor.
|
// This tensor is different from input named tensor.
|
||||||
// If not using, just ignore it. Drop it, will drop tensor at varstore.
|
// If not using, just ignore it. Drop it, will drop tensor at varstore.
|
||||||
_ = p.Add(name, namedTensor.Tensor, requiresGrad)
|
_ = p.MustAdd(name, namedTensor.Tensor, requiresGrad)
|
||||||
|
|
||||||
// Clean-up named tensors.
|
// Clean-up named tensors.
|
||||||
namedTensor.Tensor.MustDrop()
|
namedTensor.Tensor.MustDrop()
|
||||||
|
@ -62,7 +62,7 @@ func TrainableCModuleLoadData(p *Path, stream io.Reader) (*TrainableCModule, err
|
||||||
// NOTE: return is a newly created and added tensor in varstore.
|
// NOTE: return is a newly created and added tensor in varstore.
|
||||||
// This tensor is different from input named tensor.
|
// This tensor is different from input named tensor.
|
||||||
// If not using, just ignore it. Drop it, will drop tensor at varstore.
|
// If not using, just ignore it. Drop it, will drop tensor at varstore.
|
||||||
_ = p.Add(name, namedTensor.Tensor, requiresGrad)
|
_ = p.MustAdd(name, namedTensor.Tensor, requiresGrad)
|
||||||
|
|
||||||
// Clean-up named tensors.
|
// Clean-up named tensors.
|
||||||
namedTensor.Tensor.MustDrop()
|
namedTensor.Tensor.MustDrop()
|
||||||
|
|
|
@ -39,8 +39,8 @@ func NewLayerNorm(vs *Path, normalizedShape []int64, config *LayerNormConfig) *L
|
||||||
bs *ts.Tensor
|
bs *ts.Tensor
|
||||||
)
|
)
|
||||||
if config.ElementwiseAffine {
|
if config.ElementwiseAffine {
|
||||||
ws = vs.NewVar("weight", normalizedShape, config.WsInit)
|
ws = vs.MustNewVar("weight", normalizedShape, config.WsInit)
|
||||||
bs = vs.NewVar("bias", normalizedShape, config.BsInit)
|
bs = vs.MustNewVar("bias", normalizedShape, config.BsInit)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &LayerNorm{config, ws, bs, normalizedShape}
|
return &LayerNorm{config, ws, bs, normalizedShape}
|
||||||
|
|
|
@ -49,14 +49,14 @@ func NewLinear(vs *Path, inDim, outDim int64, c *LinearConfig) *Linear {
|
||||||
case c.BsInit == nil:
|
case c.BsInit == nil:
|
||||||
bound := 1.0 / math.Sqrt(float64(inDim))
|
bound := 1.0 / math.Sqrt(float64(inDim))
|
||||||
bsInit := NewUniformInit(-bound, bound)
|
bsInit := NewUniformInit(-bound, bound)
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, bsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, bsInit)
|
||||||
case c.BsInit != nil:
|
case c.BsInit != nil:
|
||||||
bs = vs.NewVar("bias", []int64{outDim}, c.BsInit)
|
bs = vs.MustNewVar("bias", []int64{outDim}, c.BsInit)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Linear{
|
return &Linear{
|
||||||
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
Ws: vs.MustNewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
||||||
Bs: bs,
|
Bs: bs,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
295
nn/optimizer.go
295
nn/optimizer.go
|
@ -5,14 +5,18 @@ package nn
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Optimizer is a struct object to run gradient descent.
|
// Optimizer is a struct object to run gradient descent.
|
||||||
type Optimizer struct {
|
type Optimizer struct {
|
||||||
opt *ts.COptimizer
|
varstore *VarStore
|
||||||
variablesInOptimizer uint8
|
opt *ts.COptimizer
|
||||||
|
// variablesInOptimizer uint8
|
||||||
|
variablesInOptimizer map[string]struct{}
|
||||||
config interface{}
|
config interface{}
|
||||||
stepCount int
|
stepCount int
|
||||||
}
|
}
|
||||||
|
@ -34,25 +38,27 @@ type OptimizerConfig interface {
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultBuild is `default` Build method for OptimizerConfig interface
|
// defaultBuild is `default` Build method for OptimizerConfig interface
|
||||||
func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (retVal *Optimizer, err error) {
|
func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (*Optimizer, error) {
|
||||||
opt, err := config.buildCOpt(lr)
|
opt, err := config.buildCOpt(lr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retVal, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(vs.Vars.TrainableVariables) > 0 {
|
names := make(map[string]struct{})
|
||||||
for _, v := range vs.Vars.TrainableVariables {
|
for name, v := range vs.vars {
|
||||||
|
if v.Trainable {
|
||||||
if err = opt.AddParameter(v.Tensor, v.Group); err != nil {
|
if err = opt.AddParameter(v.Tensor, v.Group); err != nil {
|
||||||
err = fmt.Errorf("Optimizer defaultBuild - AddParameter failed: %w\n", err)
|
err = fmt.Errorf("Optimizer defaultBuild - AddParameter failed: %w\n", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
names[name] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Optimizer{
|
return &Optimizer{
|
||||||
opt: opt,
|
varstore: vs,
|
||||||
// variables: vs.Vars,
|
opt: opt,
|
||||||
variablesInOptimizer: uint8(len(vs.Vars.TrainableVariables)),
|
variablesInOptimizer: names,
|
||||||
config: config,
|
config: config,
|
||||||
stepCount: 0,
|
stepCount: 0,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -215,51 +221,79 @@ func (c *RMSPropConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
|
||||||
|
|
||||||
// Optimizer methods:
|
// Optimizer methods:
|
||||||
// ==================
|
// ==================
|
||||||
|
|
||||||
func (opt *Optimizer) addMissingVariables() {
|
func (opt *Optimizer) addMissingVariables() {
|
||||||
|
type param struct {
|
||||||
// missingVariables := len(opt.variables.TrainableVariables) - int(opt.variablesInOptimizer)
|
tensor *ts.Tensor
|
||||||
//
|
group uint
|
||||||
// if missingVariables > 0 {
|
}
|
||||||
// var tensors []ts.Tensor
|
trainables := make(map[string]param)
|
||||||
// for _, t := range opt.variables.TrainableVariables[opt.variablesInOptimizer:] {
|
for name, v := range opt.varstore.vars {
|
||||||
// tensor := t.MustShallowClone()
|
if v.Trainable {
|
||||||
// tensor.Detach_()
|
trainables[name] = param{tensor: v.Tensor, group: v.Group}
|
||||||
// tensors = append(tensors, tensor)
|
}
|
||||||
// }
|
}
|
||||||
//
|
missingVariables := len(trainables) - len(opt.variablesInOptimizer)
|
||||||
// opt.opt.AddParameters(tensors)
|
if missingVariables > 0 {
|
||||||
// opt.variablesInOptimizer = uint8(len(opt.variables.TrainableVariables))
|
log.Println("INFO: Optimizer.addMissingVariables()...")
|
||||||
// }
|
for name, x := range trainables {
|
||||||
|
if _, ok := opt.variablesInOptimizer[name]; !ok {
|
||||||
|
opt.opt.AddParameter(x.tensor, x.group)
|
||||||
|
opt.variablesInOptimizer[name] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ZeroGrad zeroes the gradient for the tensors tracked by this optimizer.
|
// ZeroGrad zeroes the gradient for the tensors tracked by this optimizer.
|
||||||
func (opt *Optimizer) ZeroGrad() {
|
func (opt *Optimizer) ZeroGrad() error {
|
||||||
opt.addMissingVariables()
|
|
||||||
if err := opt.opt.ZeroGrad(); err != nil {
|
if err := opt.opt.ZeroGrad(); err != nil {
|
||||||
log.Fatalf("Optimizer - ZeroGrad method call error: %v\n", err)
|
err = fmt.Errorf("Optimizer.ZeroGrad() failed: %w\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustZeroGrad zeroes the gradient for the tensors tracked by this optimizer.
|
||||||
|
func (opt *Optimizer) MustZeroGrad() {
|
||||||
|
err := opt.ZeroGrad()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clips gradient value at some specified maximum value.
|
// Clips gradient value at some specified maximum value.
|
||||||
func (opt *Optimizer) ClipGradValue(max float64) {
|
func (opt *Optimizer) ClipGradValue(max float64) {
|
||||||
|
opt.varstore.Lock()
|
||||||
|
defer opt.varstore.Unlock()
|
||||||
|
|
||||||
// opt.variables.mutex.Lock()
|
for _, v := range opt.varstore.vars {
|
||||||
// defer opt.variables.mutex.Unlock()
|
if v.Trainable {
|
||||||
|
// v.Tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||||
// for _, tensor := range opt.variables.TrainableVariables {
|
gradTs := v.Tensor.MustGrad(false)
|
||||||
// tensor.MustGrad().Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
gradTs.Clamp_(ts.FloatScalar(-max), ts.FloatScalar(max))
|
||||||
// }
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step performs an optimization step, updating the tracked tensors based on their gradients.
|
// Step performs an optimization step, updating the tracked tensors based on their gradients.
|
||||||
func (opt *Optimizer) Step() {
|
func (opt *Optimizer) Step() error {
|
||||||
opt.addMissingVariables()
|
|
||||||
err := opt.opt.Step()
|
err := opt.opt.Step()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Optimizer - Step method call error: %v\n", err)
|
err = fmt.Errorf("Optimizer.Step() failed: %w\n", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
opt.stepCount += 1
|
opt.stepCount += 1
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustStep performs an optimization step, updating the tracked tensors based on their gradients.
|
||||||
|
func (opt *Optimizer) MustStep() {
|
||||||
|
err := opt.Step()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResetStepCount set step count to zero.
|
// ResetStepCount set step count to zero.
|
||||||
|
@ -273,51 +307,208 @@ func (opt *Optimizer) StepCount() int {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||||
func (opt *Optimizer) BackwardStep(loss *ts.Tensor) {
|
func (opt *Optimizer) BackwardStep(loss *ts.Tensor) error {
|
||||||
opt.addMissingVariables()
|
|
||||||
err := opt.opt.ZeroGrad()
|
err := opt.opt.ZeroGrad()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Optimizer - BackwardStep method call - ZeroGrad error: %v\n", err)
|
err = fmt.Errorf("Optimizer.BackwardStep() failed: %w\n", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
loss.MustBackward()
|
loss.MustBackward()
|
||||||
err = opt.opt.Step()
|
err = opt.opt.Step()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Optimizer - BackwardStep method call - Step() error: %v\n", err)
|
err = fmt.Errorf("Optimizer.BackwardStep() failed: %w\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||||
|
func (opt *Optimizer) MustBackwardStep(loss *ts.Tensor) {
|
||||||
|
err := opt.BackwardStep(loss)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
|
// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
|
||||||
//
|
//
|
||||||
// The gradients are clipped based on `max` before being applied.
|
// The gradients are clipped based on `max` before being applied.
|
||||||
func (opt *Optimizer) BackwardStepClip(loss *ts.Tensor, max float64) {
|
func (opt *Optimizer) BackwardStepClip(loss *ts.Tensor, max float64) error {
|
||||||
opt.addMissingVariables()
|
|
||||||
err := opt.opt.ZeroGrad()
|
err := opt.opt.ZeroGrad()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Optimizer - BackwardStepClip method call - ZeroGrad error: %v\n", err)
|
err = fmt.Errorf("Optimizer.BackwardStepClip() failed: %w\n", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
loss.MustBackward()
|
loss.MustBackward()
|
||||||
opt.ClipGradValue(max)
|
opt.ClipGradValue(max)
|
||||||
err = opt.opt.Step()
|
err = opt.opt.Step()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Optimizer - BackwardStepClip method call - Step() error: %v\n", err)
|
err = fmt.Errorf("Optimizer.BackwardStepClip() failed: %w\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
|
||||||
|
//
|
||||||
|
// The gradients are clipped based on `max` before being applied.
|
||||||
|
func (opt *Optimizer) MustBackwardStepClip(loss *ts.Tensor, max float64) {
|
||||||
|
err := opt.BackwardStepClip(loss, max)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TODO. Clips gradient L2 norm over all trainable parameters.
|
type ClipOpts struct {
|
||||||
|
NormType float64
|
||||||
|
ErrorIfNonFinite bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClipOpt func(*ClipOpts)
|
||||||
|
|
||||||
|
func defaultClipOpts() *ClipOpts {
|
||||||
|
return &ClipOpts{
|
||||||
|
NormType: 2.0,
|
||||||
|
ErrorIfNonFinite: false, // will switch to "true" in the future.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithNormType(v float64) ClipOpt {
|
||||||
|
return func(o *ClipOpts) {
|
||||||
|
o.NormType = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithErrorIfNonFinite(v bool) ClipOpt {
|
||||||
|
return func(o *ClipOpts) {
|
||||||
|
o.ErrorIfNonFinite = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clips gradient L2 norm over all trainable parameters.
|
||||||
//
|
//
|
||||||
// The norm is computed over all gradients together, as if they were
|
// The norm is computed over all gradients together, as if they were
|
||||||
// concatenated into a single vector.
|
// concatenated into a single vector.
|
||||||
func (opt *Optimizer) ClipGradNorm(max float64) {
|
//
|
||||||
// TODO.
|
/// Args:
|
||||||
log.Fatalf("Not implemented yet!")
|
// - max: max norm of the gradient
|
||||||
|
// - o.NormType. Type of the used p-norm, can be "inf" for infinity norm. Default= 2.0
|
||||||
|
// - o.ErrorIfNonFinite bool. If true, throw error if total norm of the gradients from paramters is "nan", "inf" or "-inf". Default=false
|
||||||
|
// Returns: total norm of the parameters (viewed as a single vector)
|
||||||
|
// ref. https://github.com/pytorch/pytorch/blob/cb4aeff7d8e4c70bb638cf159878c5204d0cc2da/torch/nn/utils/clip_grad.py#L59
|
||||||
|
func (opt *Optimizer) ClipGradNorm(max float64, opts ...ClipOpt) error {
|
||||||
|
o := defaultClipOpts()
|
||||||
|
for _, option := range opts {
|
||||||
|
option(o)
|
||||||
|
}
|
||||||
|
|
||||||
|
opt.varstore.Lock()
|
||||||
|
defer opt.varstore.Unlock()
|
||||||
|
parameters := opt.varstore.TrainableVariables()
|
||||||
|
if len(parameters) == 0 {
|
||||||
|
// return ts.MustOfSlice([]float64{0.0}), nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
norms []ts.Tensor
|
||||||
|
totalNorm *ts.Tensor
|
||||||
|
)
|
||||||
|
|
||||||
|
device := opt.varstore.device
|
||||||
|
if o.NormType == math.Inf(1) {
|
||||||
|
for _, v := range opt.varstore.vars {
|
||||||
|
n := v.Tensor.MustGrad(false).MustDetach(true).MustAbs(true).MustMax(true).MustTo(device, true)
|
||||||
|
norms = append(norms, *n)
|
||||||
|
}
|
||||||
|
// total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
|
||||||
|
totalNorm = ts.MustStack(norms, 0).MustMax(true)
|
||||||
|
} else {
|
||||||
|
for _, v := range opt.varstore.vars {
|
||||||
|
// x := v.Tensor.MustGrad(false).MustNorm(true)
|
||||||
|
|
||||||
|
// NOTE. tensor.Norm() is going to be deprecated. So use linalg_norm
|
||||||
|
// Ref. https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm
|
||||||
|
x := v.Tensor.MustGrad(false).MustDetach(true).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||||
|
norms = append(norms, *x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// totalNorm = ts.MustStack(norms, 0).MustNorm(true).MustAddScalar(ts.FloatScalar(1e-6), true)
|
||||||
|
// total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
|
||||||
|
totalNorm = ts.MustStack(norms, 0).MustLinalgNorm(ts.FloatScalar(o.NormType), nil, false, gotch.Float, true)
|
||||||
|
for _, x := range norms {
|
||||||
|
x.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
totalNormVal := totalNorm.Float64Values(true)[0]
|
||||||
|
// if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||||
|
if o.ErrorIfNonFinite && (math.IsNaN(totalNormVal) || math.IsInf(totalNormVal, 1)) {
|
||||||
|
err := fmt.Errorf("The total norm of order (%v) for gradients from 'parameters' is non-finite, so it cannot be clipped. To disable this error and scale the gradients by the non-finite norm anyway, set option.ErrorIfNonFinite= false", o.NormType)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// clip_coef = max_norm / (total_norm + 1e-6)
|
||||||
|
// clipCoefTs := ts.TensorFrom([]float64{max}).MustDiv(totalNorm, true)
|
||||||
|
clipCoef := max / (totalNormVal + 1e-6)
|
||||||
|
// NOTE: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||||
|
// avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||||
|
// when the gradients do not reside in CPU memory.
|
||||||
|
// clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||||
|
if clipCoef > 1.0 {
|
||||||
|
clipCoef = 1.0
|
||||||
|
}
|
||||||
|
for _, v := range opt.varstore.vars {
|
||||||
|
if v.Trainable {
|
||||||
|
// p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
|
||||||
|
// v.Tensor.MustGrad(false).MustDetach(true).MustMulScalar_(ts.FloatScalar(clipCoef))
|
||||||
|
v.Tensor.MustGrad(false).MustMulScalar_(ts.FloatScalar(clipCoef))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO. Applies a backward step pass, update the gradients, and performs an optimization step.
|
// BackwardStepClipNorm applies a backward step pass, update the gradients, and performs an optimization step.
|
||||||
//
|
//
|
||||||
// The gradients L2 norm is clipped based on `max`.
|
// The gradients L2 norm is clipped based on `max`.
|
||||||
func (opt *Optimizer) BackwardStepClipNorm(loss *ts.Tensor, max float64) {
|
func (opt *Optimizer) BackwardStepClipNorm(loss *ts.Tensor, max float64, opts ...ClipOpt) error {
|
||||||
// TODO.
|
err := opt.opt.ZeroGrad()
|
||||||
log.Fatalf("Not implemented yet!")
|
if err != nil {
|
||||||
|
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = loss.Backward()
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = opt.ClipGradNorm(max, opts...)
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = opt.Step()
|
||||||
|
if err != nil {
|
||||||
|
err := fmt.Errorf("Optimizer.BackwardStepClipNorm() failed: %w\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBackwardStepClipNorm applies a backward step pass, update the gradients, and performs an optimization step.
|
||||||
|
//
|
||||||
|
// The gradients L2 norm is clipped based on `max`.
|
||||||
|
func (opt *Optimizer) MustBackwardStepClipNorm(loss *ts.Tensor, max float64, opts ...ClipOpt) {
|
||||||
|
err := opt.BackwardStepClipNorm(loss, max, opts...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLR sets the optimizer learning rate.
|
// SetLR sets the optimizer learning rate.
|
||||||
|
|
|
@ -1,69 +1,73 @@
|
||||||
package nn_test
|
package nn_test
|
||||||
|
|
||||||
/*
|
import (
|
||||||
* import (
|
"fmt"
|
||||||
* // "reflect"
|
"testing"
|
||||||
* "fmt"
|
|
||||||
* "log"
|
"github.com/sugarme/gotch"
|
||||||
* "testing"
|
"github.com/sugarme/gotch/nn"
|
||||||
*
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
* "github.com/sugarme/gotch"
|
)
|
||||||
* "github.com/sugarme/gotch/nn"
|
|
||||||
* ts "github.com/sugarme/gotch/tensor"
|
func TestOptimizer(t *testing.T) {
|
||||||
* )
|
x := ts.MustArangeStart(ts.IntScalar(1), ts.IntScalar(15), gotch.Float, gotch.CPU).MustView([]int64{-1, 1}, true)
|
||||||
*
|
// y = x * 0.42 + 1.337
|
||||||
* func TestOptimizer(t *testing.T) {
|
y := x.MustMulScalar(ts.FloatScalar(0.42), false).MustAddScalar(ts.FloatScalar(1.337), false)
|
||||||
*
|
|
||||||
* var data []float32
|
vs := nn.NewVarStore(gotch.CPU)
|
||||||
* for i := 0; i < 15; i++ {
|
path := vs.Root()
|
||||||
* data = append(data, float32(i))
|
|
||||||
* }
|
cfg := &nn.LinearConfig{
|
||||||
* xs, err := ts.NewTensorFromData(data, []int64{int64(len(data)), 1})
|
WsInit: nn.NewConstInit(0.0),
|
||||||
* if err != nil {
|
BsInit: nn.NewConstInit(0.0),
|
||||||
* log.Fatal(err)
|
Bias: true,
|
||||||
* }
|
}
|
||||||
*
|
model := nn.NewLinear(path, 1, 1, cfg)
|
||||||
* ys := xs.MustMul1(ts.FloatScalar(0.42), false).MustAdd1(ts.FloatScalar(1.337), false)
|
|
||||||
*
|
lr := 1e-2
|
||||||
* vs := nn.NewVarStore(gotch.CPU)
|
opt, err := nn.DefaultSGDConfig().Build(vs, lr)
|
||||||
*
|
if err != nil {
|
||||||
* optCfg := nn.DefaultSGDConfig()
|
t.Errorf("Failed building SGD optimizer")
|
||||||
* opt, err := optCfg.Build(vs, 1e-2)
|
}
|
||||||
* if err != nil {
|
|
||||||
* t.Errorf("Failed building SGD optimizer")
|
initialLoss := x.ApplyT(model, true).MustMseLoss(y, 1, true).Float64Values(true)[0]
|
||||||
* }
|
wantLoss := float64(1.0)
|
||||||
*
|
if initialLoss < wantLoss {
|
||||||
* cfg := nn.LinearConfig{
|
t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
|
||||||
* WsInit: nn.NewConstInit(0.0),
|
}
|
||||||
* BsInit: nn.NewConstInit(0.0),
|
|
||||||
* Bias: true,
|
// Optimization loop
|
||||||
* }
|
for i := 0; i < 50; i++ {
|
||||||
*
|
logits := model.ForwardT(x, true)
|
||||||
* linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
loss := logits.MustMseLoss(y, 1, true)
|
||||||
*
|
if i%10 == 0 {
|
||||||
* logits := xs.Apply(linear)
|
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0}))
|
||||||
* loss := logits.MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
}
|
||||||
*
|
opt.BackwardStep(loss)
|
||||||
* initialLoss := loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0})
|
}
|
||||||
*
|
|
||||||
* wantLoss := float64(1.0)
|
loss := x.Apply(model).MustMseLoss(y, 1, true)
|
||||||
*
|
opt.BackwardStep(loss)
|
||||||
* if initialLoss < wantLoss {
|
|
||||||
* t.Errorf("Expect initial loss > %v, got %v", wantLoss, initialLoss)
|
loss = x.Apply(model).MustMseLoss(y, 1, true)
|
||||||
* }
|
finalLoss := loss.Float64Values()[0]
|
||||||
*
|
fmt.Printf("Final loss: %v\n", finalLoss)
|
||||||
* for i := 0; i < 50; i++ {
|
|
||||||
* loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
if finalLoss > 0.25 {
|
||||||
*
|
t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
|
||||||
* opt.BackwardStep(loss)
|
}
|
||||||
* fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}, false).MustFloat64Value([]int64{0}))
|
}
|
||||||
* }
|
|
||||||
*
|
// see https://github.com/pytorch/pytorch/blob/9b203f667ac096db9f5f5679ac3e3d7931c34d36/test/test_nn.py#L2308
|
||||||
* loss = xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt(), true)
|
func TestClipGradNorm(t *testing.T) {
|
||||||
* finalLoss := loss.Values()[0]
|
// TODO.
|
||||||
* fmt.Printf("Final loss: %v\n", finalLoss)
|
// vs := nn.NewVarStore(gotch.CPU)
|
||||||
*
|
// path := vs.Root()
|
||||||
* if finalLoss > 0.25 {
|
// l := nn.NewLinear(path, 10, 10, nn.DefaultLinearConfig())
|
||||||
* t.Errorf("Expect initial loss < 0.25, got %v", finalLoss)
|
// maxNorm := 2.0
|
||||||
* }
|
}
|
||||||
* } */
|
|
||||||
|
// see https://github.com/pytorch/pytorch/blob/9b203f667ac096db9f5f5679ac3e3d7931c34d36/test/test_nn.py#L2364
|
||||||
|
func TestClipGradValue(t *testing.T) {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
28
nn/other.go
28
nn/other.go
|
@ -24,6 +24,9 @@ func (d *Dropout) ForwardT(input *ts.Tensor, train bool) (retVal *ts.Tensor) {
|
||||||
return ts.MustDropout(input, d.dropoutProb, train)
|
return ts.MustDropout(input, d.dropoutProb, train)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parameter:
|
||||||
|
// ==========
|
||||||
|
|
||||||
// NewParameter creates a kind of tensor that is considered as a module parameter.
|
// NewParameter creates a kind of tensor that is considered as a module parameter.
|
||||||
// Ref. https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html
|
// Ref. https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html
|
||||||
func NewParameter(path *Path, name string, x *ts.Tensor, requireGradOpt ...bool) *ts.Tensor {
|
func NewParameter(path *Path, name string, x *ts.Tensor, requireGradOpt ...bool) *ts.Tensor {
|
||||||
|
@ -32,11 +35,34 @@ func NewParameter(path *Path, name string, x *ts.Tensor, requireGradOpt ...bool)
|
||||||
requiredGrad = requireGradOpt[0]
|
requiredGrad = requireGradOpt[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
param := path.Add(name, x, requiredGrad)
|
param := path.MustAdd(name, x, requiredGrad)
|
||||||
|
|
||||||
return param
|
return param
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Buffer:
|
||||||
|
// =======
|
||||||
|
|
||||||
|
// NewBuffer creates new buffer.
|
||||||
|
//
|
||||||
|
// Buffer is different from Parameter as its requiredGrad always false.
|
||||||
|
// - `o.Persistent` param. Default=true. If `true` buffer variable will be saved when `nn.VarStore.Save()` is called.
|
||||||
|
//
|
||||||
|
// Ref.
|
||||||
|
// - https://github.com/pytorch/pytorch/blob/f71eede85a69caed637008e331f5ac5f5b7717ae/torch/nn/modules/module.py#L275
|
||||||
|
// - https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/2
|
||||||
|
func NewBuffer(path *Path, name string, x *ts.Tensor, persistentOpt ...bool) *ts.Tensor {
|
||||||
|
persistent := true
|
||||||
|
if len(persistentOpt) > 0 {
|
||||||
|
persistent = persistentOpt[0]
|
||||||
|
}
|
||||||
|
opts := []AddOpt{
|
||||||
|
WithPersistent(persistent),
|
||||||
|
WithVarType("buffer"),
|
||||||
|
}
|
||||||
|
return path.MustAdd(name, x, false, opts...) // requiredGrad always false. Different from parameter.
|
||||||
|
}
|
||||||
|
|
||||||
// Identity:
|
// Identity:
|
||||||
// =========
|
// =========
|
||||||
|
|
||||||
|
|
32
nn/rnn.go
32
nn/rnn.go
|
@ -97,26 +97,26 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
||||||
}
|
}
|
||||||
switch numDirections {
|
switch numDirections {
|
||||||
case 1:
|
case 1:
|
||||||
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
wIh := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||||
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
wHh := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||||
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||||
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||||
|
|
||||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||||
|
|
||||||
case 2: // bi-directional
|
case 2: // bi-directional
|
||||||
// forward
|
// forward
|
||||||
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
wIh := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||||
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
wHh := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||||
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
bIh := vs.MustZeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||||
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
bHh := vs.MustZeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||||
|
|
||||||
// reverse
|
// reverse
|
||||||
wIhR := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
|
wIhR := vs.MustKaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
|
||||||
wHhR := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
|
wHhR := vs.MustKaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
|
||||||
bIhR := vs.Zeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
|
bIhR := vs.MustZeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
|
||||||
bHhR := vs.Zeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
|
bHhR := vs.MustZeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
|
||||||
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
|
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -234,10 +234,10 @@ func NewGRU(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) (retVal *GRU) {
|
||||||
inputDim = hiddenDim * numDirections
|
inputDim = hiddenDim * numDirections
|
||||||
}
|
}
|
||||||
|
|
||||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inputDim})
|
wIh := vs.MustKaimingUniform("w_ih", []int64{gateDim, inputDim})
|
||||||
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
wHh := vs.MustKaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||||
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
bIh := vs.MustZeros("b_ih", []int64{gateDim})
|
||||||
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
bHh := vs.MustZeros("b_hh", []int64{gateDim})
|
||||||
|
|
||||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||||
}
|
}
|
||||||
|
|
|
@ -251,6 +251,7 @@ func BatchAccuracyForLogits(vs *VarStore, m ts.ModuleT, xs, ys *ts.Tensor, d got
|
||||||
|
|
||||||
logits := m.ForwardT(bImages, false)
|
logits := m.ForwardT(bImages, false)
|
||||||
acc := logits.AccuracyForLogits(bLabels)
|
acc := logits.AccuracyForLogits(bLabels)
|
||||||
|
logits.MustDrop()
|
||||||
sumAccuracy += acc.Float64Values()[0] * size
|
sumAccuracy += acc.Float64Values()[0] * size
|
||||||
sampleCount += size
|
sampleCount += size
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ type Embedding struct {
|
||||||
// NewEmbedding creates a new Embedding
|
// NewEmbedding creates a new Embedding
|
||||||
func NewEmbedding(vs *Path, numEmbeddings int64, embeddingDim int64, config *EmbeddingConfig) *Embedding {
|
func NewEmbedding(vs *Path, numEmbeddings int64, embeddingDim int64, config *EmbeddingConfig) *Embedding {
|
||||||
return &Embedding{
|
return &Embedding{
|
||||||
Ws: vs.NewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
|
Ws: vs.MustNewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
|
||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
233
nn/varstore.go
233
nn/varstore.go
|
@ -187,6 +187,46 @@ func (vs *VarStore) Load(filepath string) error {
|
||||||
v.Tensor.Copy_(currTs)
|
v.Tensor.Copy_(currTs)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, x := range namedTensors {
|
||||||
|
x.Tensor.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWeights loads pretrained weights to VarStore.
|
||||||
|
func (vs *VarStore) LoadWeights(namedTensors []ts.NamedTensor) error {
|
||||||
|
var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0)
|
||||||
|
for _, namedTensor := range namedTensors {
|
||||||
|
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match and in-place copy value (update) from newly loaded tensors
|
||||||
|
// to existing named tensors if name is matched. Throw error otherwise.
|
||||||
|
vs.Lock()
|
||||||
|
defer vs.Unlock()
|
||||||
|
|
||||||
|
for name, v := range vs.vars {
|
||||||
|
// missing variable
|
||||||
|
currTs, ok := namedTensorsMap[name]
|
||||||
|
if !ok {
|
||||||
|
err := fmt.Errorf("VarStore.LoadWeights() failed: there's a tensor with name %q in VarStore, but not found in the loaded weights.\n", name)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// mismatched shape
|
||||||
|
sourceShape := currTs.MustSize()
|
||||||
|
destShape := v.Tensor.MustSize()
|
||||||
|
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||||
|
err := fmt.Errorf("VarStore.LoadWeights() failed. Mismatched shape error for variable name: %v - At store: %v - At source %v\n", name, destShape, sourceShape)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts.NoGrad(func() {
|
||||||
|
v.Tensor.Copy_(currTs)
|
||||||
|
})
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,6 +282,60 @@ func (vs *VarStore) LoadPartial(filepath string) ([]string, error) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, x := range namedTensors {
|
||||||
|
x.Tensor.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
return missingVariables, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadWeightsPartial loads the VarStore variable values from a file if it exists.
|
||||||
|
//
|
||||||
|
// Weight values for the tensors currently stored in the var-store and the given file get
|
||||||
|
// loaded from the given file. If a variable in the var store is not present in the given file,
|
||||||
|
// it is skipped and its values are not updated. This method should be used if pre-trained
|
||||||
|
// weight for only parts of the model are available.
|
||||||
|
// Note that the set of variables stored in the var-store is not changed, only the values
|
||||||
|
// for these tensors are modified.
|
||||||
|
//
|
||||||
|
// Returns a String Vector containing the names of missing variables.
|
||||||
|
func (vs *VarStore) LoadWeightsPartial(namedTensors []ts.NamedTensor) ([]string, error) {
|
||||||
|
var namedTensorsMap map[string]*ts.Tensor = make(map[string]*ts.Tensor, 0)
|
||||||
|
for _, namedTensor := range namedTensors {
|
||||||
|
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
var missingVariables []string
|
||||||
|
|
||||||
|
// Match and in-place copy value (update) from newly loaded tensors
|
||||||
|
// to existing named tensors if name is matched. Throw error otherwise.
|
||||||
|
vs.Lock()
|
||||||
|
defer vs.Unlock()
|
||||||
|
|
||||||
|
for name, v := range vs.vars {
|
||||||
|
var currTs *ts.Tensor
|
||||||
|
var ok bool
|
||||||
|
|
||||||
|
// missing variable
|
||||||
|
if currTs, ok = namedTensorsMap[name]; !ok {
|
||||||
|
missingVariables = append(missingVariables, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// mismatched shape
|
||||||
|
destShape := currTs.MustSize()
|
||||||
|
sourceShape := v.Tensor.MustSize()
|
||||||
|
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||||
|
fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", name, destShape, sourceShape)
|
||||||
|
missingVariables = append(missingVariables, name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ts.NoGrad(func() {
|
||||||
|
v.Tensor.Copy_(currTs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return missingVariables, nil
|
return missingVariables, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,7 +378,7 @@ func (vs *VarStore) Unfreeze() error {
|
||||||
//
|
//
|
||||||
// All the variables in this var store have to exist with the same
|
// All the variables in this var store have to exist with the same
|
||||||
// name in the source var store, otherwise an error is returned.
|
// name in the source var store, otherwise an error is returned.
|
||||||
func (vs *VarStore) Copy(src VarStore) error {
|
func (vs *VarStore) Copy(src *VarStore) error {
|
||||||
vs.Lock()
|
vs.Lock()
|
||||||
defer vs.Unlock()
|
defer vs.Unlock()
|
||||||
src.Lock()
|
src.Lock()
|
||||||
|
@ -343,6 +437,34 @@ func (vs *VarStore) Summary() {
|
||||||
fmt.Printf("Num of layers: %v\n", len(vars))
|
fmt.Printf("Num of layers: %v\n", len(vars))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToDType casts all variables in VarStore to specified DType.
|
||||||
|
//
|
||||||
|
// NOTE. only float-like types (Half, Float, Double) can ensure convertible.
|
||||||
|
func (vs *VarStore) ToDType(dtype gotch.DType) {
|
||||||
|
vs.Root().ToDType(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToHalf casts all float-like variables in VarStore to `Half` dtype.
|
||||||
|
//
|
||||||
|
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||||
|
func (vs *VarStore) ToHalf() {
|
||||||
|
vs.Root().ToHalf()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToFloat casts all float-like variables in VarStore to `Float` dtype.
|
||||||
|
//
|
||||||
|
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||||
|
func (vs *VarStore) ToFloat() {
|
||||||
|
vs.Root().ToFloat()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDouble casts all float-like variables in VarStore to `Double` dtype.
|
||||||
|
//
|
||||||
|
// NOTE. float-like includes `Half`, `Float` and `Double` dtype.
|
||||||
|
func (vs *VarStore) ToDouble() {
|
||||||
|
vs.Root().ToDouble()
|
||||||
|
}
|
||||||
|
|
||||||
// Path methods:
|
// Path methods:
|
||||||
// =============
|
// =============
|
||||||
|
|
||||||
|
@ -467,6 +589,23 @@ func (p *Path) Add(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) (*
|
||||||
return p.add(name, x, trainable, o.VarType, o.Persistent)
|
return p.add(name, x, trainable, o.VarType, o.Persistent)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MustAdd adds a tensor to a given path.
|
||||||
|
//
|
||||||
|
// Args
|
||||||
|
// - name: intention name of variable in VarStore (if duplicated, it will be added a suffix number)
|
||||||
|
// - x: tensor holding values to keep in VarStore
|
||||||
|
// - trainable: marked whether tensor is trainable.
|
||||||
|
// - o.VarType: variable type, i.e., either "parameter" or "buffer"
|
||||||
|
// - o.Persistent: whether to save this variables when `VarStore.Save()` is called. Only applied to `buffer` type.
|
||||||
|
// Returns a reference to a tensor stored in VarStore.
|
||||||
|
func (p *Path) MustAdd(name string, x *ts.Tensor, trainable bool, opts ...AddOpt) *ts.Tensor {
|
||||||
|
x, err := p.Add(name, x, trainable, opts...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) {
|
func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool, opts ...AddOpt) (*ts.Tensor, error) {
|
||||||
path := p.getpath(name)
|
path := p.getpath(name)
|
||||||
|
|
||||||
|
@ -480,9 +619,73 @@ func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Path) SetGroup(g uint) {
|
func (p *Path) SetGroup(g uint) {
|
||||||
|
p.varstore.Lock()
|
||||||
|
defer p.varstore.Unlock()
|
||||||
|
|
||||||
|
// TODO. set group for individual variables.
|
||||||
|
// TBD. variables of current path only or all sub paths as well?
|
||||||
|
// For now, just set group for all variable at the path
|
||||||
|
path := strings.Join(p.path, SEP)
|
||||||
|
for name, v := range p.varstore.vars {
|
||||||
|
vpaths := strings.Split(name, SEP)
|
||||||
|
vpath := strings.Join(vpaths[:len(vpaths)-1], SEP)
|
||||||
|
if vpath == path {
|
||||||
|
v.Group = g
|
||||||
|
p.varstore.vars[name] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
p.group = g
|
p.group = g
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToDType casts all variables in this path and its sub-paths to the specified dtype.
|
||||||
|
//
|
||||||
|
// NOTE. this method should be used for floating-point conversion, i.e.,
|
||||||
|
// "gotch.Float", "gotch.Half", "gotch.Float16", "gotch.Double".
|
||||||
|
func (p *Path) ToDType(dtype gotch.DType) {
|
||||||
|
p.varstore.Lock()
|
||||||
|
defer p.varstore.Unlock()
|
||||||
|
path := strings.Join(p.path, SEP)
|
||||||
|
for name, v := range p.varstore.vars {
|
||||||
|
if strings.Contains(name, path) {
|
||||||
|
newVar := v
|
||||||
|
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
||||||
|
p.varstore.vars[name] = newVar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// toFloat casts all float-like variables in this current path and sub-paths to specified dtype.
|
||||||
|
func (p *Path) toFloat(dtype gotch.DType) {
|
||||||
|
p.varstore.Lock()
|
||||||
|
defer p.varstore.Unlock()
|
||||||
|
path := strings.Join(p.path, SEP)
|
||||||
|
for name, v := range p.varstore.vars {
|
||||||
|
if strings.Contains(name, path) {
|
||||||
|
dtype := v.Tensor.DType()
|
||||||
|
if dtype == gotch.Half || dtype == gotch.Float || dtype == gotch.Double {
|
||||||
|
newVar := v
|
||||||
|
newVar.Tensor = v.Tensor.MustTotype(dtype, true)
|
||||||
|
p.varstore.vars[name] = newVar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToHalf casts all variables in current path and subpaths to `Half` precision.
|
||||||
|
func (p *Path) ToHalf() {
|
||||||
|
p.toFloat(gotch.Half)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToFloat casts all variables in current path and subpaths to `Float` precision.
|
||||||
|
func (p *Path) ToFloat() {
|
||||||
|
p.toFloat(gotch.Float)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDouble casts all variables in current path and subpaths to `Double` precision.
|
||||||
|
func (p *Path) ToDouble() {
|
||||||
|
p.toFloat(gotch.Double)
|
||||||
|
}
|
||||||
|
|
||||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||||
//
|
//
|
||||||
// The new variable is named according to the name parameter and
|
// The new variable is named according to the name parameter and
|
||||||
|
@ -506,6 +709,20 @@ func (p *Path) ZerosNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tens
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MustZerosNoTrain creates a new variable initialized with zeros.
|
||||||
|
//
|
||||||
|
// The new variable is named according to the name parameter and
|
||||||
|
// has the specified shape. The variable will not be trainable so
|
||||||
|
// gradients will not be tracked.
|
||||||
|
// The variable uses a float tensor initialized with zeros.
|
||||||
|
func (p *Path) MustZerosNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor {
|
||||||
|
x, err := p.ZerosNoTrain(name, dims, opts...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
// OnesNoTrain creates a new variable initialized with ones.
|
// OnesNoTrain creates a new variable initialized with ones.
|
||||||
//
|
//
|
||||||
// The new variable is named according to the name parameter and
|
// The new variable is named according to the name parameter and
|
||||||
|
@ -529,6 +746,20 @@ func (p *Path) OnesNoTrain(name string, dims []int64, opts ...AddOpt) (*ts.Tenso
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MustOnesNoTrain creates a new variable initialized with ones.
|
||||||
|
//
|
||||||
|
// The new variable is named according to the name parameter and
|
||||||
|
// has the specified shape. The variable will not be trainable so
|
||||||
|
// gradients will not be tracked.
|
||||||
|
// The variable uses a float tensor initialized with ones.
|
||||||
|
func (p *Path) MustOnesNoTrain(name string, dims []int64, opts ...AddOpt) *ts.Tensor {
|
||||||
|
x, err := p.OnesNoTrain(name, dims, opts...)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
// NewVar creates a new variable.
|
// NewVar creates a new variable.
|
||||||
//
|
//
|
||||||
// The new variable is named according to the name parameter and
|
// The new variable is named according to the name parameter and
|
||||||
|
|
|
@ -439,26 +439,19 @@ func LoadAll(vs *nn.VarStore, modelFile string) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// for tsName, _ := range vs.Vars.NamedVariables {
|
var namedTensors []ts.NamedTensor
|
||||||
for tsName := range vs.Vars.NamedVariables {
|
for n, x := range weights {
|
||||||
// missing variable
|
namedTs := ts.NamedTensor{
|
||||||
currTs, ok := weights[tsName]
|
Name: n,
|
||||||
if !ok {
|
Tensor: x,
|
||||||
err = fmt.Errorf("LoadAll() failed: Cannot find tensor with name: %v in variable store. \n", tsName)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mismatched shape
|
namedTensors = append(namedTensors, namedTs)
|
||||||
sourceShape := currTs.MustSize()
|
}
|
||||||
destShape := vs.Vars.NamedVariables[tsName].MustSize()
|
|
||||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
|
||||||
err = fmt.Errorf("LoadAll() failed: Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
err = vs.LoadWeights(namedTensors)
|
||||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
if err != nil {
|
||||||
})
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, x := range weights {
|
for _, x := range weights {
|
||||||
|
@ -477,32 +470,21 @@ func LoadPartial(vs *nn.VarStore, modelFile string) ([]string, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var namedTensors []ts.NamedTensor
|
||||||
|
for n, x := range weights {
|
||||||
|
namedTs := ts.NamedTensor{
|
||||||
|
Name: n,
|
||||||
|
Tensor: x,
|
||||||
|
}
|
||||||
|
|
||||||
|
namedTensors = append(namedTensors, namedTs)
|
||||||
|
}
|
||||||
|
|
||||||
var missingVariables []string
|
var missingVariables []string
|
||||||
|
|
||||||
// Match and in-place copy value (update) from newly loaded tensors
|
missingVariables, err = vs.LoadWeightsPartial(namedTensors)
|
||||||
// to existing named tensors if name is matched. Throw error otherwise.
|
if err != nil {
|
||||||
for tsName := range vs.Vars.NamedVariables {
|
return nil, err
|
||||||
var currTs *ts.Tensor
|
|
||||||
var ok bool
|
|
||||||
|
|
||||||
// missing variable
|
|
||||||
if currTs, ok = weights[tsName]; !ok {
|
|
||||||
missingVariables = append(missingVariables, tsName)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// mismatched shape
|
|
||||||
destShape := currTs.MustSize()
|
|
||||||
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
|
||||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
|
||||||
fmt.Printf("WARNING: Mismatched shape error for variable name: %v - At store: %v - At source %v. Skip loading this weight...\n", tsName, destShape, sourceShape)
|
|
||||||
missingVariables = append(missingVariables, tsName)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
|
||||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, x := range weights {
|
for _, x := range weights {
|
||||||
|
|
|
@ -1101,38 +1101,39 @@ func (ngg *NoGradGuard) Enable() {
|
||||||
_ = MustGradSetEnabled(ngg.enabled)
|
_ = MustGradSetEnabled(ngg.enabled)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduction type is an enum-like type
|
|
||||||
type Reduction int
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// Do not reduce
|
// Do not reduce
|
||||||
ReductionNone Reduction = iota
|
ReductionNone int64 = 0
|
||||||
// Mean of losses
|
// Mean of losses
|
||||||
ReductionMean
|
ReductionMean int64 = 1
|
||||||
// Sum of losses
|
// Sum of losses
|
||||||
ReductionSum
|
ReductionSum int64 = 2
|
||||||
// Escape hatch in case new options become available
|
// Escape hatch in case new options become available
|
||||||
ReductionOther
|
ReductionOther int64 = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r Reduction) ToInt() int {
|
// func (r Reduction) ToInt() int {
|
||||||
switch r {
|
// switch r {
|
||||||
case ReductionNone:
|
// case ReductionNone:
|
||||||
return 0
|
// return 0
|
||||||
case ReductionMean:
|
// case ReductionMean:
|
||||||
return 1
|
// return 1
|
||||||
case ReductionSum:
|
// case ReductionSum:
|
||||||
return 2
|
// return 2
|
||||||
case ReductionOther:
|
// case ReductionOther:
|
||||||
return 3
|
// return 3
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
// NOTE. should it be panic here instead of returning -1?
|
// // NOTE. should it be panic here instead of returning -1?
|
||||||
return -1
|
// return -1
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Float64Values returns values of tensor in a slice of float64.
|
// Float64Values returns values of tensor in a slice of float64.
|
||||||
func (ts *Tensor) Float64Values() []float64 {
|
func (ts *Tensor) Float64Values(delOpt ...bool) []float64 {
|
||||||
|
del := false
|
||||||
|
if len(delOpt) > 0 {
|
||||||
|
del = delOpt[0]
|
||||||
|
}
|
||||||
numel := ts.Numel()
|
numel := ts.Numel()
|
||||||
vec := make([]float64, numel)
|
vec := make([]float64, numel)
|
||||||
|
|
||||||
|
@ -1141,11 +1142,19 @@ func (ts *Tensor) Float64Values() []float64 {
|
||||||
float64Ts.MustCopyData(vec, numel)
|
float64Ts.MustCopyData(vec, numel)
|
||||||
float64Ts.MustDrop()
|
float64Ts.MustDrop()
|
||||||
|
|
||||||
|
if del {
|
||||||
|
ts.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
return vec
|
return vec
|
||||||
}
|
}
|
||||||
|
|
||||||
// Int64Values returns values of tensor in a slice of int64.
|
// Int64Values returns values of tensor in a slice of int64.
|
||||||
func (ts *Tensor) Int64Values() []int64 {
|
func (ts *Tensor) Int64Values(delOpt ...bool) []int64 {
|
||||||
|
del := false
|
||||||
|
if len(delOpt) > 0 {
|
||||||
|
del = delOpt[0]
|
||||||
|
}
|
||||||
numel := ts.Numel()
|
numel := ts.Numel()
|
||||||
vec := make([]int64, numel)
|
vec := make([]int64, numel)
|
||||||
|
|
||||||
|
@ -1154,6 +1163,10 @@ func (ts *Tensor) Int64Values() []int64 {
|
||||||
int64Ts.MustCopyData(vec, numel)
|
int64Ts.MustCopyData(vec, numel)
|
||||||
int64Ts.MustDrop()
|
int64Ts.MustDrop()
|
||||||
|
|
||||||
|
if del {
|
||||||
|
ts.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
return vec
|
return vec
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user