Improved classification performance
This commit is contained in:
@@ -16,7 +16,6 @@ import (
|
||||
)
|
||||
|
||||
func loadBaseImage(c *Context, id string) {
|
||||
// TODO handle more types than png
|
||||
infile, err := os.Open(path.Join("savedData", id, "baseimage.png"))
|
||||
if err != nil {
|
||||
c.Logger.Errorf("Failed to read image for model with id %s\n", id)
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"os"
|
||||
"path"
|
||||
"runtime/debug"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
@@ -37,11 +38,19 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
|
||||
return image.Scale(0, 255)
|
||||
}
|
||||
|
||||
func runModelNormal(model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
|
||||
func runModelNormal(model *BaseModel, def_id string, inputImage *tf.Tensor, data *RunnerModelData) (order int, confidence float32, err error) {
|
||||
order = 0
|
||||
err = nil
|
||||
|
||||
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
||||
var tf_model *tg.Model = nil
|
||||
|
||||
if data.Id != nil && *data.Id == def_id {
|
||||
tf_model = data.Model
|
||||
} else {
|
||||
tf_model = tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
|
||||
data.Model = tf_model
|
||||
data.Id = &def_id
|
||||
}
|
||||
|
||||
results := tf_model.Exec([]tf.Output{
|
||||
tf_model.Op("StatefulPartitionedCall", 0),
|
||||
@@ -125,10 +134,15 @@ func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.
|
||||
return
|
||||
}
|
||||
|
||||
func ClassifyTask(base BasePack, task Task) (err error) {
|
||||
type RunnerModelData struct {
|
||||
Id *string
|
||||
Model *tg.Model
|
||||
}
|
||||
|
||||
func ClassifyTask(base BasePack, task Task, data *RunnerModelData) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
base.GetLogger().Error("Task failed due to", "error", r)
|
||||
base.GetLogger().Error("Task failed due to", "error", r, "stack", string(debug.Stack()))
|
||||
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Task failed running")
|
||||
}
|
||||
}()
|
||||
@@ -186,6 +200,8 @@ func ClassifyTask(base BasePack, task Task) (err error) {
|
||||
|
||||
if model.ModelType == 2 {
|
||||
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
data.Model = nil
|
||||
data.Id = nil
|
||||
vi, confidence, err = runModelExp(base, model, def_id, inputImage)
|
||||
if err != nil {
|
||||
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
|
||||
@@ -193,7 +209,7 @@ func ClassifyTask(base BasePack, task Task) (err error) {
|
||||
}
|
||||
} else {
|
||||
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
|
||||
vi, confidence, err = runModelNormal(model, def_id, inputImage)
|
||||
vi, confidence, err = runModelNormal(model, def_id, inputImage, data)
|
||||
if err != nil {
|
||||
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
|
||||
return
|
||||
|
||||
@@ -1191,7 +1191,7 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
|
||||
}
|
||||
order++
|
||||
|
||||
loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
|
||||
loop := max(1, int(math.Ceil((math.Log(float64(model.Width))/math.Log(float64(10)))))+1)
|
||||
for i := 0; i < loop; i++ {
|
||||
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
|
||||
order++
|
||||
@@ -1299,7 +1299,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
|
||||
order++
|
||||
|
||||
// Create the blocks
|
||||
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
|
||||
loop := int(math.Ceil((math.Log(float64(model.Width)) / math.Log(float64(10))))) + 1
|
||||
|
||||
/*if model.Width < 50 && model.Height < 50 {
|
||||
loop = 0
|
||||
|
||||
Reference in New Issue
Block a user