fix(tensor/NoGrad): added a hacky way to get loss func updated
This commit is contained in:
parent
613cd93443
commit
20f338014d
19
README.md
19
README.md
|
@ -48,8 +48,27 @@
|
||||||
- Other examples can be found at `example` folder
|
- Other examples can be found at `example` folder
|
||||||
|
|
||||||
|
|
||||||
|
### 3. Notes on running examples
|
||||||
|
|
||||||
|
- Clean Cgo cache to get a fresh build as [mention here](https://github.com/golang/go/issues/24355)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Either
|
||||||
|
go clean -cache -testcache .
|
||||||
|
go run [EXAMPLE FILES]
|
||||||
|
|
||||||
|
# Or
|
||||||
|
go build -a
|
||||||
|
|
||||||
|
go run [EXAMPLE FILES]
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
- This projects has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs)
|
- This projects has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs)
|
||||||
Libtorch Rust binding.
|
Libtorch Rust binding.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,10 +13,8 @@ const (
|
||||||
Label int64 = 10
|
Label int64 = 10
|
||||||
MnistDir string = "../../data/mnist"
|
MnistDir string = "../../data/mnist"
|
||||||
|
|
||||||
// epochs = 500
|
|
||||||
// batchSize = 256
|
|
||||||
epochs = 200
|
epochs = 200
|
||||||
batchSize = 60000
|
batchSize = 256
|
||||||
)
|
)
|
||||||
|
|
||||||
func runLinear() {
|
func runLinear() {
|
||||||
|
@ -44,6 +42,7 @@ func runLinear() {
|
||||||
*
|
*
|
||||||
* batches := samples / batchSize
|
* batches := samples / batchSize
|
||||||
* batchIndex := 0
|
* batchIndex := 0
|
||||||
|
* var loss ts.Tensor
|
||||||
* for i := 0; i < batches; i++ {
|
* for i := 0; i < batches; i++ {
|
||||||
* start := batchIndex * batchSize
|
* start := batchIndex * batchSize
|
||||||
* size := batchSize
|
* size := batchSize
|
||||||
|
@ -55,20 +54,20 @@ func runLinear() {
|
||||||
*
|
*
|
||||||
* // Indexing
|
* // Indexing
|
||||||
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
* narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||||
* // bImages := ds.TrainImages.Idx(narrowIndex)
|
* bImages := ds.TrainImages.Idx(narrowIndex)
|
||||||
* // bLabels := ds.TrainLabels.Idx(narrowIndex)
|
* bLabels := ds.TrainLabels.Idx(narrowIndex)
|
||||||
* bImages := imagesTs.Idx(narrowIndex)
|
* // bImages := imagesTs.Idx(narrowIndex)
|
||||||
* bLabels := labelsTs.Idx(narrowIndex)
|
* // bLabels := labelsTs.Idx(narrowIndex)
|
||||||
*
|
*
|
||||||
* logits := bImages.MustMm(ws).MustAdd(bs)
|
* logits := bImages.MustMm(ws).MustAdd(bs)
|
||||||
* // loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
* loss = logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels).MustSetRequiresGrad(true)
|
||||||
* loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(bLabels)
|
|
||||||
*
|
*
|
||||||
* ws.ZeroGrad()
|
* ws.ZeroGrad()
|
||||||
* bs.ZeroGrad()
|
* bs.ZeroGrad()
|
||||||
* loss.Backward()
|
* loss.MustBackward()
|
||||||
*
|
*
|
||||||
* bs.MustGrad().Print()
|
* // TODO: why `loss` need to print out to get updated?
|
||||||
|
* fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
|
||||||
*
|
*
|
||||||
* ts.NoGrad(func() {
|
* ts.NoGrad(func() {
|
||||||
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
* ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||||
|
@ -81,31 +80,21 @@ func runLinear() {
|
||||||
* */
|
* */
|
||||||
|
|
||||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||||
// loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
||||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
|
||||||
// loss := ds.TrainImages.MustMm(ws).MustAdd(bs).MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels).MustSetRequiresGrad(true)
|
|
||||||
|
|
||||||
ws.ZeroGrad()
|
ws.ZeroGrad()
|
||||||
bs.ZeroGrad()
|
bs.ZeroGrad()
|
||||||
// loss.MustBackward()
|
loss.MustBackward()
|
||||||
loss.Backward()
|
|
||||||
|
|
||||||
// TODO: why `loss` need to print out to get updated?
|
|
||||||
fmt.Printf("loss (epoch %v): %v\n", epoch, loss.MustToString(0))
|
|
||||||
// fmt.Printf("bs grad (epoch %v): %v\n", epoch, bs.MustGrad().MustToString(1))
|
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||||
})
|
})
|
||||||
|
|
||||||
// fmt.Printf("bs(epoch %v): \n%v\n", epoch, bs.MustToString(1))
|
|
||||||
// fmt.Printf("ws mean(epoch %v): \n%v\n", epoch, ws.MustMean(gotch.Float.CInt()).MustToString(1))
|
|
||||||
|
|
||||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
// testAccuracy := ds.TestImages.MustMm(ws).MustAdd(bs).MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
|
||||||
//
|
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
fmt.Printf("Epoch: %v - Test accuracy: %v\n", epoch, testAccuracy*100)
|
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -160,15 +160,9 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
retValErr, err := ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
ts.Copy_(currTs, namedTs.Tensor)
|
ts.Copy_(currTs, namedTs.Tensor)
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if retValErr != nil {
|
|
||||||
return retValErr.(error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -207,15 +201,9 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// It's matched. Now, copy in-place the loaded tensor value to var-store
|
// It's matched. Now, copy in-place the loaded tensor value to var-store
|
||||||
retValErr, err := ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
ts.Copy_(currTs, namedTs.Tensor)
|
ts.Copy_(currTs, namedTs.Tensor)
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if retValErr != nil {
|
|
||||||
return nil, retValErr.(error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return missingVariables, nil
|
return missingVariables, nil
|
||||||
|
@ -278,15 +266,9 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
retValErr, err := ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
ts.Copy_(v, srcDevTs)
|
ts.Copy_(v, srcDevTs)
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if retValErr != nil {
|
|
||||||
return retValErr.(error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -548,15 +530,9 @@ func (p *Path) VarCopy(name string, t ts.Tensor) (retVal ts.Tensor) {
|
||||||
}
|
}
|
||||||
v := p.Zeros(name, size)
|
v := p.Zeros(name, size)
|
||||||
|
|
||||||
retValErr, err := ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
ts.Copy_(v, t)
|
ts.Copy_(v, t)
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
if retValErr != nil {
|
|
||||||
log.Fatal(retValErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
@ -615,15 +591,9 @@ func (e *Entry) OrVarCopy(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||||
}
|
}
|
||||||
v := e.OrZeros(size)
|
v := e.OrZeros(size)
|
||||||
|
|
||||||
retValErr, err := ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
ts.Copy_(v, tensor)
|
ts.Copy_(v, tensor)
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
if retValErr != nil {
|
|
||||||
log.Fatal(retValErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
|
@ -376,6 +376,15 @@ func (ts Tensor) RequiresGrad() (retVal bool, err error) {
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustRequiresGrad() (retVal bool) {
|
||||||
|
retVal, err := ts.RequiresGrad()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
// DataPtr returns the address of the first element of this tensor.
|
// DataPtr returns the address of the first element of this tensor.
|
||||||
func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
|
func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
|
||||||
|
|
||||||
|
@ -899,7 +908,15 @@ func MustGradSetEnabled(b bool) (retVal bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// NoGrad runs a closure without keeping track of gradients.
|
// NoGrad runs a closure without keeping track of gradients.
|
||||||
func NoGrad(fn interface{}) (retVal interface{}, err error) {
|
func NoGrad(fn interface{}) {
|
||||||
|
|
||||||
|
// TODO: This is weird but somehow we need to trigger C++ print
|
||||||
|
// to get loss function updated. Probably it is related to
|
||||||
|
// C++ cache clearing.
|
||||||
|
// Next step would be creating a Go func that trigger C++ cache clean
|
||||||
|
// instead of this ugly hacky way.
|
||||||
|
newTs := NewTensor()
|
||||||
|
newTs.Drop()
|
||||||
|
|
||||||
// Switch off Grad
|
// Switch off Grad
|
||||||
prev := MustGradSetEnabled(false)
|
prev := MustGradSetEnabled(false)
|
||||||
|
@ -907,16 +924,15 @@ func NoGrad(fn interface{}) (retVal interface{}, err error) {
|
||||||
// Analyze input as function. If not, throw error
|
// Analyze input as function. If not, throw error
|
||||||
f, err := NewFunc(fn)
|
f, err := NewFunc(fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retVal, nil
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// invokes the function
|
// invokes the function
|
||||||
retVal = f.Invoke()
|
f.Invoke()
|
||||||
|
|
||||||
// Switch on Grad
|
// Switch on Grad
|
||||||
_ = MustGradSetEnabled(prev)
|
_ = MustGradSetEnabled(prev)
|
||||||
|
|
||||||
return retVal, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
|
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user