fix(vision/efficientnet)

This commit is contained in:
sugarme 2020-07-12 11:48:26 +10:00
parent cda08ca450
commit f961f84389
2 changed files with 14 additions and 11 deletions

View File

@ -96,7 +96,7 @@ func main() {
log.Fatal(err)
}
fmt.Println("InceptionV3 weights loaded.")
case "efficient-b4":
case "efficientnet-b4":
net = vision.EfficientNetB4(vs.Root(), in.ClassCount())
err = vs.Load(modelPath)
if err != nil {

View File

@ -62,9 +62,9 @@ func (p params) roundRepeats(repeats int64) (retVal int64) {
func (p params) roundFilters(filters int64) (retVal int64) {
var divisor int64 = 8
filF := p.Width * float64(filters)
filI := int64(filF + float64(divisor))
filI := int64(filF + float64(divisor)/2.0)
newFilters := int64(math.Max(float64(divisor), float64(filI/(divisor*divisor))))
newFilters := int64(math.Max(float64(divisor), float64(filI/divisor*divisor)))
if float64(newFilters) < (0.9 * filF) {
newFilters += int64(divisor)
@ -82,16 +82,16 @@ func enConv2d(vs nn.Path, i, o, k int64, c nn.Conv2DConfig, train bool) (retVal
size := xs.MustSize()
ih := size[2]
iw := size[3]
oh := (ih + s[0] - 1)
ow := (iw + s[0] - 1)
oh := (ih + s[0] - 1) / s[0]
ow := (iw + s[0] - 1) / s[0]
var padH int64 = 0
if (oh-1)*s[0]+k-ih > 0 {
padH = (oh-1)*s[0] + k - ih
if (((oh - 1) * s[0]) + k - ih) > 0 {
padH = ((oh - 1) * s[0]) + k - ih
}
var padW int64 = 0
if (ow-1)*s[0]+k-iw > 0 {
padW = (ow-1)*s[0] + k - iw
if (((ow - 1) * s[0]) + k - iw) > 0 {
padW = ((ow - 1) * s[0]) + k - iw
}
if padW > 0 || padH > 0 {
@ -181,14 +181,17 @@ func block(p nn.Path, args BlockArgs) (retVal ts.ModuleT) {
var se nn.SequentialT // se will be nil if args.SeRatio == 0
if args.SeRatio > 0 {
var nsc int64 = 1
if float64(inp)*args.SeRatio > 1 {
nsc = inp * int64(args.SeRatio)
if (float64(inp) * args.SeRatio) > 1 {
nsc = int64(float64(inp) * args.SeRatio)
}
se = nn.SeqT()
se.Add(enConv2d(p.Sub("_se_reduce"), oup, nsc, 1, nn.DefaultConv2DConfig(), false))
se.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
return xs.Swish()
}))
se.Add(enConv2d(p.Sub("_se_expand"), nsc, oup, 1, nn.DefaultConv2DConfig(), false))
}