feat(vision/vgg): added vgg models

This commit is contained in:
sugarme 2020-07-03 09:50:02 +10:00
parent 3c115ee79f
commit 4b91f15865

149
vision/vgg.go Normal file
View File

@ -0,0 +1,149 @@
package vision
// VGG models
import (
"fmt"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
// NOTE: each list element contains multiple convolutions with some specified number
// of features followed by a single max-pool layer.
func layersA() (retVal [][]int64) {
return [][]int64{
{64},
{128},
{256, 256},
{512, 512},
{512, 512},
}
}
func layersB() (retVal [][]int64) {
return [][]int64{
{64, 64},
{128, 128},
{256, 256},
{512, 512},
{512, 512},
}
}
func layersD() (retVal [][]int64) {
return [][]int64{
{64, 64},
{128, 128},
{256, 256, 256},
{512, 512, 512},
{512, 512, 512},
}
}
func layersE() (retVal [][]int64) {
return [][]int64{
{64, 64},
{128, 128},
{256, 256, 256, 256},
{512, 512, 512, 512},
{512, 512, 512, 512},
}
}
func vggConv2d(path nn.Path, cIn, cOut int64) (retVal nn.Conv2D) {
config := nn.DefaultConv2DConfig()
config.Stride = []int64{1, 1}
config.Padding = []int64{1, 1}
return nn.NewConv2D(&path, cIn, cOut, 3, config)
}
func vgg(path nn.Path, config [][]int64, nclasses int64, batchNorm bool) nn.SequentialT {
c := path.Sub("classifier")
seq := nn.SeqT()
f := path.Sub("features")
var cIn int64 = 3
for _, channels := range config {
for _, cOut := range channels {
l := seq.Len()
seq.Add(vggConv2d(f.Sub(fmt.Sprintf("%v", l)), cIn, cOut))
if batchNorm {
bnLen := seq.Len()
seq.Add(nn.BatchNorm2D(f.Sub(fmt.Sprintf("%v", bnLen)), cOut, nn.DefaultBatchNormConfig()))
}
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu(false)
})
cIn = cOut
} // end of inner For loop
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
return xs.MaxPool2DDefault(2, false)
})
} // end of outer For loop
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
return xs.FlatView()
})
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("0")), 512*7*7, 4096, *nn.DefaultLinearConfig()))
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu(false)
})
seq.AddFnT(func(xs ts.Tensor, train bool) ts.Tensor {
return xs.Dropout(0.5, train)
})
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("3")), 4096, 4096, *nn.DefaultLinearConfig()))
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu(false)
})
seq.AddFnT(func(xs ts.Tensor, train bool) ts.Tensor {
return xs.Dropout(0.5, train)
})
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("6")), 4096, nclasses, *nn.DefaultLinearConfig()))
return seq
}
func VGG11(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersA(), nclasses, false)
}
func VGG11BN(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersA(), nclasses, true)
}
func VGG13(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersB(), nclasses, false)
}
func VGG13BN(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersB(), nclasses, true)
}
func VGG16(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersD(), nclasses, false)
}
func VGG16BN(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersD(), nclasses, true)
}
func VGG19(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersE(), nclasses, false)
}
func VGG19BN(path nn.Path, nclasses int64) (retVal nn.SequentialT) {
return vgg(path, layersE(), nclasses, true)
}