WIP(example/transfer-leanring)

This commit is contained in:
sugarme 2020-07-01 17:03:34 +10:00
parent 8dee081115
commit ec5be6716f
2 changed files with 80 additions and 22 deletions

View File

@ -0,0 +1,64 @@
package main
// This example illustrates how to use transfer learning to fine tune a pre-trained
// imagenet model on another dataset.
import (
"flag"
"fmt"
"log"
"path/filepath"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/vision"
)
var (
datasetDir string
weights string
)
func init() {
flag.StringVar(&datasetDir, "dataset", "../../data/hymenoptera-data", "full path to dataset directory")
flag.StringVar(&weights, "weights", "../../data/pretrained/resnet18.pt", "resnet18 pretrained weights file")
}
func main() {
flag.Parse()
// Load the dataset and resize it to the usual imagenet dimension of 224x224.
imageNet := vision.NewImageNet()
datasetPath, err := filepath.Abs(datasetDir)
if err != nil {
log.Fatal(err)
}
dataset, err := imageNet.LoadFromDir(datasetPath)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Dataset: %v\n", dataset)
fmt.Printf("Train shape: %v\n", dataset.TrainImages.MustSize())
fmt.Printf("Train shape: %v\n", dataset.TrainLabels.MustSize())
fmt.Printf("Test shape: %v\n", dataset.TestImages.MustSize())
fmt.Printf("Test shape: %v\n", dataset.TestLabels.MustSize())
// Create the model and load the weights from the file.
vs := nn.NewVarStore(gotch.CPU)
net := vision.ResNet18NoFinalLayer(vs.Root())
for k, _ := range vs.Vars.NamedVariables {
fmt.Printf("First variable name: %v\n", k)
}
panic("Stop")
err = vs.Load(weights)
if err != nil {
log.Fatal(err)
}
panic(net)
}

View File

@ -4,7 +4,7 @@ import (
"fmt"
"io/ioutil"
"log"
"os"
// "os"
"path/filepath"
"reflect"
"sync"
@ -142,7 +142,7 @@ func (in ImageNet) hasSuffix(path string) (retVal bool) {
ext := filepath.Ext(path)
switch ext {
case "jpg", "jpeg", "png", "JPG", "JPEG", "PNG":
case ".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG":
return true
default:
return false
@ -152,18 +152,18 @@ func (in ImageNet) hasSuffix(path string) (retVal bool) {
func (in ImageNet) loadImageFromDir(dir string) (retVal ts.Tensor, err error) {
var images []ts.Tensor
base := filepath.Dir(dir)
files, err := ioutil.ReadDir(dir)
if err != nil {
err = fmt.Errorf("ImageNet - loadImageFromDir method call: %v", err)
return retVal, err
}
for _, file := range files {
if !in.hasSuffix(file.Name()) {
continue
}
img, err := in.LoadImageAndResize224(fmt.Sprintf("%v/%v", base, file.Name()))
img, err := in.LoadImageAndResize224(fmt.Sprintf("%v/%v", dir, file.Name()))
if err != nil {
err = fmt.Errorf("ImageNet - loadImageFromDir method call: %v", err)
return retVal, err
@ -188,14 +188,19 @@ func (in ImageNet) loadImageFromDir(dir string) (retVal ts.Tensor, err error) {
// The ImageNet normalization is applied, image are resized to 224x224.
func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
trainPath := fmt.Sprintf("%v/train", path)
validPath := fmt.Sprintf("%v/val", path)
absPath, err := filepath.Abs(path)
if err != nil {
log.Fatal(err)
}
trainPath := fmt.Sprintf("%v/train", absPath)
validPath := fmt.Sprintf("%v/val", absPath)
var classes []string
var subDirs []os.FileInfo
// var subDirs []os.FileInfo
subs, err := ioutil.ReadDir(path)
subs, err := ioutil.ReadDir(validPath)
if err != nil {
err := fmt.Errorf("ImageNet - LoadFromDir method call: %v\n", err)
return retVal, err
@ -205,19 +210,7 @@ func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
if !sub.IsDir() {
continue
}
subDirs = append(subDirs, sub)
}
for _, subDir := range subDirs {
files, err := ioutil.ReadDir(subDir.Name())
if err != nil {
err := fmt.Errorf("ImageNet - LoadFromDir method call: %v\n", err)
return retVal, err
}
for _, file := range files {
classes = append(classes, file.Name())
}
classes = append(classes, sub.Name())
}
fmt.Printf("Classess: %v\n", classes)
@ -236,9 +229,10 @@ func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
trainDir := fmt.Sprintf("%v/%v", trainPath, labelDir)
trainTs, err := in.loadImageFromDir(trainDir)
if err != nil {
err := fmt.Errorf("ImageNet - LoadFromDir method call - Err at classes interating: %v\n", err)
err := fmt.Errorf("ImageNet - LoadFromDir method call - Err at classes iterating: %v\n", err)
return retVal, err
}
ntrainTs := trainTs.MustSize()[0]
trainImages = append(trainImages, trainTs)