gotch/vision/densenet.go
Goncalves Henriques, Andre (UG - Computer Science) 9257404edd Move the name of the module
2024-04-21 15:15:00 +01:00

147 lines
3.8 KiB
Go

package vision
// DenseNet implementation.
//
// See "Densely Connected Convolutional Networks", Huang et al 2016.
// https://arxiv.org/abs/1608.06993
import (
"fmt"
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
)
func dnConv2d(p *nn.Path, cIn, cOut, ksize, padding, stride int64) *nn.Conv2D {
config := nn.DefaultConv2DConfig()
config.Stride = []int64{stride, stride}
config.Padding = []int64{padding, padding}
config.Bias = false
return nn.NewConv2D(p, cIn, cOut, ksize, config)
}
type denseLayer struct {
Conv1 *nn.Conv2D
Bn1 *nn.BatchNorm
Conv2 *nn.Conv2D
Bn2 *nn.BatchNorm
}
func (l *denseLayer) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
ys1 := xs.ApplyT(l.Bn1, train)
ys2 := ys1.MustRelu(true)
ys3 := ys2.Apply(l.Conv1)
ys2.MustDrop()
ys4 := ys3.ApplyT(l.Bn2, train)
ys3.MustDrop()
ys5 := ys4.MustRelu(true)
ys := ys5.Apply(l.Conv2)
ys5.MustDrop()
res := ts.MustCat([]*ts.Tensor{xs, ys}, 1)
ys.MustDrop()
return res
}
func newDenseLayer(p *nn.Path, cIn, bnSize, growth int64) ts.ModuleT {
cInter := bnSize * growth
bn1 := nn.BatchNorm2D(p.Sub("norm1"), cIn, nn.DefaultBatchNormConfig())
conv1 := dnConv2d(p.Sub("conv1"), cIn, cInter, 1, 0, 1)
bn2 := nn.BatchNorm2D(p.Sub("norm2"), cInter, nn.DefaultBatchNormConfig())
conv2 := dnConv2d(p.Sub("conv2"), cInter, growth, 3, 1, 1)
return &denseLayer{
Bn1: bn1,
Conv1: conv1,
Bn2: bn2,
Conv2: conv2,
}
}
func denseBlock(p *nn.Path, cIn, bnSize, growth, nlayers int64) ts.ModuleT {
seq := nn.SeqT()
for i := 0; i < int(nlayers); i++ {
seq.Add(newDenseLayer(p.Sub(fmt.Sprintf("denselayer%v", 1+i)), cIn+(int64(i)*growth), bnSize, growth))
}
return seq
}
func transition(p *nn.Path, cIn, cOut int64) ts.ModuleT {
seq := nn.SeqT()
seq.Add(nn.BatchNorm2D(p.Sub("norm"), cIn, nn.DefaultBatchNormConfig()))
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustRelu(false)
}))
seq.Add(dnConv2d(p.Sub("conv"), cIn, cOut, 1, 0, 1))
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.AvgPool2DDefault(2, false)
}))
return seq
}
func densenet(p *nn.Path, cIn, bnSize int64, growth int64, blockConfig []int64, cOut int64) ts.ModuleT {
fp := p.Sub("features")
seq := nn.SeqT()
seq.Add(dnConv2d(fp.Sub("conv0"), 3, cIn, 7, 3, 2))
seq.Add(nn.BatchNorm2D(fp.Sub("norm0"), cIn, nn.DefaultBatchNormConfig()))
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
tmp := xs.MustRelu(false)
return tmp.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
}))
nfeat := cIn
for i, nlayers := range blockConfig {
seq.Add(denseBlock(fp.Sub(fmt.Sprintf("denseblock%v", 1+i)), nfeat, bnSize, growth, nlayers))
nfeat += nlayers * growth
if i+1 != len(blockConfig) {
seq.Add(transition(fp.Sub(fmt.Sprintf("transition%v", 1+i)), nfeat, nfeat/2))
nfeat = nfeat / 2
}
}
seq.Add(nn.BatchNorm2D(fp.Sub("norm5"), nfeat, nn.DefaultBatchNormConfig()))
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
tmp1 := xs.MustRelu(false)
tmp2 := tmp1.MustAvgPool2d([]int64{7, 7}, []int64{1, 1}, []int64{0, 0}, false, true, []int64{1}, true)
res := tmp2.FlatView()
tmp2.MustDrop()
return res
}))
seq.Add(nn.NewLinear(p.Sub("classifier"), nfeat, cOut, nn.DefaultLinearConfig()))
return seq
}
func DenseNet121(p *nn.Path, nclasses int64) ts.ModuleT {
// path, cIn, bnSize, growth, blockConfig, cOut
return densenet(p, 64, 4, 32, []int64{6, 12, 24, 16}, nclasses)
}
func DenseNet161(p *nn.Path, nclasses int64) ts.ModuleT {
return densenet(p, 96, 4, 48, []int64{6, 12, 36, 24}, nclasses)
}
func DenseNet169(p *nn.Path, nclasses int64) ts.ModuleT {
return densenet(p, 64, 4, 32, []int64{6, 12, 32, 32}, nclasses)
}
func DenseNet201(p *nn.Path, nclasses int64) ts.ModuleT {
return densenet(p, 64, 4, 32, []int64{6, 12, 48, 32}, nclasses)
}