132 lines
3.6 KiB
Go
132 lines
3.6 KiB
Go
package vision
|
|
|
|
// DenseNet implementation.
|
|
//
|
|
// See "Densely Connected Convolutional Networks", Huang et al 2016.
|
|
// https://arxiv.org/abs/1608.06993
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/sugarme/gotch/nn"
|
|
ts "github.com/sugarme/gotch/tensor"
|
|
)
|
|
|
|
func dnConv2d(p nn.Path, cIn, cOut, ksize, padding, stride int64) (retVal nn.Conv2D) {
|
|
config := nn.DefaultConv2DConfig()
|
|
config.Stride = []int64{stride, stride}
|
|
config.Padding = []int64{padding, padding}
|
|
config.Bias = false
|
|
|
|
return nn.NewConv2D(p, cIn, cOut, ksize, config)
|
|
}
|
|
|
|
func denseLayer(p nn.Path, cIn, bnSize, growth int64) (retVal ts.ModuleT) {
|
|
cInter := bnSize * growth
|
|
bn1 := nn.BatchNorm2D(p.Sub("norm1"), cIn, nn.DefaultBatchNormConfig())
|
|
conv1 := dnConv2d(p.Sub("conv1"), cIn, cInter, 1, 0, 1)
|
|
bn2 := nn.BatchNorm2D(p.Sub("norm2"), cInter, nn.DefaultBatchNormConfig())
|
|
conv2 := dnConv2d(p.Sub("conv2"), cInter, growth, 3, 1, 1)
|
|
|
|
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
|
ys1 := xs.ApplyT(bn1, train)
|
|
ys2 := ys1.MustRelu(true)
|
|
ys3 := ys2.Apply(conv1)
|
|
ys2.MustDrop()
|
|
ys4 := ys3.ApplyT(bn2, train)
|
|
ys3.MustDrop()
|
|
ys5 := ys4.MustRelu(true)
|
|
ys := ys5.Apply(conv2)
|
|
ys5.MustDrop()
|
|
|
|
res := ts.MustCat([]ts.Tensor{xs, ys}, 1)
|
|
ys.MustDrop()
|
|
|
|
return res
|
|
})
|
|
}
|
|
|
|
func denseBlock(p nn.Path, cIn, bnSize, growth, nlayers int64) (retVal ts.ModuleT) {
|
|
seq := nn.SeqT()
|
|
|
|
for i := 0; i < int(nlayers); i++ {
|
|
seq.Add(denseLayer(p.Sub(fmt.Sprintf("denselayer%v", 1+i)), cIn+int64(i)*growth, bnSize, growth))
|
|
}
|
|
|
|
return seq
|
|
}
|
|
|
|
func transition(p nn.Path, cIn, cOut int64) (retVal ts.ModuleT) {
|
|
seq := nn.SeqT()
|
|
|
|
seq.Add(nn.BatchNorm2D(p.Sub("norm"), cIn, nn.DefaultBatchNormConfig()))
|
|
|
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
|
return xs.MustRelu(false)
|
|
}))
|
|
|
|
seq.Add(dnConv2d(p.Sub("conv"), cIn, cOut, 1, 0, 1))
|
|
|
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
|
return xs.AvgPool2DDefault(2, false)
|
|
}))
|
|
|
|
return seq
|
|
}
|
|
|
|
func densenet(p nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth int64) (retVal ts.ModuleT) {
|
|
fp := p.Sub("features")
|
|
seq := nn.SeqT()
|
|
|
|
seq.Add(dnConv2d(fp.Sub("conv0"), 3, cIn, 7, 3, 2))
|
|
|
|
seq.Add(nn.BatchNorm2D(fp.Sub("norm0"), cIn, nn.DefaultBatchNormConfig()))
|
|
|
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
|
tmp := xs.MustRelu(false)
|
|
return tmp.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
|
}))
|
|
|
|
nfeat := cIn
|
|
|
|
for i, nlayers := range blockConfig {
|
|
seq.Add(denseBlock(fp.Sub(fmt.Sprintf("densebloc%v", 1+i)), nfeat, bnSize, growth, nlayers))
|
|
|
|
nfeat += nlayers * growth
|
|
|
|
if i+1 != len(blockConfig) {
|
|
seq.Add(transition(fp.Sub(fmt.Sprintf("transition%v", 1+i)), nfeat, nfeat/2))
|
|
}
|
|
}
|
|
|
|
seq.Add(nn.BatchNorm2D(fp.Sub("norm5"), nfeat, nn.DefaultBatchNormConfig()))
|
|
|
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
|
tmp1 := xs.MustRelu(false)
|
|
tmp2 := tmp1.MustAvgPool2d([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, 1, true)
|
|
res := tmp2.FlatView()
|
|
tmp2.MustDrop()
|
|
return res
|
|
}))
|
|
|
|
seq.Add(nn.NewLinear(p.Sub("classifier"), nfeat, cOut, nn.DefaultLinearConfig()))
|
|
|
|
return seq
|
|
}
|
|
|
|
func DenseNet121(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
|
return densenet(p, 64, 4, 32, []int64{6, 12, 24, 16}, nclasses)
|
|
}
|
|
|
|
func DenseNet161(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
|
return densenet(p, 96, 4, 48, []int64{6, 12, 36, 24}, nclasses)
|
|
}
|
|
|
|
func DenseNet169(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
|
return densenet(p, 64, 4, 32, []int64{6, 12, 32, 32}, nclasses)
|
|
}
|
|
|
|
func DenseNet201(p nn.Path, nclasses int64) (retVal ts.ModuleT) {
|
|
return densenet(p, 64, 4, 32, []int64{6, 12, 48, 32}, nclasses)
|
|
}
|