feat(example/pretrained-models)
This commit is contained in:
parent
1862030e1d
commit
8ba286ef30
123
example/pretrained-models/main.go
Normal file
123
example/pretrained-models/main.go
Normal file
|
@ -0,0 +1,123 @@
|
|||
package main
|
||||
|
||||
// This example illustrates how to use pre-trained vision models.
|
||||
// model to get the imagenet label for some image.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
var (
|
||||
model string
|
||||
image string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&model, "model", "../../data/pretrained/resnet18.pt", "Model weights for inference")
|
||||
flag.StringVar(&image, "image", "../../data/pretrained/koala.jpg", "Image file to get imagenet label")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
imagePath, err := filepath.Abs(image)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
modelPath, err := filepath.Abs(model)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
in := vision.NewImageNet()
|
||||
|
||||
// Load the image file and resize it to the usual imagenet dimension of 224x224.
|
||||
imageTs, err := in.LoadImageAndResize224(imagePath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
_, modelFile := filepath.Split(modelPath)
|
||||
modelName := strings.TrimSuffix(modelFile, filepath.Ext(modelFile))
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
var net ts.ModuleT
|
||||
switch modelName {
|
||||
case "resnet18":
|
||||
net = vision.ResNet18(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("ResNet18 weights loaded.")
|
||||
case "vgg16":
|
||||
net = vision.VGG16(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("VGG16 weights loaded.")
|
||||
case "alexnet":
|
||||
net = vision.AlexNet(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("AlexNet weights loaded.")
|
||||
case "squeezenet-v1_1":
|
||||
net = vision.SqueezeNetV1_1(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("SqueezeNetV1_1 weights loaded.")
|
||||
case "mobilenet-v2":
|
||||
net = vision.MobileNetV2(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("MobileNetV2 weights loaded.")
|
||||
case "inception-v3":
|
||||
net = vision.InceptionV3(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("InceptionV3 weights loaded.")
|
||||
case "efficient-b4":
|
||||
net = vision.EfficientNetB4(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("EfficientNetB4 weights loaded.")
|
||||
default:
|
||||
log.Fatalf("Invalid model name (%v)\n", modelName)
|
||||
}
|
||||
|
||||
// Apply the forward pass of the model to get the logits.
|
||||
input := imageTs.MustUnsqueeze(0, true)
|
||||
logits := net.ForwardT(input, false)
|
||||
|
||||
// Convert to probability
|
||||
pval := logits.MustSoftmax(-1, gotch.Float, true)
|
||||
|
||||
// Print the top 5 categories for this image.
|
||||
top5 := in.Top(pval, int64(5))
|
||||
|
||||
for _, i := range top5 {
|
||||
fmt.Printf("%-80v %5.2f%%\n", i.Label, i.Pvalue*100)
|
||||
}
|
||||
}
|
|
@ -104,5 +104,7 @@ func AlexNet(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
|||
return res
|
||||
}))
|
||||
|
||||
seq.Add(classifier(p.Sub("classifier"), nclasses))
|
||||
|
||||
return seq
|
||||
}
|
||||
|
|
|
@ -119,7 +119,6 @@ func MobileNetV2(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
|||
|
||||
tmp2 := tmp1.MustMean1([]int64{2}, false, gotch.Float, true)
|
||||
tmp3 := tmp2.MustMean1([]int64{2}, false, gotch.Float, true)
|
||||
tmp2.MustDrop()
|
||||
|
||||
res := tmp3.ApplyT(classifier, train)
|
||||
tmp3.MustDrop()
|
||||
|
|
Loading…
Reference in New Issue
Block a user