fix(nn/sequential, linear): fixed memory held up due to undeleted middle tensors
This commit is contained in:
parent
9d31337b4f
commit
ed709027c0
|
@ -41,7 +41,7 @@ func runLinear() {
|
|||
})
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
testAccuracy := testLogits.MustArgmax(-1, false, true).MustEq1(ds.TestLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Values()[0], testAccuracy*100)
|
||||
|
||||
|
|
|
@ -28,9 +28,9 @@ func netInit(vs nn.Path) ts.Module {
|
|||
|
||||
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig()))
|
||||
|
||||
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustRelu(true)
|
||||
}))
|
||||
// n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
// return xs.MustRelu(true)
|
||||
// }))
|
||||
|
||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig()))
|
||||
// n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig()))
|
||||
|
@ -40,7 +40,8 @@ func netInit(vs nn.Path) ts.Module {
|
|||
|
||||
func train(trainX, trainY, testX, testY ts.Tensor, m ts.Module, opt nn.Optimizer, epoch int) {
|
||||
|
||||
loss := m.Forward(trainX).CrossEntropyForLogits(trainY)
|
||||
logits := m.Forward(trainX)
|
||||
loss := logits.CrossEntropyForLogits(trainY)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
|
|
205
example/seq/main.go
Normal file
205
example/seq/main.go
Normal file
|
@ -0,0 +1,205 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// noSeq()
|
||||
withSeq()
|
||||
// noSeq2Layers()
|
||||
|
||||
// seqNoVarStore()
|
||||
}
|
||||
|
||||
func noSeq() {
|
||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
||||
|
||||
wsInit := nn.NewKaimingUniformInit()
|
||||
ws := wsInit.InitTensor([]int64{10, 784}, gotch.CPU).MustT(true)
|
||||
|
||||
bound := 1.0 / math.Sqrt(float64(784))
|
||||
bsInit := nn.NewUniformInit(-bound, bound)
|
||||
bs := bsInit.InitTensor([]int64{10}, gotch.CPU)
|
||||
|
||||
for i := 0; i < 2000; i++ {
|
||||
mul := ds.TrainImages.MustMatMul(ws, false)
|
||||
logits := mul.MustAdd(bs, true)
|
||||
loss := logits.AccuracyForLogits(ds.TrainLabels)
|
||||
|
||||
fmt.Printf("Epoch %v\t Loss: %.3f\n", i, loss.Values()[0])
|
||||
loss.MustDrop()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func withSeq() {
|
||||
seq := nn.Seq()
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
// seq.Add(nn.NewLinear(vs.Root(), 784, 10, *nn.DefaultLinearConfig()))
|
||||
seq.Add(nn.NewLinear(vs.Root(), 784, 128, *nn.DefaultLinearConfig()))
|
||||
seq.Add(nn.NewLinear(vs.Root(), 128, 10, *nn.DefaultLinearConfig()))
|
||||
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, 1e-2)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
||||
|
||||
for i := 0; i < 2000; i++ {
|
||||
logits := seq.Forward(ds.TrainImages)
|
||||
loss := logits.CrossEntropyForLogits(ds.TrainLabels)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
testLogits := seq.Forward(ds.TestImages)
|
||||
testAccuracy := testLogits.AccuracyForLogits(ds.TestLabels)
|
||||
|
||||
fmt.Printf("Epoch: %v \t Loss: %.3f \t Test accuracy: %.2f%%\n", i, loss.Values()[0], testAccuracy.Values()[0]*100)
|
||||
|
||||
loss.MustDrop()
|
||||
testAccuracy.MustDrop()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func noSeq2Layers() {
|
||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
||||
|
||||
wsInit := nn.NewKaimingUniformInit()
|
||||
ws1 := wsInit.InitTensor([]int64{1024, 784}, gotch.CPU).MustT(true)
|
||||
ws2 := wsInit.InitTensor([]int64{10, 1024}, gotch.CPU).MustT(true)
|
||||
|
||||
bound1 := 1.0 / math.Sqrt(float64(784))
|
||||
bsInit1 := nn.NewUniformInit(-bound1, bound1)
|
||||
bs1 := bsInit1.InitTensor([]int64{1024}, gotch.CPU)
|
||||
|
||||
bound2 := 1.0 / math.Sqrt(float64(1024))
|
||||
bsInit2 := nn.NewUniformInit(-bound2, bound2)
|
||||
bs2 := bsInit2.InitTensor([]int64{10}, gotch.CPU)
|
||||
|
||||
for i := 0; i < 2000; i++ {
|
||||
mul1 := ds.TrainImages.MustMatMul(ws1, false)
|
||||
out1 := mul1.MustAdd(bs1, true)
|
||||
|
||||
mul2 := out1.MustMatMul(ws2, true)
|
||||
logits := mul2.MustAdd(bs2, true)
|
||||
|
||||
loss := logits.AccuracyForLogits(ds.TrainLabels)
|
||||
|
||||
fmt.Printf("Epoch %v\t Loss: %.3f\n", i, loss.Values()[0])
|
||||
loss.MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
func seqNoVarStore() {
|
||||
|
||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
||||
|
||||
wsInit := nn.NewKaimingUniformInit()
|
||||
ws1 := wsInit.InitTensor([]int64{1024, 784}, gotch.CPU).MustT(true)
|
||||
ws2 := wsInit.InitTensor([]int64{10, 1024}, gotch.CPU).MustT(true)
|
||||
|
||||
bound1 := 1.0 / math.Sqrt(float64(784))
|
||||
bsInit1 := nn.NewUniformInit(-bound1, bound1)
|
||||
bs1 := bsInit1.InitTensor([]int64{1024}, gotch.CPU)
|
||||
|
||||
bound2 := 1.0 / math.Sqrt(float64(1024))
|
||||
bsInit2 := nn.NewUniformInit(-bound2, bound2)
|
||||
bs2 := bsInit2.InitTensor([]int64{10}, gotch.CPU)
|
||||
|
||||
l1 := Linear{&ws1, &bs1}
|
||||
l2 := Linear{&ws2, &bs2}
|
||||
|
||||
seq := Seq()
|
||||
seq.Add(l1)
|
||||
seq.Add(l2)
|
||||
// seq.Add1(l1)
|
||||
// seq.Add2(l2)
|
||||
|
||||
for i := 0; i < 2000; i++ {
|
||||
logits := seq.Forward(ds.TrainImages)
|
||||
|
||||
logits.MustDrop()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
type Linear struct {
|
||||
Ws *ts.Tensor
|
||||
Bs *ts.Tensor
|
||||
}
|
||||
|
||||
func (l Linear) Forward(xs ts.Tensor) ts.Tensor {
|
||||
mul := xs.MustMatMul(*l.Ws, false)
|
||||
return mul.MustAdd(*l.Bs, true)
|
||||
}
|
||||
|
||||
type Sequential struct {
|
||||
layers []ts.Module
|
||||
l1 ts.Module
|
||||
l2 ts.Module
|
||||
}
|
||||
|
||||
func Seq() Sequential {
|
||||
return Sequential{layers: make([]ts.Module, 0)}
|
||||
}
|
||||
|
||||
// Len returns number of sub-layers embedded in this layer
|
||||
func (s *Sequential) Len() (retVal int64) {
|
||||
return int64(len(s.layers))
|
||||
}
|
||||
|
||||
// IsEmpty returns true if this layer does not have any sub-layers.
|
||||
func (s *Sequential) IsEmpty() (retVal bool) {
|
||||
return len(s.layers) == 0
|
||||
}
|
||||
|
||||
// Add appends a layer after all the current layers.
|
||||
func (s *Sequential) Add(l ts.Module) {
|
||||
|
||||
s.layers = append(s.layers, l)
|
||||
}
|
||||
|
||||
func (s *Sequential) Add1(l ts.Module) {
|
||||
s.l1 = l
|
||||
}
|
||||
|
||||
func (s *Sequential) Add2(l ts.Module) {
|
||||
s.l2 = l
|
||||
}
|
||||
|
||||
func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
if s.IsEmpty() {
|
||||
return xs.MustShallowClone()
|
||||
}
|
||||
|
||||
// forward sequentially
|
||||
outs := make([]ts.Tensor, len(s.layers))
|
||||
for i := 0; i < len(s.layers); i++ {
|
||||
if i == 0 {
|
||||
outs[0] = s.layers[i].Forward(xs)
|
||||
defer outs[0].MustDrop()
|
||||
} else if i == len(s.layers)-1 {
|
||||
return s.layers[i].Forward(outs[i-1])
|
||||
} else {
|
||||
outs[i+1] = s.layers[i].Forward(outs[i-1])
|
||||
defer outs[i+1].MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
// out1 := s.l1.Forward(xs)
|
||||
// defer out1.MustDrop()
|
||||
//
|
||||
// return s.l2.Forward(out1)
|
||||
|
||||
}
|
|
@ -56,7 +56,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) *Linear {
|
|||
}
|
||||
|
||||
return &Linear{
|
||||
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit),
|
||||
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
||||
Bs: bs,
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) *Linear {
|
|||
// 1 1 1
|
||||
// 1 1 1 ]
|
||||
func (l *Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
// TODO: measure memory leak here.
|
||||
mul := xs.MustMatMul(l.Ws.MustT(false), false)
|
||||
|
||||
mul := xs.MustMatMul(l.Ws, false)
|
||||
return mul.MustAdd(l.Bs, true)
|
||||
}
|
||||
|
|
|
@ -72,18 +72,28 @@ func WithUint8(n uint8) func() uint8 {
|
|||
|
||||
// Implement Module interface for Sequential:
|
||||
// ==========================================
|
||||
func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
|
||||
// Forward implements Module interface for Sequential
|
||||
func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
if s.IsEmpty() {
|
||||
return xs.MustShallowClone()
|
||||
}
|
||||
|
||||
// forward sequentially
|
||||
outs := make([]ts.Tensor, len(s.layers))
|
||||
for i := 0; i < len(s.layers); i++ {
|
||||
// xs = s.layers[i].Forward(xs)
|
||||
xs = xs.Apply(s.layers[i])
|
||||
if i == 0 {
|
||||
outs[0] = s.layers[i].Forward(xs)
|
||||
defer outs[0].MustDrop()
|
||||
} else if i == len(s.layers)-1 {
|
||||
return s.layers[i].Forward(outs[i-1])
|
||||
} else {
|
||||
outs[i+1] = s.layers[i].Forward(outs[i-1])
|
||||
defer outs[i+1].MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
return xs
|
||||
return
|
||||
}
|
||||
|
||||
// SequentialT is a sequential layer combining new layers with support for a training mode.
|
||||
|
|
|
@ -8,13 +8,18 @@ import (
|
|||
|
||||
// CrossEntropyForLogits computes the cross-entropy loss based on some logits and targets.
|
||||
func (ts Tensor) CrossEntropyForLogits(targets Tensor) (retVal Tensor) {
|
||||
return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true)
|
||||
// return ts.MustLogSoftmax(-1, gotch.Float.CInt(), true).MustNllLoss(targets, true)
|
||||
|
||||
logSm := ts.MustLogSoftmax(-1, gotch.Float.CInt(), true)
|
||||
return logSm.MustNllLoss(targets, true)
|
||||
}
|
||||
|
||||
// AccuracyForLogits returns the average accuracy for some given logits assuming that
|
||||
// targets represent ground-truth.
|
||||
func (ts Tensor) AccuracyForLogits(targets Tensor) (retVal Tensor) {
|
||||
return ts.MustArgmax(-1, false, true).MustEq1(targets).MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
|
||||
argmax := ts.MustArgmax(-1, false, true)
|
||||
eq1 := argmax.MustEq1(targets, true)
|
||||
return eq1.MustTotype(gotch.Float, true).MustMean(gotch.Float.CInt(), true)
|
||||
}
|
||||
|
||||
func (ts Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal Tensor) {
|
||||
|
|
|
@ -295,12 +295,14 @@ func (ts Tensor) Device() (retVal gotch.Device, err error) {
|
|||
return device.OfCInt(int32(cInt)), nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
||||
func (ts Tensor) Eq1(other Tensor, del bool) (retVal Tensor, err error) {
|
||||
|
||||
// Get a C null pointer
|
||||
// https://stackoverflow.com/a/2022369
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
|
@ -311,8 +313,8 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
|||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustEq1(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Eq1(other)
|
||||
func (ts Tensor) MustEq1(other Tensor, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Eq1(other, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user