fix(tensor/NoGrad): added a hacky way to get loss func updated
This commit is contained in:
parent
613cd93443
commit
20f338014d
19
README.md
19
README.md
|
@ -48,8 +48,27 @@
|
|||
- Other examples can be found at `example` folder
|
||||
|
||||
|
||||
### 3. Notes on running examples
|
||||
|
||||
- Clean Cgo cache to get a fresh build as [mention here](https://github.com/golang/go/issues/24355)
|
||||
|
||||
```bash
|
||||
# Either
|
||||
go clean -cache -testcache .
|
||||
go run [EXAMPLE FILES]
|
||||
|
||||
# Or
|
||||
go build -a
|
||||
|
||||
go run [EXAMPLE FILES]
|
||||
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
- This projects has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs)
|
||||
Libtorch Rust binding.
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -13,10 +13,8 @@ const (
|
|||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
|
||||
// epochs = 500
|
||||
// batchSize = 256
|
||||
epochs = 200
|
||||
batchSize = 60000
|
||||
batchSize = 256
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
|
@ -44,6 +42,7 @@ func runLinear() {
|
|||
*
|
||||
* batches := samples / batchSize
|
||||
* batchIndex := 0
|
||||
* var loss ts.Tensor
|
||||
* for i := 0; i < batches; i++ {
|
||||
* start := batchIndex * batchSize
|
||||
* size := batchSize
|
||||
|
@ -55,20 +54,20 @@ func runLinear() {
|
|||
*
|
||||
* // Indexing
|
||||
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||
* // bImages := ds.TrainImages.Idx(narrowIndex)
|
||||
* // bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||
* bImages := imagesTs.Idx(narrowIndex)
|
||||
* bLabels := labelsTs.Idx(narrowIndex)
|
||||
* bImages := ds.TrainImages.Idx(narrowIndex)
|
||||
* bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||
* // bImages := imagesTs.Idx(narrowIndex)
|
||||
* // bLabels := labelsTs.Idx(narrowIndex)
|
||||
*
|
||||
* logits := bImages.MustMm(ws).MustAdd(bs)
|
||||
* // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
||||
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
||||
* loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels).MustSetRequiresGrad(true)
|
||||
*
|
||||
* ws.ZeroGrad()
|
||||
* bs.ZeroGrad()
|
||||
* loss.Backward()
|
||||
* loss.MustBackward()
|
||||
*
|
||||
* bs.MustGrad().Print()
|
||||
* // TODO: why `loss` need to print out to get updated?
|
||||
* fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
|
||||
*
|
||||
* ts.NoGrad(func() {
|
||||
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
|
@ -81,31 +80,21 @@ func runLinear() {
|
|||
* */
|
||||
|
||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
// loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
// loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
||||
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
// loss.MustBackward()
|
||||
loss.Backward()
|
||||
|
||||
// TODO: why `loss` need to print out to get updated?
|
||||
fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
|
||||
// fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1))
|
||||
loss.MustBackward()
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
})
|
||||
|
||||
// fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1))
|
||||
// fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1))
|
||||
|
||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
// testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
//
|
||||
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
|
||||
|
||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -160,15 +160,9 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.NoGrad(func() {
|
||||
ts.Copy_(currTs, namedTs.Tensor)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if retValErr != nil {
|
||||
return retValErr.(error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -207,15 +201,9 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
|||
}
|
||||
|
||||
// It's matched. Now, copy in-place the loaded tensor value to var-store
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.NoGrad(func() {
|
||||
ts.Copy_(currTs, namedTs.Tensor)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if retValErr != nil {
|
||||
return nil, retValErr.(error)
|
||||
}
|
||||
}
|
||||
|
||||
return missingVariables, nil
|
||||
|
@ -278,15 +266,9 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.NoGrad(func() {
|
||||
ts.Copy_(v, srcDevTs)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if retValErr != nil {
|
||||
return retValErr.(error)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -548,15 +530,9 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
|||
}
|
||||
v := p.Zeros(name, size)
|
||||
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.NoGrad(func() {
|
||||
ts.Copy_(v, t)
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if retValErr != nil {
|
||||
log.Fatal(retValErr)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
@ -615,15 +591,9 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) {
|
|||
}
|
||||
v := e.OrZeros(size)
|
||||
|
||||
retValErr, err := ts.NoGrad(func() {
|
||||
ts.NoGrad(func() {
|
||||
ts.Copy_(v, tensor)
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if retValErr != nil {
|
||||
log.Fatal(retValErr)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
|
|
@ -376,6 +376,15 @@ func (ts Tensor) RequiresGrad() (retVal bool, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustRequiresGrad() (retVal bool) {
|
||||
retVal, err := ts.RequiresGrad()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// DataPtr returns the address of the first element of this tensor.
|
||||
func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
|
||||
|
||||
|
@ -899,7 +908,15 @@ func MustGradSetEnabled(b bool) (retVal bool) {
|
|||
}
|
||||
|
||||
// NoGrad runs a closure without keeping track of gradients.
|
||||
func NoGrad(fn interface{}) (retVal interface{}, err error) {
|
||||
func NoGrad(fn interface{}) {
|
||||
|
||||
// TODO: This is weird but somehow we need to trigger C++ print
|
||||
// to get loss function updated. Probably it is related to
|
||||
// C++ cache clearing.
|
||||
// Next step would be creating a Go func that trigger C++ cache clean
|
||||
// instead of this ugly hacky way.
|
||||
newTs := NewTensor()
|
||||
newTs.Drop()
|
||||
|
||||
// Switch off Grad
|
||||
prev := MustGradSetEnabled(false)
|
||||
|
@ -907,16 +924,15 @@ func NoGrad(fn interface{}) (retVal interface{}, err error) {
|
|||
// Analyze input as function. If not, throw error
|
||||
f, err := NewFunc(fn)
|
||||
if err != nil {
|
||||
return retVal, nil
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// invokes the function
|
||||
retVal = f.Invoke()
|
||||
f.Invoke()
|
||||
|
||||
// Switch on Grad
|
||||
_ = MustGradSetEnabled(prev)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
|
||||
|
|
Loading…
Reference in New Issue
Block a user