Merge pull request #82 from StrongerXi/fix-resnet-basic-block-count
This commit is contained in:
commit
17f2c49e34
|
@ -104,10 +104,10 @@ func (bb *basicBlock) ForwardT(x *ts.Tensor, train bool) *ts.Tensor {
|
|||
func resnet(p *nn.Path, nclasses int64, c1, c2, c3, c4 int64) nn.FuncT {
|
||||
seq := nn.SeqT()
|
||||
layer0 := layerZero(p)
|
||||
layer1 := basicLayer(p.Sub("layer1"), 64, 64, 1, 3)
|
||||
layer2 := basicLayer(p.Sub("layer2"), 64, 128, 2, 4)
|
||||
layer3 := basicLayer(p.Sub("layer3"), 128, 256, 2, 6)
|
||||
layer4 := basicLayer(p.Sub("layer4"), 256, 512, 2, 3)
|
||||
layer1 := basicLayer(p.Sub("layer1"), 64, 64, 1, c1)
|
||||
layer2 := basicLayer(p.Sub("layer2"), 64, 128, 2, c2)
|
||||
layer3 := basicLayer(p.Sub("layer3"), 128, 256, 2, c3)
|
||||
layer4 := basicLayer(p.Sub("layer4"), 256, 512, 2, c4)
|
||||
seq.Add(layer0)
|
||||
seq.Add(layer1)
|
||||
seq.Add(layer2)
|
||||
|
|
Loading…
Reference in New Issue
Block a user