WIP(example/transfer-leanring)
This commit is contained in:
parent
8dee081115
commit
ec5be6716f
64
example/transfer-learning/main.go
Normal file
64
example/transfer-learning/main.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package main
|
||||
|
||||
// This example illustrates how to use transfer learning to fine tune a pre-trained
|
||||
// imagenet model on another dataset.
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
var (
|
||||
datasetDir string
|
||||
weights string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&datasetDir, "dataset", "../../data/hymenoptera-data", "full path to dataset directory")
|
||||
flag.StringVar(&weights, "weights", "../../data/pretrained/resnet18.pt", "resnet18 pretrained weights file")
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
// Load the dataset and resize it to the usual imagenet dimension of 224x224.
|
||||
imageNet := vision.NewImageNet()
|
||||
datasetPath, err := filepath.Abs(datasetDir)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
dataset, err := imageNet.LoadFromDir(datasetPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Dataset: %v\n", dataset)
|
||||
fmt.Printf("Train shape: %v\n", dataset.TrainImages.MustSize())
|
||||
fmt.Printf("Train shape: %v\n", dataset.TrainLabels.MustSize())
|
||||
fmt.Printf("Test shape: %v\n", dataset.TestImages.MustSize())
|
||||
fmt.Printf("Test shape: %v\n", dataset.TestLabels.MustSize())
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
net := vision.ResNet18NoFinalLayer(vs.Root())
|
||||
|
||||
for k, _ := range vs.Vars.NamedVariables {
|
||||
fmt.Printf("First variable name: %v\n", k)
|
||||
}
|
||||
|
||||
panic("Stop")
|
||||
|
||||
err = vs.Load(weights)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
panic(net)
|
||||
|
||||
}
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
// "os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
@ -142,7 +142,7 @@ func (in ImageNet) hasSuffix(path string) (retVal bool) {
|
|||
ext := filepath.Ext(path)
|
||||
|
||||
switch ext {
|
||||
case "jpg", "jpeg", "png", "JPG", "JPEG", "PNG":
|
||||
case ".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
@ -152,18 +152,18 @@ func (in ImageNet) hasSuffix(path string) (retVal bool) {
|
|||
func (in ImageNet) loadImageFromDir(dir string) (retVal ts.Tensor, err error) {
|
||||
var images []ts.Tensor
|
||||
|
||||
base := filepath.Dir(dir)
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ImageNet - loadImageFromDir method call: %v", err)
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if !in.hasSuffix(file.Name()) {
|
||||
continue
|
||||
}
|
||||
|
||||
img, err := in.LoadImageAndResize224(fmt.Sprintf("%v/%v", base, file.Name()))
|
||||
img, err := in.LoadImageAndResize224(fmt.Sprintf("%v/%v", dir, file.Name()))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("ImageNet - loadImageFromDir method call: %v", err)
|
||||
return retVal, err
|
||||
|
@ -188,14 +188,19 @@ func (in ImageNet) loadImageFromDir(dir string) (retVal ts.Tensor, err error) {
|
|||
// The ImageNet normalization is applied, image are resized to 224x224.
|
||||
func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
|
||||
|
||||
trainPath := fmt.Sprintf("%v/train", path)
|
||||
validPath := fmt.Sprintf("%v/val", path)
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
trainPath := fmt.Sprintf("%v/train", absPath)
|
||||
validPath := fmt.Sprintf("%v/val", absPath)
|
||||
|
||||
var classes []string
|
||||
|
||||
var subDirs []os.FileInfo
|
||||
// var subDirs []os.FileInfo
|
||||
|
||||
subs, err := ioutil.ReadDir(path)
|
||||
subs, err := ioutil.ReadDir(validPath)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ImageNet - LoadFromDir method call: %v\n", err)
|
||||
return retVal, err
|
||||
|
@ -205,19 +210,7 @@ func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
|
|||
if !sub.IsDir() {
|
||||
continue
|
||||
}
|
||||
subDirs = append(subDirs, sub)
|
||||
}
|
||||
|
||||
for _, subDir := range subDirs {
|
||||
files, err := ioutil.ReadDir(subDir.Name())
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ImageNet - LoadFromDir method call: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
classes = append(classes, file.Name())
|
||||
}
|
||||
classes = append(classes, sub.Name())
|
||||
}
|
||||
|
||||
fmt.Printf("Classess: %v\n", classes)
|
||||
|
@ -236,9 +229,10 @@ func (in ImageNet) LoadFromDir(path string) (retVal Dataset, err error) {
|
|||
trainDir := fmt.Sprintf("%v/%v", trainPath, labelDir)
|
||||
trainTs, err := in.loadImageFromDir(trainDir)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ImageNet - LoadFromDir method call - Err at classes interating: %v\n", err)
|
||||
err := fmt.Errorf("ImageNet - LoadFromDir method call - Err at classes iterating: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
ntrainTs := trainTs.MustSize()[0]
|
||||
trainImages = append(trainImages, trainTs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user