From 20f338014d224e0b9e4f66d2037c378e6920bed0 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 17 Jun 2020 16:24:27 +1000 Subject: [PATCH] fix(tensor/NoGrad): added a hacky way to get loss func updated --- README.md | 19 +++++++++++++++++++ example/mnist/linear.go | 41 +++++++++++++++-------------------------- nn/varstore.go | 40 +++++----------------------------------- tensor/tensor.go | 24 ++++++++++++++++++++---- 4 files changed, 59 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 84fe980..c539083 100644 --- a/README.md +++ b/README.md @@ -48,8 +48,27 @@ - Other examples can be found at `example` folder +### 3. Notes on running examples + +- Clean Cgo cache to get a fresh build as [mention here](https://github.com/golang/go/issues/24355) + +```bash + # Either + go clean -cache -testcache . + go run [EXAMPLE FILES] + + # Or + go build -a + + go run [EXAMPLE FILES] + +``` + + ## Acknowledgement - This projects has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs) Libtorch Rust binding. + + diff --git a/example/mnist/linear.go b/example/mnist/linear.go index b29e792..3e2cf97 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -13,10 +13,8 @@ const ( Label int64 = 10 MnistDir string = "../../data/mnist" - // epochs = 500 - // batchSize = 256 epochs = 200 - batchSize = 60000 + batchSize = 256 ) func runLinear() { @@ -44,6 +42,7 @@ func runLinear() { * * batches := samples / batchSize * batchIndex := 0 + * var loss ts.Tensor * for i := 0; i < batches; i++ { * start := batchIndex * batchSize * size := batchSize @@ -55,20 +54,20 @@ func runLinear() { * * // Indexing * narrowIndex := ts.NewNarrow(int64(start), int64(start+size)) - * // bImages := ds.TrainImages.Idx(narrowIndex) - * // bLabels := ds.TrainLabels.Idx(narrowIndex) - * bImages := imagesTs.Idx(narrowIndex) - * bLabels := labelsTs.Idx(narrowIndex) + * bImages := ds.TrainImages.Idx(narrowIndex) + * bLabels := ds.TrainLabels.Idx(narrowIndex) + * // bImages := imagesTs.Idx(narrowIndex) + * // bLabels := labelsTs.Idx(narrowIndex) * * logits := bImages.MustMm(ws).MustAdd(bs) - * // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels) - * loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels) + * loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels).MustSetRequiresGrad(true) * * ws.ZeroGrad() * bs.ZeroGrad() - * loss.Backward() + * loss.MustBackward() * - * bs.MustGrad().Print() + * // TODO: why `loss` need to print out to get updated? + * fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0)) * * ts.NoGrad(func() { * ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) @@ -81,31 +80,21 @@ func runLinear() { * */ logits := ds.TrainImages.MustMm(ws).MustAdd(bs) - // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true) - loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels) - // loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true) + loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true) ws.ZeroGrad() bs.ZeroGrad() - // loss.MustBackward() - loss.Backward() - - // TODO: why `loss` need to print out to get updated? - fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0)) - // fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1)) + loss.MustBackward() ts.NoGrad(func() { ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) }) - // fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1)) - // fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1)) - testLogits := ds.TestImages.MustMm(ws).MustAdd(bs) testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0}) - // testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0}) - // - fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100) + + lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0}) + fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100) } } diff --git a/nn/varstore.go b/nn/varstore.go index 4c51ba8..655d222 100644 --- a/nn/varstore.go +++ b/nn/varstore.go @@ -160,15 +160,9 @@ func (vs *VarStore) Load(filepath string) (err error) { return err } - retValErr, err := ts.NoGrad(func() { + ts.NoGrad(func() { ts.Copy_(currTs, namedTs.Tensor) }) - if err != nil { - return err - } - if retValErr != nil { - return retValErr.(error) - } } return nil @@ -207,15 +201,9 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) { } // It's matched. Now, copy in-place the loaded tensor value to var-store - retValErr, err := ts.NoGrad(func() { + ts.NoGrad(func() { ts.Copy_(currTs, namedTs.Tensor) }) - if err != nil { - return nil, err - } - if retValErr != nil { - return nil, retValErr.(error) - } } return missingVariables, nil @@ -278,15 +266,9 @@ func (vs *VarStore) Copy(src VarStore) (err error) { if err != nil { return err } - retValErr, err := ts.NoGrad(func() { + ts.NoGrad(func() { ts.Copy_(v, srcDevTs) }) - if err != nil { - return err - } - if retValErr != nil { - return retValErr.(error) - } } return nil @@ -548,15 +530,9 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) { } v := p.Zeros(name, size) - retValErr, err := ts.NoGrad(func() { + ts.NoGrad(func() { ts.Copy_(v, t) }) - if err != nil { - log.Fatal(err) - } - if retValErr != nil { - log.Fatal(retValErr) - } return v } @@ -615,15 +591,9 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) { } v := e.OrZeros(size) - retValErr, err := ts.NoGrad(func() { + ts.NoGrad(func() { ts.Copy_(v, tensor) }) - if err != nil { - log.Fatal(err) - } - if retValErr != nil { - log.Fatal(retValErr) - } return v } diff --git a/tensor/tensor.go b/tensor/tensor.go index 8ac6790..271cbfb 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -376,6 +376,15 @@ func (ts Tensor) RequiresGrad() (retVal bool, err error) { return retVal, nil } +func (ts Tensor) MustRequiresGrad() (retVal bool) { + retVal, err := ts.RequiresGrad() + if err != nil { + log.Fatal(err) + } + + return retVal +} + // DataPtr returns the address of the first element of this tensor. func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) { @@ -899,7 +908,15 @@ func MustGradSetEnabled(b bool) (retVal bool) { } // NoGrad runs a closure without keeping track of gradients. -func NoGrad(fn interface{}) (retVal interface{}, err error) { +func NoGrad(fn interface{}) { + + // TODO: This is weird but somehow we need to trigger C++ print + // to get loss function updated. Probably it is related to + // C++ cache clearing. + // Next step would be creating a Go func that trigger C++ cache clean + // instead of this ugly hacky way. + newTs := NewTensor() + newTs.Drop() // Switch off Grad prev := MustGradSetEnabled(false) @@ -907,16 +924,15 @@ func NoGrad(fn interface{}) (retVal interface{}, err error) { // Analyze input as function. If not, throw error f, err := NewFunc(fn) if err != nil { - return retVal, nil + log.Fatal(err) } // invokes the function - retVal = f.Invoke() + f.Invoke() // Switch on Grad _ = MustGradSetEnabled(prev) - return retVal, nil } // NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.