fix(vision/resnet): fixed mem blow out
This commit is contained in:
parent
d23a606a64
commit
9ad62a1a26
|
@ -46,9 +46,21 @@ func basicBlock(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
|
|||
downsample := downSample(path.Sub("downsample"), cIn, cOut, stride)
|
||||
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||
ys := xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).Apply(conv2).ApplyT(bn2, train)
|
||||
downsampleLayer := xs.ApplyT(downsample, train).MustAdd(ys, true)
|
||||
res := downsampleLayer.MustRelu(true)
|
||||
// ys := xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).Apply(conv2).ApplyT(bn2, train)
|
||||
// downsampleLayer := xs.ApplyT(downsample, train).MustAdd(ys, true)
|
||||
// res := downsampleLayer.MustRelu(true)
|
||||
c1 := xs.Apply(conv1)
|
||||
bn1 := c1.ApplyT(bn1, train)
|
||||
c1.MustDrop()
|
||||
relu := bn1.MustRelu(true)
|
||||
c2 := relu.Apply(conv2)
|
||||
relu.MustDrop()
|
||||
bn2 := c2.ApplyT(bn2, train)
|
||||
c2.MustDrop()
|
||||
|
||||
dsl := xs.ApplyT(downsample, train)
|
||||
dslAdd := dsl.MustAdd(bn2, true)
|
||||
res := dslAdd.MustRelu(true)
|
||||
|
||||
return res
|
||||
})
|
||||
|
@ -75,7 +87,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
|||
layer4 := basicLayer(path.Sub("layer4"), 256, 512, 2, c4)
|
||||
|
||||
if nclasses > 0 {
|
||||
|
||||
// With final layer
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
c1 := xs.Apply(conv1)
|
||||
bn1 := c1.ApplyT(bn1, train)
|
||||
|
@ -107,6 +119,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
|||
// No final layer
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
c1 := xs.Apply(conv1)
|
||||
xs.MustDrop()
|
||||
bn1 := c1.ApplyT(bn1, train)
|
||||
c1.MustDrop()
|
||||
relu := bn1.MustRelu(true)
|
||||
|
|
Loading…
Reference in New Issue
Block a user