fix(tensor/NoGrad): added a hacky way to get loss func updated

This commit is contained in:
sugarme 2020-06-17 16:24:27 +10:00
parent 613cd93443
commit 20f338014d
4 changed files with 59 additions and 65 deletions

View File

@ -48,8 +48,27 @@
- Other examples can be found at `example` folder
### 3. Notes on running examples
- Clean Cgo cache to get a fresh build as [mention here](https://github.com/golang/go/issues/24355)
```bash
# Either
go clean -cache -testcache .
go run [EXAMPLE FILES]
# Or
go build -a
go run [EXAMPLE FILES]
```
## Acknowledgement
- This projects has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs)
Libtorch Rust binding.

View File

@ -13,10 +13,8 @@ const (
Label int64 = 10
MnistDir string = "../../data/mnist"
// epochs = 500
// batchSize = 256
epochs = 200
batchSize = 60000
batchSize = 256
)
func runLinear() {
@ -44,6 +42,7 @@ func runLinear() {
*
* batches := samples / batchSize
* batchIndex := 0
* var loss ts.Tensor
* for i := 0; i < batches; i++ {
* start := batchIndex * batchSize
* size := batchSize
@ -55,20 +54,20 @@ func runLinear() {
*
* // Indexing
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
* // bImages := ds.TrainImages.Idx(narrowIndex)
* // bLabels := ds.TrainLabels.Idx(narrowIndex)
* bImages := imagesTs.Idx(narrowIndex)
* bLabels := labelsTs.Idx(narrowIndex)
* bImages := ds.TrainImages.Idx(narrowIndex)
* bLabels := ds.TrainLabels.Idx(narrowIndex)
* // bImages := imagesTs.Idx(narrowIndex)
* // bLabels := labelsTs.Idx(narrowIndex)
*
* logits := bImages.MustMm(ws).MustAdd(bs)
* // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
* loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels).MustSetRequiresGrad(true)
*
* ws.ZeroGrad()
* bs.ZeroGrad()
* loss.Backward()
* loss.MustBackward()
*
* bs.MustGrad().Print()
* // TODO: why `loss` need to print out to get updated?
* fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
*
* ts.NoGrad(func() {
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
@ -81,31 +80,21 @@ func runLinear() {
* */
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
// loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
// loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
ws.ZeroGrad()
bs.ZeroGrad()
// loss.MustBackward()
loss.Backward()
// TODO: why `loss` need to print out to get updated?
fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
// fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1))
loss.MustBackward()
ts.NoGrad(func() {
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
})
// fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1))
// fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1))
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
// testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
//
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
}
}

View File

@ -160,15 +160,9 @@ func (vs *VarStore) Load(filepath string) (err error) {
return err
}
retValErr, err := ts.NoGrad(func() {
ts.NoGrad(func() {
ts.Copy_(currTs, namedTs.Tensor)
})
if err != nil {
return err
}
if retValErr != nil {
return retValErr.(error)
}
}
return nil
@ -207,15 +201,9 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
}
// It's matched. Now, copy in-place the loaded tensor value to var-store
retValErr, err := ts.NoGrad(func() {
ts.NoGrad(func() {
ts.Copy_(currTs, namedTs.Tensor)
})
if err != nil {
return nil, err
}
if retValErr != nil {
return nil, retValErr.(error)
}
}
return missingVariables, nil
@ -278,15 +266,9 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
if err != nil {
return err
}
retValErr, err := ts.NoGrad(func() {
ts.NoGrad(func() {
ts.Copy_(v, srcDevTs)
})
if err != nil {
return err
}
if retValErr != nil {
return retValErr.(error)
}
}
return nil
@ -548,15 +530,9 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
}
v := p.Zeros(name, size)
retValErr, err := ts.NoGrad(func() {
ts.NoGrad(func() {
ts.Copy_(v, t)
})
if err != nil {
log.Fatal(err)
}
if retValErr != nil {
log.Fatal(retValErr)
}
return v
}
@ -615,15 +591,9 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) {
}
v := e.OrZeros(size)
retValErr, err := ts.NoGrad(func() {
ts.NoGrad(func() {
ts.Copy_(v, tensor)
})
if err != nil {
log.Fatal(err)
}
if retValErr != nil {
log.Fatal(retValErr)
}
return v
}

View File

@ -376,6 +376,15 @@ func (ts Tensor) RequiresGrad() (retVal bool, err error) {
return retVal, nil
}
func (ts Tensor) MustRequiresGrad() (retVal bool) {
retVal, err := ts.RequiresGrad()
if err != nil {
log.Fatal(err)
}
return retVal
}
// DataPtr returns the address of the first element of this tensor.
func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
@ -899,7 +908,15 @@ func MustGradSetEnabled(b bool) (retVal bool) {
}
// NoGrad runs a closure without keeping track of gradients.
func NoGrad(fn interface{}) (retVal interface{}, err error) {
func NoGrad(fn interface{}) {
// TODO: This is weird but somehow we need to trigger C++ print
// to get loss function updated. Probably it is related to
// C++ cache clearing.
// Next step would be creating a Go func that trigger C++ cache clean
// instead of this ugly hacky way.
newTs := NewTensor()
newTs.Drop()
// Switch off Grad
prev := MustGradSetEnabled(false)
@ -907,16 +924,15 @@ func NoGrad(fn interface{}) (retVal interface{}, err error) {
// Analyze input as function. If not, throw error
f, err := NewFunc(fn)
if err != nil {
return retVal, nil
log.Fatal(err)
}
// invokes the function
retVal = f.Invoke()
f.Invoke()
// Switch on Grad
_ = MustGradSetEnabled(prev)
return retVal, nil
}
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.