Improved classification performance

This commit is contained in:
2024-05-15 05:32:49 +01:00
parent 516d1d7634
commit 652542d261
18 changed files with 211 additions and 98 deletions

View File

@@ -4,6 +4,7 @@ import (
"errors"
"os"
"path"
"runtime/debug"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
@@ -37,11 +38,19 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
return image.Scale(0, 255)
}
func runModelNormal(model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
func runModelNormal(model *BaseModel, def_id string, inputImage *tf.Tensor, data *RunnerModelData) (order int, confidence float32, err error) {
order = 0
err = nil
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
var tf_model *tg.Model = nil
if data.Id != nil && *data.Id == def_id {
tf_model = data.Model
} else {
tf_model = tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
data.Model = tf_model
data.Id = &def_id
}
results := tf_model.Exec([]tf.Output{
tf_model.Op("StatefulPartitionedCall", 0),
@@ -125,10 +134,15 @@ func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.
return
}
func ClassifyTask(base BasePack, task Task) (err error) {
type RunnerModelData struct {
Id *string
Model *tg.Model
}
func ClassifyTask(base BasePack, task Task, data *RunnerModelData) (err error) {
defer func() {
if r := recover(); r != nil {
base.GetLogger().Error("Task failed due to", "error", r)
base.GetLogger().Error("Task failed due to", "error", r, "stack", string(debug.Stack()))
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Task failed running")
}
}()
@@ -186,6 +200,8 @@ func ClassifyTask(base BasePack, task Task) (err error) {
if model.ModelType == 2 {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
data.Model = nil
data.Id = nil
vi, confidence, err = runModelExp(base, model, def_id, inputImage)
if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
@@ -193,7 +209,7 @@ func ClassifyTask(base BasePack, task Task) (err error) {
}
} else {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
vi, confidence, err = runModelNormal(model, def_id, inputImage)
vi, confidence, err = runModelNormal(model, def_id, inputImage, data)
if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
return