reworked on resnet, densenet, added BCELoss, CrossEntropyLoss, changed DataLoader.Reset()

This commit is contained in:
sugarme 2021-07-14 10:38:11 +10:00
parent 77031295cb
commit c89e4b3ba1
7 changed files with 352 additions and 208 deletions

View File

@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
- Fixed temporary fix huge number of learning group returned from C at `libtch/tensor.go AtoGetLearningRates`
- Fixed incorrect `nn.AdamWConfig` and some documentation.
- Fixed - reworked on `vision.ResNet` and `vision.DenseNet` to fix incorrect layers and memory leak
- Changed `dutil.DataLoader.Reset()` to reshuffle when resetting DataLoader if flag is true
- Changed `dutil.DataLoader.Next()`. Deleted case batch size == 1 to make consistency by always returning items in a slice `[]element dtype` even with batchsize = 1.
- Added `nn.CrossEntropyLoss` and `nn.BCELoss`
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -14,6 +14,7 @@ type DataLoader struct {
indexes []int // order of samples in dataset for interation.
batchSize int
currIdx int
sampler Sampler
}
func NewDataLoader(data Dataset, s Sampler) (*DataLoader, error) {
@ -40,6 +41,7 @@ func NewDataLoader(data Dataset, s Sampler) (*DataLoader, error) {
indexes: s.Sample(),
batchSize: s.BatchSize(),
currIdx: 0,
sampler: s,
}, nil
}
@ -70,18 +72,7 @@ func (dl *DataLoader) Next() (interface{}, error) {
return nil, err
}
// Non-batching
if dl.batchSize == 1 {
item, err := dl.dataset.Item(dl.currIdx)
if err != nil {
return nil, err
}
dl.currIdx += 1
return item, nil
}
// Batch sampling
// determine element dtype
elem, err := dl.dataset.Item(0)
if err != nil {
return nil, err
@ -98,6 +89,7 @@ func (dl *DataLoader) Next() (interface{}, error) {
elem.(*ts.Tensor).MustDrop()
}
// Get a batch based on batch size
items := reflect.MakeSlice(reflect.SliceOf(elemType), 0, dl.dataset.Len())
nextIndex := dl.currIdx + dl.batchSize
@ -123,7 +115,14 @@ func (dl *DataLoader) HasNext() bool {
}
// Reset reset index to start position.
func (dl *DataLoader) Reset() {
func (dl *DataLoader) Reset(shuffleOpt ...bool) {
shuffle := false
if len(shuffleOpt) > 0 {
shuffle = shuffleOpt[0]
}
if shuffle {
dl.indexes = dl.sampler.Sample()
}
dl.currIdx = 0
}

109
nn/loss.go Normal file
View File

@ -0,0 +1,109 @@
package nn
import (
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
type lossFnOptions struct {
ClassWeights []float64
Reduction int64 // 0: "None", 1: "mean", 2: "sum"
IgnoreIndex int64
PosWeight int64 // index of the weight attributed to positive class. Used in BCELoss
}
type LossFnOption func(*lossFnOptions)
func WithLossFnWeights(vals []float64) LossFnOption {
return func(o *lossFnOptions) {
o.ClassWeights = vals
}
}
func WithLossFnReduction(val int64) LossFnOption {
return func(o *lossFnOptions) {
o.Reduction = val
}
}
func WithLossFnIgnoreIndex(val int64) LossFnOption {
return func(o *lossFnOptions) {
o.IgnoreIndex = val
}
}
func WithLossFnPosWeight(val int64) LossFnOption {
return func(o *lossFnOptions) {
o.PosWeight = val
}
}
func defaultLossFnOptions() *lossFnOptions {
return &lossFnOptions{
ClassWeights: nil,
Reduction: 1, // "mean"
IgnoreIndex: -100,
PosWeight: -1,
}
}
// CrossEntropyLoss calculates cross entropy loss.
// Ref. https://github.com/pytorch/pytorch/blob/15be189f0de4addf4f68d18022500f67617ab05d/torch/nn/functional.py#L2012
// - logits: tensor of shape [B, C, H, W] corresponding the raw output of the model.
// - target: ground truth tensor of shape [B, 1, H, W]
// - posWeight: scalar representing the weight attributed to positive class.
// This is especially useful for an imbalanced dataset
func CrossEntropyLoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tensor {
options := defaultLossFnOptions()
for _, o := range opts {
o(options)
}
var ws *ts.Tensor
device := logits.MustDevice()
dtype := logits.DType()
if len(options.ClassWeights) > 0 {
ws = ts.MustOfSlice(options.ClassWeights).MustTotype(dtype, true).MustTo(device, true)
} else {
ws = ts.NewTensor()
}
reduction := options.Reduction
ignoreIndex := options.IgnoreIndex
logSm := logits.MustLogSoftmax(-1, gotch.Float, false)
loss := logSm.MustNllLoss(target, ws, reduction, ignoreIndex, true)
ws.MustDrop()
return loss
}
// BCELoss calculates a binary cross entropy loss.
//
// - logits: tensor of shape [B, C, H, W] corresponding the raw output of the model.
// - target: ground truth tensor of shape [B, 1, H, W]
// - posWeight: scalar representing the weight attributed to positive class.
// This is especially useful for an imbalanced dataset
func BCELoss(logits, target *ts.Tensor, opts ...LossFnOption) *ts.Tensor {
options := defaultLossFnOptions()
for _, o := range opts {
o(options)
}
var ws *ts.Tensor
device := logits.MustDevice()
dtype := logits.DType()
if len(options.ClassWeights) > 0 {
ws = ts.MustOfSlice(options.ClassWeights).MustTotype(dtype, true).MustTo(device, true)
} else {
ws = ts.NewTensor()
}
reduction := options.Reduction
var posWeight *ts.Tensor
if options.PosWeight >= 0 {
posWeight = ts.MustOfSlice([]int64{options.PosWeight})
} else {
posWeight = ts.NewTensor()
}
loss := logits.MustSqueeze(false).MustBinaryCrossEntropyWithLogits(target, ws, posWeight, reduction, true)
return loss
}

View File

@ -145,16 +145,16 @@ type AdamWConfig struct {
Wd float64
}
// DefaultAdamConfig creates AdamConfig with default values
func DefaultAdamWConfig() *AdamConfig {
return &AdamConfig{
// DefaultAdamWConfig creates AdamWConfig with default values
func DefaultAdamWConfig() *AdamWConfig {
return &AdamWConfig{
Beta1: 0.9,
Beta2: 0.999,
Wd: 0.0,
Wd: 0.01,
}
}
// NewAdamConfig creates AdamConfig with specified values
// NewAdamWConfig creates AdamWConfig with specified values
func NewAdamWConfig(beta1, beta2, wd float64) *AdamWConfig {
return &AdamWConfig{
Beta1: beta1,
@ -163,11 +163,12 @@ func NewAdamWConfig(beta1, beta2, wd float64) *AdamWConfig {
}
}
// Implement OptimizerConfig interface for AdamConfig
// Implement OptimizerConfig interface for AdamWConfig
func (c *AdamWConfig) buildCOpt(lr float64) (*ts.COptimizer, error) {
return ts.AdamW(lr, c.Beta1, c.Beta2, c.Wd)
}
// Build builds AdamW optimizer
func (c *AdamWConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
return defaultBuild(c, vs, lr)
}

View File

@ -27,5 +27,3 @@ func (ts *Tensor) AccuracyForLogits(targets *Tensor) (retVal *Tensor) {
func (ts *Tensor) MaxPool2DDefault(ksize int64, del bool) (retVal *Tensor) {
return ts.MustMaxPool2d([]int64{ksize, ksize}, []int64{ksize, ksize}, []int64{0, 0}, []int64{1, 1}, false, del)
}
// TODO: continue

View File

@ -21,36 +21,49 @@ func dnConv2d(p *nn.Path, cIn, cOut, ksize, padding, stride int64) *nn.Conv2D {
return nn.NewConv2D(p, cIn, cOut, ksize, config)
}
func denseLayer(p *nn.Path, cIn, bnSize, growth int64) ts.ModuleT {
type denseLayer struct {
Conv1 *nn.Conv2D
Bn1 *nn.BatchNorm
Conv2 *nn.Conv2D
Bn2 *nn.BatchNorm
}
func (l *denseLayer) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
ys1 := xs.ApplyT(l.Bn1, train)
ys2 := ys1.MustRelu(true)
ys3 := ys2.Apply(l.Conv1)
ys2.MustDrop()
ys4 := ys3.ApplyT(l.Bn2, train)
ys3.MustDrop()
ys5 := ys4.MustRelu(true)
ys := ys5.Apply(l.Conv2)
ys5.MustDrop()
res := ts.MustCat([]ts.Tensor{*xs, *ys}, 1)
ys.MustDrop()
return res
}
func newDenseLayer(p *nn.Path, cIn, bnSize, growth int64) ts.ModuleT {
cInter := bnSize * growth
bn1 := nn.BatchNorm2D(p.Sub("norm1"), cIn, nn.DefaultBatchNormConfig())
conv1 := dnConv2d(p.Sub("conv1"), cIn, cInter, 1, 0, 1)
bn2 := nn.BatchNorm2D(p.Sub("norm2"), cInter, nn.DefaultBatchNormConfig())
conv2 := dnConv2d(p.Sub("conv2"), cInter, growth, 3, 1, 1)
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
ys1 := xs.ApplyT(bn1, train)
ys2 := ys1.MustRelu(true)
ys3 := ys2.Apply(conv1)
ys2.MustDrop()
ys4 := ys3.ApplyT(bn2, train)
ys3.MustDrop()
ys5 := ys4.MustRelu(true)
ys := ys5.Apply(conv2)
ys5.MustDrop()
res := ts.MustCat([]ts.Tensor{*xs, *ys}, 1)
ys.MustDrop()
return res
})
return &denseLayer{
Bn1: bn1,
Conv1: conv1,
Bn2: bn2,
Conv2: conv2,
}
}
func denseBlock(p *nn.Path, cIn, bnSize, growth, nlayers int64) ts.ModuleT {
seq := nn.SeqT()
for i := 0; i < int(nlayers); i++ {
seq.Add(denseLayer(p.Sub(fmt.Sprintf("denselayer%v", 1+i)), cIn+int64(i)*growth, bnSize, growth))
seq.Add(newDenseLayer(p.Sub(fmt.Sprintf("denselayer%v", 1+i)), cIn+(int64(i)*growth), bnSize, growth))
}
return seq
@ -74,7 +87,7 @@ func transition(p *nn.Path, cIn, cOut int64) ts.ModuleT {
return seq
}
func densenet(p *nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth int64) ts.ModuleT {
func densenet(p *nn.Path, cIn, bnSize int64, growth int64, blockConfig []int64, cOut int64) ts.ModuleT {
fp := p.Sub("features")
seq := nn.SeqT()
@ -90,12 +103,13 @@ func densenet(p *nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth i
nfeat := cIn
for i, nlayers := range blockConfig {
seq.Add(denseBlock(fp.Sub(fmt.Sprintf("densebloc%v", 1+i)), nfeat, bnSize, growth, nlayers))
seq.Add(denseBlock(fp.Sub(fmt.Sprintf("denseblock%v", 1+i)), nfeat, bnSize, growth, nlayers))
nfeat += nlayers * growth
if i+1 != len(blockConfig) {
seq.Add(transition(fp.Sub(fmt.Sprintf("transition%v", 1+i)), nfeat, nfeat/2))
nfeat = nfeat / 2
}
}
@ -115,6 +129,7 @@ func densenet(p *nn.Path, cIn, cOut, bnSize int64, blockConfig []int64, growth i
}
func DenseNet121(p *nn.Path, nclasses int64) ts.ModuleT {
// path, cIn, bnSize, growth, blockConfig, cOut
return densenet(p, 64, 4, 32, []int64{6, 12, 24, 16}, nclasses)
}

View File

@ -12,122 +12,127 @@ import (
// See "Deep Residual Learning for Image Recognition" He et al. 2015
// https://arxiv.org/abs/1512.03385
func conv2d(path *nn.Path, cIn, cOut, ksize, padding, stride int64) *nn.Conv2D {
config := nn.DefaultConv2DConfig()
config.Stride = []int64{stride, stride}
config.Padding = []int64{padding, padding}
config.Bias = false
func layerZero(p *nn.Path) ts.ModuleT {
conv1 := conv2dNoBias(p.Sub("conv1"), 3, 64, 7, 3, 2)
bn1 := nn.BatchNorm2D(p.Sub("bn1"), 64, nn.DefaultBatchNormConfig())
layer0 := nn.SeqT()
layer0.Add(conv1)
layer0.Add(bn1)
layer0.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustRelu(false)
}))
layer0.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, false)
}))
return nn.NewConv2D(path, cIn, cOut, ksize, config)
}
func downSample(path *nn.Path, cIn, cOut, stride int64) ts.ModuleT {
if stride != 1 || cIn != cOut {
seq := nn.SeqT()
seq.Add(conv2d(path.Sub("0"), cIn, cOut, 1, 0, stride))
seq.Add(nn.BatchNorm2D(path.Sub("1"), cOut, nn.DefaultBatchNormConfig()))
return seq
}
return nn.SeqT()
}
func basicBlock(path *nn.Path, cIn, cOut, stride int64) ts.ModuleT {
conv1 := conv2d(path.Sub("conv1"), cIn, cOut, 3, 1, stride)
bn1 := nn.BatchNorm2D(path.Sub("bn1"), cOut, nn.DefaultBatchNormConfig())
conv2 := conv2d(path.Sub("conv2"), cOut, cOut, 3, 1, 1)
bn2 := nn.BatchNorm2D(path.Sub("bn2"), cOut, nn.DefaultBatchNormConfig())
downsample := downSample(path.Sub("downsample"), cIn, cOut, stride)
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
c1 := xs.Apply(conv1)
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
c2 := relu.Apply(conv2)
relu.MustDrop()
bn2 := c2.ApplyT(bn2, train)
c2.MustDrop()
dsl := xs.ApplyT(downsample, train)
dslAdd := dsl.MustAdd(bn2, true)
res := dslAdd.MustRelu(true)
return res
})
return layer0
}
func basicLayer(path *nn.Path, cIn, cOut, stride, cnt int64) ts.ModuleT {
layer := nn.SeqT()
layer.Add(basicBlock(path.Sub("0"), cIn, cOut, stride))
layer.Add(newBasicBlock(path.Sub("0"), cIn, cOut, stride))
for blockIndex := 1; blockIndex < int(cnt); blockIndex++ {
layer.Add(basicBlock(path.Sub(fmt.Sprint(blockIndex)), cOut, cOut, 1))
layer.Add(newBasicBlock(path.Sub(fmt.Sprint(blockIndex)), cOut, cOut, 1))
}
return layer
}
func resnet(path *nn.Path, nclasses int64, c1, c2, c3, c4 int64) nn.FuncT {
conv1 := conv2d(path.Sub("conv1"), 3, 64, 7, 3, 2)
bn1 := nn.BatchNorm2D(path.Sub("bn1"), 64, nn.DefaultBatchNormConfig())
layer1 := basicLayer(path.Sub("layer1"), 64, 64, 1, c1)
layer2 := basicLayer(path.Sub("layer2"), 64, 128, 2, c2)
layer3 := basicLayer(path.Sub("layer3"), 128, 256, 2, c3)
layer4 := basicLayer(path.Sub("layer4"), 256, 512, 2, c4)
func conv2d(p *nn.Path, cIn, cOut, ksize, padding, stride int64) *nn.Conv2D {
config := nn.DefaultConv2DConfig()
config.Stride = []int64{stride, stride}
config.Padding = []int64{padding, padding}
return nn.NewConv2D(p, cIn, cOut, ksize, config)
}
func conv2dNoBias(p *nn.Path, cIn, cOut, ksize, padding, stride int64) *nn.Conv2D {
config := nn.DefaultConv2DConfig()
config.Bias = false
config.Stride = []int64{stride, stride}
config.Padding = []int64{padding, padding}
return nn.NewConv2D(p, cIn, cOut, ksize, config)
}
func downSample(path *nn.Path, cIn, cOut, stride int64) ts.ModuleT {
if stride != 1 || cIn != cOut {
seq := nn.SeqT()
seq.Add(conv2dNoBias(path.Sub("0"), cIn, cOut, 1, 0, stride))
seq.Add(nn.BatchNorm2D(path.Sub("1"), cOut, nn.DefaultBatchNormConfig()))
return seq
}
return nn.SeqT()
}
type basicBlock struct {
Conv1 *nn.Conv2D
Bn1 *nn.BatchNorm
Conv2 *nn.Conv2D
Bn2 *nn.BatchNorm
Downsample ts.ModuleT
}
func newBasicBlock(path *nn.Path, cIn, cOut, stride int64) *basicBlock {
conv1 := conv2dNoBias(path.Sub("conv1"), cIn, cOut, 3, 1, stride)
bn1 := nn.BatchNorm2D(path.Sub("bn1"), cOut, nn.DefaultBatchNormConfig())
conv2 := conv2dNoBias(path.Sub("conv2"), cOut, cOut, 3, 1, 1)
bn2 := nn.BatchNorm2D(path.Sub("bn2"), cOut, nn.DefaultBatchNormConfig())
downsample := downSample(path.Sub("downsample"), cIn, cOut, stride)
return &basicBlock{conv1, bn1, conv2, bn2, downsample}
}
func (bb *basicBlock) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
c1 := bb.Conv1.ForwardT(x, train)
bn1Ts := bb.Bn1.ForwardT(c1, train)
c1.MustDrop()
relu := bn1Ts.MustRelu(true)
c2 := bb.Conv2.ForwardT(relu, train)
relu.MustDrop()
bn2Ts := bb.Bn2.ForwardT(c2, train)
c2.MustDrop()
dsl := bb.Downsample.ForwardT(x, train)
dslAdd := dsl.MustAdd(bn2Ts, true)
bn2Ts.MustDrop()
res := dslAdd.MustRelu(true)
return res
}
func resnet(p *nn.Path, nclasses int64, c1, c2, c3, c4 int64) nn.FuncT {
seq := nn.SeqT()
layer0 := layerZero(p)
layer1 := basicLayer(p.Sub("layer1"), 64, 64, 1, 3)
layer2 := basicLayer(p.Sub("layer2"), 64, 128, 2, 4)
layer3 := basicLayer(p.Sub("layer3"), 128, 256, 2, 6)
layer4 := basicLayer(p.Sub("layer4"), 256, 512, 2, 3)
seq.Add(layer0)
seq.Add(layer1)
seq.Add(layer2)
seq.Add(layer3)
seq.Add(layer4)
if nclasses > 0 {
// With final layer
linearConfig := nn.DefaultLinearConfig()
fc := nn.NewLinear(path.Sub("fc"), 512, nclasses, linearConfig)
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
c1 := xs.Apply(conv1)
xs.MustDrop()
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
l1 := maxpool.ApplyT(layer1, train)
l2 := l1.ApplyT(layer2, train)
l1.MustDrop()
l3 := l2.ApplyT(layer3, train)
l2.MustDrop()
l4 := l3.ApplyT(layer4, train)
l3.MustDrop()
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
l4.MustDrop()
fc := nn.NewLinear(p.Sub("fc"), 512, nclasses, linearConfig)
return nn.NewFuncT(func(x *ts.Tensor, train bool) *ts.Tensor {
output := seq.ForwardT(x, train)
avgpool := output.MustAdaptiveAvgPool2d([]int64{1, 1}, true)
fv := avgpool.FlatView()
avgpool.MustDrop()
retVal := fv.ApplyOpt(ts.WithModule(fc))
fv.MustDrop()
return retVal
})
} else {
// No final layer
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
c1 := xs.Apply(conv1)
xs.MustDrop()
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
l1 := maxpool.ApplyT(layer1, train)
maxpool.MustDrop()
l2 := l1.ApplyT(layer2, train)
l1.MustDrop()
l3 := l2.ApplyT(layer3, train)
l2.MustDrop()
l4 := l3.ApplyT(layer4, train)
l3.MustDrop()
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
l4.MustDrop()
// no final layer
return nn.NewFuncT(func(x *ts.Tensor, train bool) *ts.Tensor {
output := seq.ForwardT(x, train)
avgpool := output.MustAdaptiveAvgPool2d([]int64{1, 1}, true)
retVal := avgpool.FlatView()
avgpool.MustDrop()
@ -136,26 +141,39 @@ func resnet(path *nn.Path, nclasses int64, c1, c2, c3, c4 int64) nn.FuncT {
}
}
// Creates a ResNet-18 model.
func ResNet18(path *nn.Path, numClasses int64) nn.FuncT {
return resnet(path, numClasses, 2, 2, 2, 2)
type bottleneckBlock struct {
Conv1 *nn.Conv2D
Bn1 *nn.BatchNorm
Conv2 *nn.Conv2D
Bn2 *nn.BatchNorm
Conv3 *nn.Conv2D
Bn3 *nn.BatchNorm
Downsample ts.ModuleT
}
func ResNet18NoFinalLayer(path *nn.Path) nn.FuncT {
return resnet(path, 0, 2, 2, 2, 2)
}
// ForwardT implements ModuleT for bottleneckBlock.
func (b *bottleneckBlock) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
c1 := xs.Apply(b.Conv1)
bn1 := c1.ApplyT(b.Bn1, train)
c1.MustDrop()
relu1 := bn1.MustRelu(true)
c2 := relu1.Apply(b.Conv2)
relu1.MustDrop()
bn2 := c2.ApplyT(b.Bn2, train)
relu2 := bn2.MustRelu(true)
c3 := relu2.Apply(b.Conv3)
relu2.MustDrop()
bn3 := c3.ApplyT(b.Bn3, train)
func ResNet34(path *nn.Path, numClasses int64) nn.FuncT {
return resnet(path, numClasses, 3, 4, 6, 3)
}
func ResNet34NoFinalLayer(path *nn.Path) nn.FuncT {
return resnet(path, 0, 3, 4, 6, 3)
dsl := xs.ApplyT(b.Downsample, train)
add := dsl.MustAdd(bn3, true)
bn3.MustDrop()
res := add.MustRelu(true)
return res
}
// Bottleneck versions for ResNet 50, 101, and 152.
func bottleneckBlock(path *nn.Path, cIn, cOut, stride, e int64) ts.ModuleT {
func newBottleneckBlock(path *nn.Path, cIn, cOut, stride, e int64) *bottleneckBlock {
eDim := e * cOut
conv1 := conv2d(path.Sub("conv1"), cIn, cOut, 1, 0, 1)
bn1 := nn.BatchNorm2D(path.Sub("bn1"), cOut, nn.DefaultBatchNormConfig())
@ -165,33 +183,22 @@ func bottleneckBlock(path *nn.Path, cIn, cOut, stride, e int64) ts.ModuleT {
bn3 := nn.BatchNorm2D(path.Sub("bn3"), eDim, nn.DefaultBatchNormConfig())
downsample := downSample(path.Sub("downsample"), cIn, eDim, stride)
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
c1 := xs.Apply(conv1)
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu1 := bn1.MustRelu(true)
c2 := relu1.Apply(conv2)
relu1.MustDrop()
bn2 := c2.ApplyT(bn2, train)
relu2 := bn2.MustRelu(true)
c3 := relu2.Apply(conv3)
relu2.MustDrop()
bn3 := c3.ApplyT(bn3, train)
dsl := xs.ApplyT(downsample, train)
add := dsl.MustAdd(bn3, true)
bn3.MustDrop()
res := add.MustRelu(true)
return res
})
return &bottleneckBlock{
Conv1: conv1,
Bn1: bn1,
Conv2: conv2,
Bn2: bn2,
Conv3: conv3,
Bn3: bn3,
Downsample: downsample,
}
}
func bottleneckLayer(path *nn.Path, cIn, cOut, stride, cnt int64) ts.ModuleT {
layer := nn.SeqT()
layer.Add(bottleneckBlock(path.Sub("0"), cIn, cOut, stride, 4))
layer.Add(newBottleneckBlock(path.Sub("0"), cIn, cOut, stride, 4))
for blockIndex := 1; blockIndex < int(cnt); blockIndex++ {
layer.Add(bottleneckBlock(path.Sub(fmt.Sprint(blockIndex)), (cOut * 4), cOut, 1, 4))
layer.Add(newBottleneckBlock(path.Sub(fmt.Sprint(blockIndex)), (cOut * 4), cOut, 1, 4))
}
return layer
@ -200,55 +207,39 @@ func bottleneckLayer(path *nn.Path, cIn, cOut, stride, cnt int64) ts.ModuleT {
func bottleneckResnet(path *nn.Path, nclasses int64, c1, c2, c3, c4 int64) ts.ModuleT {
conv1 := conv2d(path.Sub("conv1"), 3, 64, 7, 3, 2)
bn1 := nn.BatchNorm2D(path.Sub("bn1"), 64, nn.DefaultBatchNormConfig())
layer1 := bottleneckLayer(path.Sub("layer1"), 64, 64, 1, c1)
layer2 := bottleneckLayer(path.Sub("layer2"), 4*64, 128, 2, c2)
layer3 := bottleneckLayer(path.Sub("layer3"), 4*128, 256, 2, c3)
layer4 := bottleneckLayer(path.Sub("layer4"), 4*256, 512, 2, c4)
if nclasses > 0 {
fc := nn.NewLinear(path.Sub("fc"), 4*512, nclasses, nn.DefaultLinearConfig())
seq := nn.SeqT()
seq.Add(conv1)
seq.Add(bn1)
seq.Add(layer1)
seq.Add(layer2)
seq.Add(layer3)
seq.Add(layer4)
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
c1 := xs.Apply(conv1)
xs.MustDrop()
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
l1 := maxpool.ApplyT(layer1, train)
l2 := l1.ApplyT(layer2, train)
l1.MustDrop()
l3 := l2.ApplyT(layer3, train)
l2.MustDrop()
l4 := l3.ApplyT(layer4, train)
l3.MustDrop()
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
l4.MustDrop()
if nclasses > 0 {
// With final layer
linearConfig := nn.DefaultLinearConfig()
fc := nn.NewLinear(path.Sub("fc"), 4*512, nclasses, linearConfig)
return nn.NewFuncT(func(x *ts.Tensor, train bool) *ts.Tensor {
output := seq.ForwardT(x, train)
avgpool := output.MustAdaptiveAvgPool2d([]int64{1, 1}, true)
fv := avgpool.FlatView()
avgpool.MustDrop()
retVal := fv.ApplyOpt(ts.WithModule(fc))
fv.MustDrop()
return retVal
})
} else {
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
c1 := xs.Apply(conv1)
xs.MustDrop()
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
maxpool := relu.MustMaxPool2d([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
l1 := maxpool.ApplyT(layer1, train)
maxpool.MustDrop()
l2 := l1.ApplyT(layer2, train)
l1.MustDrop()
l3 := l2.ApplyT(layer3, train)
l2.MustDrop()
l4 := l3.ApplyT(layer4, train)
l3.MustDrop()
avgpool := l4.MustAdaptiveAvgPool2d([]int64{1, 1}, false)
l4.MustDrop()
// no final layer
return nn.NewFuncT(func(x *ts.Tensor, train bool) *ts.Tensor {
output := seq.ForwardT(x, train)
avgpool := output.MustAdaptiveAvgPool2d([]int64{1, 1}, true)
retVal := avgpool.FlatView()
avgpool.MustDrop()
@ -257,26 +248,52 @@ func bottleneckResnet(path *nn.Path, nclasses int64, c1, c2, c3, c4 int64) ts.Mo
}
}
// ResNet18 creates a ResNet-18 model.
func ResNet18(path *nn.Path, numClasses int64) nn.FuncT {
return resnet(path, numClasses, 2, 2, 2, 2)
}
// ResNet18 creates a ResNet-18 model without final fully connfected layer.
func ResNet18NoFinalLayer(path *nn.Path) nn.FuncT {
return resnet(path, 0, 2, 2, 2, 2)
}
// ResNet34 creates a ResNet-34 model.
func ResNet34(path *nn.Path, numClasses int64) nn.FuncT {
return resnet(path, numClasses, 3, 4, 6, 3)
}
// ResNet34 creates a ResNet-34 model without final fully connfected layer.
func ResNet34NoFinalLayer(path *nn.Path) nn.FuncT {
return resnet(path, 0, 3, 4, 6, 3)
}
// ResNet50 creates a ResNet-50 model.
func ResNet50(path *nn.Path, numClasses int64) ts.ModuleT {
return bottleneckResnet(path, numClasses, 3, 4, 6, 3)
}
// ResNet50 creates a ResNet-50 model without final fully connfected layer.
func ResNet50NoFinalLayer(path *nn.Path) ts.ModuleT {
return bottleneckResnet(path, 0, 3, 4, 6, 3)
}
// ResNet101 creates a ResNet-101 model.
func ResNet101(path *nn.Path, numClasses int64) ts.ModuleT {
return bottleneckResnet(path, numClasses, 3, 4, 23, 3)
}
// ResNet101 creates a ResNet-101 model without final fully connfected layer.
func ResNet101NoFinalLayer(path *nn.Path) ts.ModuleT {
return bottleneckResnet(path, 0, 3, 4, 23, 3)
}
// ResNet152 creates a ResNet-152 model.
func ResNet152(path *nn.Path, numClasses int64) ts.ModuleT {
return bottleneckResnet(path, numClasses, 3, 8, 36, 3)
}
// ResNet150 creates a ResNet-150 model without final fully connfected layer.
func ResNet150NoFinalLayer(path *nn.Path) ts.ModuleT {
return bottleneckResnet(path, 0, 3, 8, 36, 3)
}