Merge pull request #82 from StrongerXi/fix-resnet-basic-block-count

This commit is contained in:
sugarme 2022-11-25 13:29:47 +11:00 committed by GitHub
commit 17f2c49e34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -104,10 +104,10 @@ func (bb *basicBlock) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
func resnet(p *nn.Path, nclasses int64, c1, c2, c3, c4 int64) nn.FuncT {
seq := nn.SeqT()
layer0 := layerZero(p)
layer1 := basicLayer(p.Sub("layer1"), 64, 64, 1, 3)
layer2 := basicLayer(p.Sub("layer2"), 64, 128, 2, 4)
layer3 := basicLayer(p.Sub("layer3"), 128, 256, 2, 6)
layer4 := basicLayer(p.Sub("layer4"), 256, 512, 2, 3)
layer1 := basicLayer(p.Sub("layer1"), 64, 64, 1, c1)
layer2 := basicLayer(p.Sub("layer2"), 64, 128, 2, c2)
layer3 := basicLayer(p.Sub("layer3"), 128, 256, 2, c3)
layer4 := basicLayer(p.Sub("layer4"), 256, 512, 2, c4)
seq.Add(layer0)
seq.Add(layer1)
seq.Add(layer2)