fix(vision/resnet): fixed mem blow out

This commit is contained in:
sugarme 2020-07-02 17:40:29 +10:00
parent d23a606a64
commit 9ad62a1a26

View File

@ -46,9 +46,21 @@ func basicBlock(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
downsample := downSample(path.Sub("downsample"), cIn, cOut, stride)
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
ys := xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).Apply(conv2).ApplyT(bn2, train)
downsampleLayer := xs.ApplyT(downsample, train).MustAdd(ys, true)
res := downsampleLayer.MustRelu(true)
// ys := xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).Apply(conv2).ApplyT(bn2, train)
// downsampleLayer := xs.ApplyT(downsample, train).MustAdd(ys, true)
// res := downsampleLayer.MustRelu(true)
c1 := xs.Apply(conv1)
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
c2 := relu.Apply(conv2)
relu.MustDrop()
bn2 := c2.ApplyT(bn2, train)
c2.MustDrop()
dsl := xs.ApplyT(downsample, train)
dslAdd := dsl.MustAdd(bn2, true)
res := dslAdd.MustRelu(true)
return res
})
@ -75,7 +87,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
layer4 := basicLayer(path.Sub("layer4"), 256, 512, 2, c4)
if nclasses > 0 {
// With final layer
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
c1 := xs.Apply(conv1)
bn1 := c1.ApplyT(bn1, train)
@ -107,6 +119,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
// No final layer
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
c1 := xs.Apply(conv1)
xs.MustDrop()
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)