fix(vision/efficientnet)
This commit is contained in:
parent
cda08ca450
commit
f961f84389
|
@ -96,7 +96,7 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("InceptionV3 weights loaded.")
|
||||
case "efficient-b4":
|
||||
case "efficientnet-b4":
|
||||
net = vision.EfficientNetB4(vs.Root(), in.ClassCount())
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
|
|
|
@ -62,9 +62,9 @@ func (p params) roundRepeats(repeats int64) (retVal int64) {
|
|||
func (p params) roundFilters(filters int64) (retVal int64) {
|
||||
var divisor int64 = 8
|
||||
filF := p.Width * float64(filters)
|
||||
filI := int64(filF + float64(divisor))
|
||||
filI := int64(filF + float64(divisor)/2.0)
|
||||
|
||||
newFilters := int64(math.Max(float64(divisor), float64(filI/(divisor*divisor))))
|
||||
newFilters := int64(math.Max(float64(divisor), float64(filI/divisor*divisor)))
|
||||
|
||||
if float64(newFilters) < (0.9 * filF) {
|
||||
newFilters += int64(divisor)
|
||||
|
@ -82,16 +82,16 @@ func enConv2d(vs nn.Path, i, o, k int64, c nn.Conv2DConfig, train bool) (retVal
|
|||
size := xs.MustSize()
|
||||
ih := size[2]
|
||||
iw := size[3]
|
||||
oh := (ih + s[0] - 1)
|
||||
ow := (iw + s[0] - 1)
|
||||
oh := (ih + s[0] - 1) / s[0]
|
||||
ow := (iw + s[0] - 1) / s[0]
|
||||
|
||||
var padH int64 = 0
|
||||
if (oh-1)*s[0]+k-ih > 0 {
|
||||
padH = (oh-1)*s[0] + k - ih
|
||||
if (((oh - 1) * s[0]) + k - ih) > 0 {
|
||||
padH = ((oh - 1) * s[0]) + k - ih
|
||||
}
|
||||
var padW int64 = 0
|
||||
if (ow-1)*s[0]+k-iw > 0 {
|
||||
padW = (ow-1)*s[0] + k - iw
|
||||
if (((ow - 1) * s[0]) + k - iw) > 0 {
|
||||
padW = ((ow - 1) * s[0]) + k - iw
|
||||
}
|
||||
|
||||
if padW > 0 || padH > 0 {
|
||||
|
@ -181,14 +181,17 @@ func block(p nn.Path, args BlockArgs) (retVal ts.ModuleT) {
|
|||
var se nn.SequentialT // se will be nil if args.SeRatio == 0
|
||||
if args.SeRatio > 0 {
|
||||
var nsc int64 = 1
|
||||
if float64(inp)*args.SeRatio > 1 {
|
||||
nsc = inp * int64(args.SeRatio)
|
||||
if (float64(inp) * args.SeRatio) > 1 {
|
||||
nsc = int64(float64(inp) * args.SeRatio)
|
||||
}
|
||||
|
||||
se = nn.SeqT()
|
||||
se.Add(enConv2d(p.Sub("_se_reduce"), oup, nsc, 1, nn.DefaultConv2DConfig(), false))
|
||||
|
||||
se.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||
return xs.Swish()
|
||||
}))
|
||||
|
||||
se.Add(enConv2d(p.Sub("_se_expand"), nsc, oup, 1, nn.DefaultConv2DConfig(), false))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user