added example/jit-train with README.md and updated code

This commit is contained in:
sugarme 2021-01-02 20:34:05 +11:00
parent d6fb8d88d8
commit 0c9ab73736
4 changed files with 266 additions and 18 deletions

223
example/jit-train/README.md Normal file
View File

@ -0,0 +1,223 @@
# Load and train Pytorch Model in Go
This example demonstrates how to load a Python Pytorch model using Torch Script, then train model in Go.
- Step 1: convert Python Pytorch model to Torch Script. The detail can be found in [Pytorch tutorial](https://pytorch.org/tutorials/advanced/cpp_export.html). Below is an example of a MNIST model.
```python
import torch
from torch.nn import Module
import torch.nn.functional as F
class MNISTModule(Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(5, 5))
self.maxpool1 = torch.nn.MaxPool2d(2)
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(5, 5))
self.maxpool2 = torch.nn.MaxPool2d(2)
self.linear1 = torch.nn.Linear(1024, 1024)
self.dropout = torch.nn.Dropout(0.5)
self.linear2 = torch.nn.Linear(1024, 10)
def forward(self, x):
x = x.view(-1, 1, 28, 28)
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x).view(-1, 1024)
x = self.linear1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.linear1(x)
return self.linear2(x)
traced_script_module = torch.jit.script(MNISTModule())
traced_script_module.save("model.pt")
```
- Step 2: Load Torch Model and continue train/fine-tune in Go. After training, model can be saved in Torch Script format so that it can be either loaded in Go, Python, or any supported Pytorch binding languages.
```go
func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
file := "./model.pt"
vs := nn.NewVarStore(device)
trainable, err := nn.TrainableCModuleLoad(vs.Root(), file)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Trainable JIT model loaded.\n")
namedTensors, err := trainable.Inner.NamedParameters()
if err != nil {
log.Fatal(err)
}
for _, x := range namedTensors {
fmt.Println(x.Name)
}
trainable.SetTrain()
bestAccuracy := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, device, 1024)
fmt.Printf("Initial Accuracy: %0.4f\n", bestAccuracy)
opt, err := nn.DefaultAdamConfig().Build(vs, 1e-4)
if err != nil {
log.Fatal(err)
}
for epoch := 0; epoch < epochs; epoch++ {
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
batches := samples / batchSize
batchIndex := 0
var epocLoss *ts.Tensor
for i := 0; i < batches; i++ {
start := batchIndex * batchSize
size := batchSize
if samples-start < batchSize {
break
}
batchIndex += 1
// Indexing
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
bImages := imagesTs.Idx(narrowIndex)
bLabels := labelsTs.Idx(narrowIndex)
bImages = bImages.MustTo(vs.Device(), true)
bLabels = bLabels.MustTo(vs.Device(), true)
logits := trainable.ForwardT(bImages, true)
loss := logits.CrossEntropyForLogits(bLabels)
opt.BackwardStep(loss)
epocLoss = loss.MustShallowClone()
epocLoss.Detach_()
bImages.MustDrop()
bLabels.MustDrop()
}
testAccuracy := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, vs.Device(), 1024)
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
if testAccuracy > bestAccuracy {
bestAccuracy = testAccuracy
}
epocLoss.MustDrop()
imagesTs.MustDrop()
labelsTs.MustDrop()
}
err = trainable.Save("trained-model.pt")
if err != nil {
log.Fatal(err)
}
fmt.Printf("Completed training. Best accuracy: %0.4f\n", bestAccuracy)
}
```
- Further step: trained model can be loaded and evaluated.
```go
func loadTrainedAndTestAcc(ds *vision.Dataset, device gotch.Device) {
vs := nn.NewVarStore(device)
m, err := nn.TrainableCModuleLoad(vs.Root(), "./trained-model.pt")
if err != nil {
log.Fatal(err)
}
m.SetEval()
acc := nn.BatchAccuracyForLogits(vs, m, ds.TestImages, ds.TestLabels, device, 1024)
fmt.Printf("Accuracy: %0.4f\n", acc)
}
```
See MNIST example for how to access MNIST dataset.
Below is a session of training and evaluate outputs:
```bash
go run .
Trainable JIT model loaded.
conv1.weight
conv1.bias
conv2.weight
conv2.bias
linear1.weight
linear1.bias
linear2.weight
linear2.bias
Initial Accuracy: 0.1122
Epoch: 0 Loss: 0.20 Test accuracy: 93.22%
Epoch: 1 Loss: 0.21 Test accuracy: 96.14%
Epoch: 2 Loss: 0.07 Test accuracy: 97.49%
Epoch: 3 Loss: 0.07 Test accuracy: 98.00%
Epoch: 4 Loss: 0.04 Test accuracy: 98.17%
Epoch: 5 Loss: 0.06 Test accuracy: 98.34%
Epoch: 6 Loss: 0.03 Test accuracy: 98.59%
Epoch: 7 Loss: 0.08 Test accuracy: 98.62%
Epoch: 8 Loss: 0.01 Test accuracy: 98.54%
Epoch: 9 Loss: 0.08 Test accuracy: 98.75%
Epoch: 10 Loss: 0.07 Test accuracy: 98.88%
Epoch: 11 Loss: 0.05 Test accuracy: 98.74%
Epoch: 12 Loss: 0.03 Test accuracy: 98.80%
Epoch: 13 Loss: 0.02 Test accuracy: 98.91%
Epoch: 14 Loss: 0.02 Test accuracy: 98.99%
Epoch: 15 Loss: 0.01 Test accuracy: 98.90%
Epoch: 16 Loss: 0.02 Test accuracy: 98.90%
Epoch: 17 Loss: 0.02 Test accuracy: 98.87%
Epoch: 18 Loss: 0.05 Test accuracy: 99.00%
Epoch: 19 Loss: 0.03 Test accuracy: 98.96%
Epoch: 20 Loss: 0.01 Test accuracy: 98.98%
Epoch: 21 Loss: 0.03 Test accuracy: 99.02%
Epoch: 22 Loss: 0.02 Test accuracy: 98.95%
Epoch: 23 Loss: 0.02 Test accuracy: 98.99%
Epoch: 24 Loss: 0.02 Test accuracy: 98.96%
Epoch: 25 Loss: 0.01 Test accuracy: 99.15%
Epoch: 26 Loss: 0.01 Test accuracy: 98.97%
Epoch: 27 Loss: 0.01 Test accuracy: 99.03%
Epoch: 28 Loss: 0.03 Test accuracy: 99.09%
Epoch: 29 Loss: 0.01 Test accuracy: 99.05%
Epoch: 30 Loss: 0.00 Test accuracy: 98.97%
Epoch: 31 Loss: 0.00 Test accuracy: 99.01%
Epoch: 32 Loss: 0.00 Test accuracy: 99.08%
Epoch: 33 Loss: 0.00 Test accuracy: 98.93%
Epoch: 34 Loss: 0.01 Test accuracy: 98.86%
Epoch: 35 Loss: 0.00 Test accuracy: 98.94%
Epoch: 36 Loss: 0.01 Test accuracy: 98.96%
Epoch: 37 Loss: 0.00 Test accuracy: 99.01%
Epoch: 38 Loss: 0.00 Test accuracy: 99.03%
Epoch: 39 Loss: 0.00 Test accuracy: 99.14%
Epoch: 40 Loss: 0.00 Test accuracy: 99.06%
Epoch: 41 Loss: 0.00 Test accuracy: 99.01%
Epoch: 42 Loss: 0.00 Test accuracy: 99.01%
Epoch: 43 Loss: 0.01 Test accuracy: 99.01%
Epoch: 44 Loss: 0.00 Test accuracy: 98.98%
Epoch: 45 Loss: 0.02 Test accuracy: 99.03%
Epoch: 46 Loss: 0.00 Test accuracy: 99.14%
Epoch: 47 Loss: 0.00 Test accuracy: 99.11%
Epoch: 48 Loss: 0.00 Test accuracy: 98.84%
Epoch: 49 Loss: 0.00 Test accuracy: 98.93%
Completed training. Best accuracy: 0.9915
```
```bash
go run . -task=infer
Accuracy: 0.9915
```

View File

@ -1,6 +1,7 @@
package main
import (
"flag"
"fmt"
"log"
@ -10,18 +11,44 @@ import (
"github.com/sugarme/gotch/vision"
)
func main() {
ds := vision.LoadMNISTDir("../../data/mnist")
dataset := &vision.Dataset{
TestImages: ds.TestImages.MustView([]int64{-1, 1, 28, 28}, true),
TrainImages: ds.TrainImages.MustView([]int64{-1, 1, 28, 28}, true),
TestLabels: ds.TestLabels,
TrainLabels: ds.TrainLabels,
}
device := gotch.CudaIfAvailable()
var (
task string
batchSize int
epochs int
cuda bool
)
// runTrainAndSaveModel(dataset, device)
loadTrainedAndTestAcc(dataset, device)
func init() {
flag.StringVar(&task, "task", "train", "specify task to run. Ie. 'train', 'infer'")
flag.IntVar(&batchSize, "batch", 256, "Specify batch size.")
flag.IntVar(&epochs, "epoch", 50, "Specify number of epochs to train.")
flag.BoolVar(&cuda, "cuda", true, "Specify whether using CUDA(default=true) or CPU. ")
}
func main() {
flag.Parse()
ds := vision.LoadMNISTDir("../../data/mnist")
// dataset := &vision.Dataset{
// TestImages: ds.TestImages.MustView([]int64{-1, 1, 28, 28}, true),
// TrainImages: ds.TrainImages.MustView([]int64{-1, 1, 28, 28}, true),
// TestLabels: ds.TestLabels,
// TrainLabels: ds.TrainLabels,
// }
var device gotch.Device = gotch.CPU
if cuda {
device = gotch.CudaIfAvailable()
}
switch task {
case "train":
runTrainAndSaveModel(ds, device)
case "infer":
loadTrainedAndTestAcc(ds, device)
default:
log.Fatalf("Invalid task: %v. Task can be 'train' or 'infer' only. ", task)
}
}
func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
@ -44,16 +71,14 @@ func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
}
trainable.SetTrain()
initialAcc := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, device, 1024)
fmt.Printf("Initial Accuracy: %0.4f\n", initialAcc)
bestAccuracy := initialAcc
bestAccuracy := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, device, 1024)
fmt.Printf("Initial Accuracy: %0.4f\n", bestAccuracy)
opt, err := nn.DefaultAdamConfig().Build(vs, 1e-4)
if err != nil {
log.Fatal(err)
}
batchSize := 128
for epoch := 0; epoch < 20; epoch++ {
for epoch := 0; epoch < epochs; epoch++ {
totalSize := ds.TrainImages.MustSize()[0]
samples := int(totalSize)
@ -88,8 +113,6 @@ func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
epocLoss = loss.MustShallowClone()
epocLoss.Detach_()
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])
bImages.MustDrop()
bLabels.MustDrop()
}
@ -109,6 +132,8 @@ func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
if err != nil {
log.Fatal(err)
}
fmt.Printf("Completed training. Best accuracy: %0.4f\n", bestAccuracy)
}
func loadTrainedAndTestAcc(ds *vision.Dataset, device gotch.Device) {

Binary file not shown.

Binary file not shown.