added example/jit-train with README.md and updated code
This commit is contained in:
parent
d6fb8d88d8
commit
0c9ab73736
223
example/jit-train/README.md
Normal file
223
example/jit-train/README.md
Normal file
|
@ -0,0 +1,223 @@
|
|||
# Load and train Pytorch Model in Go
|
||||
|
||||
This example demonstrates how to load a Python Pytorch model using Torch Script, then train model in Go.
|
||||
|
||||
- Step 1: convert Python Pytorch model to Torch Script. The detail can be found in [Pytorch tutorial](https://pytorch.org/tutorials/advanced/cpp_export.html). Below is an example of a MNIST model.
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
import torch.nn.functional as F
|
||||
|
||||
class MNISTModule(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(5, 5))
|
||||
self.maxpool1 = torch.nn.MaxPool2d(2)
|
||||
self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(5, 5))
|
||||
self.maxpool2 = torch.nn.MaxPool2d(2)
|
||||
self.linear1 = torch.nn.Linear(1024, 1024)
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
self.linear2 = torch.nn.Linear(1024, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(-1, 1, 28, 28)
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.maxpool2(x).view(-1, 1024)
|
||||
x = self.linear1(x)
|
||||
x = F.relu(x)
|
||||
x = self.dropout(x)
|
||||
x = self.linear1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
traced_script_module = torch.jit.script(MNISTModule())
|
||||
traced_script_module.save("model.pt")
|
||||
|
||||
```
|
||||
|
||||
- Step 2: Load Torch Model and continue train/fine-tune in Go. After training, model can be saved in Torch Script format so that it can be either loaded in Go, Python, or any supported Pytorch binding languages.
|
||||
|
||||
```go
|
||||
func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
|
||||
|
||||
file := "./model.pt"
|
||||
vs := nn.NewVarStore(device)
|
||||
trainable, err := nn.TrainableCModuleLoad(vs.Root(), file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Trainable JIT model loaded.\n")
|
||||
|
||||
namedTensors, err := trainable.Inner.NamedParameters()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, x := range namedTensors {
|
||||
fmt.Println(x.Name)
|
||||
}
|
||||
|
||||
trainable.SetTrain()
|
||||
bestAccuracy := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, device, 1024)
|
||||
fmt.Printf("Initial Accuracy: %0.4f\n", bestAccuracy)
|
||||
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, 1e-4)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
|
||||
totalSize := ds.TrainImages.MustSize()[0]
|
||||
samples := int(totalSize)
|
||||
index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)
|
||||
imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)
|
||||
labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)
|
||||
|
||||
batches := samples / batchSize
|
||||
batchIndex := 0
|
||||
var epocLoss *ts.Tensor
|
||||
for i := 0; i < batches; i++ {
|
||||
start := batchIndex * batchSize
|
||||
size := batchSize
|
||||
if samples-start < batchSize {
|
||||
break
|
||||
}
|
||||
batchIndex += 1
|
||||
|
||||
// Indexing
|
||||
narrowIndex := ts.NewNarrow(int64(start), int64(start+size))
|
||||
bImages := imagesTs.Idx(narrowIndex)
|
||||
bLabels := labelsTs.Idx(narrowIndex)
|
||||
|
||||
bImages = bImages.MustTo(vs.Device(), true)
|
||||
bLabels = bLabels.MustTo(vs.Device(), true)
|
||||
|
||||
logits := trainable.ForwardT(bImages, true)
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
epocLoss = loss.MustShallowClone()
|
||||
epocLoss.Detach_()
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
}
|
||||
|
||||
testAccuracy := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, vs.Device(), 1024)
|
||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)
|
||||
if testAccuracy > bestAccuracy {
|
||||
bestAccuracy = testAccuracy
|
||||
}
|
||||
|
||||
epocLoss.MustDrop()
|
||||
imagesTs.MustDrop()
|
||||
labelsTs.MustDrop()
|
||||
}
|
||||
|
||||
err = trainable.Save("trained-model.pt")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Completed training. Best accuracy: %0.4f\n", bestAccuracy)
|
||||
}
|
||||
```
|
||||
|
||||
- Further step: trained model can be loaded and evaluated.
|
||||
|
||||
```go
|
||||
func loadTrainedAndTestAcc(ds *vision.Dataset, device gotch.Device) {
|
||||
vs := nn.NewVarStore(device)
|
||||
m, err := nn.TrainableCModuleLoad(vs.Root(), "./trained-model.pt")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
m.SetEval()
|
||||
acc := nn.BatchAccuracyForLogits(vs, m, ds.TestImages, ds.TestLabels, device, 1024)
|
||||
|
||||
fmt.Printf("Accuracy: %0.4f\n", acc)
|
||||
}
|
||||
```
|
||||
|
||||
See MNIST example for how to access MNIST dataset.
|
||||
|
||||
Below is a session of training and evaluate outputs:
|
||||
|
||||
```bash
|
||||
go run .
|
||||
Trainable JIT model loaded.
|
||||
conv1.weight
|
||||
conv1.bias
|
||||
conv2.weight
|
||||
conv2.bias
|
||||
linear1.weight
|
||||
linear1.bias
|
||||
linear2.weight
|
||||
linear2.bias
|
||||
Initial Accuracy: 0.1122
|
||||
Epoch: 0 Loss: 0.20 Test accuracy: 93.22%
|
||||
Epoch: 1 Loss: 0.21 Test accuracy: 96.14%
|
||||
Epoch: 2 Loss: 0.07 Test accuracy: 97.49%
|
||||
Epoch: 3 Loss: 0.07 Test accuracy: 98.00%
|
||||
Epoch: 4 Loss: 0.04 Test accuracy: 98.17%
|
||||
Epoch: 5 Loss: 0.06 Test accuracy: 98.34%
|
||||
Epoch: 6 Loss: 0.03 Test accuracy: 98.59%
|
||||
Epoch: 7 Loss: 0.08 Test accuracy: 98.62%
|
||||
Epoch: 8 Loss: 0.01 Test accuracy: 98.54%
|
||||
Epoch: 9 Loss: 0.08 Test accuracy: 98.75%
|
||||
Epoch: 10 Loss: 0.07 Test accuracy: 98.88%
|
||||
Epoch: 11 Loss: 0.05 Test accuracy: 98.74%
|
||||
Epoch: 12 Loss: 0.03 Test accuracy: 98.80%
|
||||
Epoch: 13 Loss: 0.02 Test accuracy: 98.91%
|
||||
Epoch: 14 Loss: 0.02 Test accuracy: 98.99%
|
||||
Epoch: 15 Loss: 0.01 Test accuracy: 98.90%
|
||||
Epoch: 16 Loss: 0.02 Test accuracy: 98.90%
|
||||
Epoch: 17 Loss: 0.02 Test accuracy: 98.87%
|
||||
Epoch: 18 Loss: 0.05 Test accuracy: 99.00%
|
||||
Epoch: 19 Loss: 0.03 Test accuracy: 98.96%
|
||||
Epoch: 20 Loss: 0.01 Test accuracy: 98.98%
|
||||
Epoch: 21 Loss: 0.03 Test accuracy: 99.02%
|
||||
Epoch: 22 Loss: 0.02 Test accuracy: 98.95%
|
||||
Epoch: 23 Loss: 0.02 Test accuracy: 98.99%
|
||||
Epoch: 24 Loss: 0.02 Test accuracy: 98.96%
|
||||
Epoch: 25 Loss: 0.01 Test accuracy: 99.15%
|
||||
Epoch: 26 Loss: 0.01 Test accuracy: 98.97%
|
||||
Epoch: 27 Loss: 0.01 Test accuracy: 99.03%
|
||||
Epoch: 28 Loss: 0.03 Test accuracy: 99.09%
|
||||
Epoch: 29 Loss: 0.01 Test accuracy: 99.05%
|
||||
Epoch: 30 Loss: 0.00 Test accuracy: 98.97%
|
||||
Epoch: 31 Loss: 0.00 Test accuracy: 99.01%
|
||||
Epoch: 32 Loss: 0.00 Test accuracy: 99.08%
|
||||
Epoch: 33 Loss: 0.00 Test accuracy: 98.93%
|
||||
Epoch: 34 Loss: 0.01 Test accuracy: 98.86%
|
||||
Epoch: 35 Loss: 0.00 Test accuracy: 98.94%
|
||||
Epoch: 36 Loss: 0.01 Test accuracy: 98.96%
|
||||
Epoch: 37 Loss: 0.00 Test accuracy: 99.01%
|
||||
Epoch: 38 Loss: 0.00 Test accuracy: 99.03%
|
||||
Epoch: 39 Loss: 0.00 Test accuracy: 99.14%
|
||||
Epoch: 40 Loss: 0.00 Test accuracy: 99.06%
|
||||
Epoch: 41 Loss: 0.00 Test accuracy: 99.01%
|
||||
Epoch: 42 Loss: 0.00 Test accuracy: 99.01%
|
||||
Epoch: 43 Loss: 0.01 Test accuracy: 99.01%
|
||||
Epoch: 44 Loss: 0.00 Test accuracy: 98.98%
|
||||
Epoch: 45 Loss: 0.02 Test accuracy: 99.03%
|
||||
Epoch: 46 Loss: 0.00 Test accuracy: 99.14%
|
||||
Epoch: 47 Loss: 0.00 Test accuracy: 99.11%
|
||||
Epoch: 48 Loss: 0.00 Test accuracy: 98.84%
|
||||
Epoch: 49 Loss: 0.00 Test accuracy: 98.93%
|
||||
Completed training. Best accuracy: 0.9915
|
||||
```
|
||||
|
||||
```bash
|
||||
go run . -task=infer
|
||||
Accuracy: 0.9915
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
|
@ -10,18 +11,44 @@ import (
|
|||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
||||
dataset := &vision.Dataset{
|
||||
TestImages: ds.TestImages.MustView([]int64{-1, 1, 28, 28}, true),
|
||||
TrainImages: ds.TrainImages.MustView([]int64{-1, 1, 28, 28}, true),
|
||||
TestLabels: ds.TestLabels,
|
||||
TrainLabels: ds.TrainLabels,
|
||||
}
|
||||
device := gotch.CudaIfAvailable()
|
||||
var (
|
||||
task string
|
||||
batchSize int
|
||||
epochs int
|
||||
cuda bool
|
||||
)
|
||||
|
||||
// runTrainAndSaveModel(dataset, device)
|
||||
loadTrainedAndTestAcc(dataset, device)
|
||||
func init() {
|
||||
flag.StringVar(&task, "task", "train", "specify task to run. Ie. 'train', 'infer'")
|
||||
flag.IntVar(&batchSize, "batch", 256, "Specify batch size.")
|
||||
flag.IntVar(&epochs, "epoch", 50, "Specify number of epochs to train.")
|
||||
flag.BoolVar(&cuda, "cuda", true, "Specify whether using CUDA(default=true) or CPU. ")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
ds := vision.LoadMNISTDir("../../data/mnist")
|
||||
// dataset := &vision.Dataset{
|
||||
// TestImages: ds.TestImages.MustView([]int64{-1, 1, 28, 28}, true),
|
||||
// TrainImages: ds.TrainImages.MustView([]int64{-1, 1, 28, 28}, true),
|
||||
// TestLabels: ds.TestLabels,
|
||||
// TrainLabels: ds.TrainLabels,
|
||||
// }
|
||||
|
||||
var device gotch.Device = gotch.CPU
|
||||
if cuda {
|
||||
device = gotch.CudaIfAvailable()
|
||||
}
|
||||
|
||||
switch task {
|
||||
case "train":
|
||||
runTrainAndSaveModel(ds, device)
|
||||
case "infer":
|
||||
loadTrainedAndTestAcc(ds, device)
|
||||
default:
|
||||
log.Fatalf("Invalid task: %v. Task can be 'train' or 'infer' only. ", task)
|
||||
}
|
||||
}
|
||||
|
||||
func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
|
||||
|
@ -44,16 +71,14 @@ func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
|
|||
}
|
||||
|
||||
trainable.SetTrain()
|
||||
initialAcc := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, device, 1024)
|
||||
fmt.Printf("Initial Accuracy: %0.4f\n", initialAcc)
|
||||
bestAccuracy := initialAcc
|
||||
bestAccuracy := nn.BatchAccuracyForLogits(vs, trainable, ds.TestImages, ds.TestLabels, device, 1024)
|
||||
fmt.Printf("Initial Accuracy: %0.4f\n", bestAccuracy)
|
||||
|
||||
opt, err := nn.DefaultAdamConfig().Build(vs, 1e-4)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
batchSize := 128
|
||||
for epoch := 0; epoch < 20; epoch++ {
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
|
||||
totalSize := ds.TrainImages.MustSize()[0]
|
||||
samples := int(totalSize)
|
||||
|
@ -88,8 +113,6 @@ func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
|
|||
epocLoss = loss.MustShallowClone()
|
||||
epocLoss.Detach_()
|
||||
|
||||
// fmt.Printf("completed \t %v batches\t %.2f\n", i, loss.Float64Values()[0])
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
}
|
||||
|
@ -109,6 +132,8 @@ func runTrainAndSaveModel(ds *vision.Dataset, device gotch.Device) {
|
|||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Completed training. Best accuracy: %0.4f\n", bestAccuracy)
|
||||
}
|
||||
|
||||
func loadTrainedAndTestAcc(ds *vision.Dataset, device gotch.Device) {
|
||||
|
|
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user