fyp/logic/models/run.go

233 lines
6.7 KiB
Go
Raw Normal View History

2023-09-27 21:20:39 +01:00
package models
import (
2024-04-12 20:36:23 +01:00
"errors"
2023-09-27 21:20:39 +01:00
"os"
"path"
2024-04-14 14:51:16 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
2024-04-12 20:36:23 +01:00
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
2023-09-27 21:20:39 +01:00
"github.com/charmbracelet/log"
2023-09-27 21:20:39 +01:00
tf "github.com/galeone/tensorflow/tensorflow/go"
2023-09-29 13:27:43 +01:00
"github.com/galeone/tensorflow/tensorflow/go/op"
2023-09-27 21:20:39 +01:00
tg "github.com/galeone/tfgo"
"github.com/galeone/tfgo/image"
)
2023-09-29 13:27:43 +01:00
func ReadPNG(scope *op.Scope, imagePath string, channels int64) *image.Image {
scope = tg.NewScope(scope)
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
output := op.DecodePng(scope.SubScope("DecodePng"), contents, op.DecodePngChannels(channels))
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
output = op.ExpandDims(scope.SubScope("Stack"), output, op.Const(scope.SubScope("axis"), []int32{1}))
2023-09-29 13:27:43 +01:00
image := &image.Image{
Tensor: tg.NewTensor(scope, output)}
return image.Scale(0, 255)
}
2023-10-12 09:38:00 +01:00
func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
scope = tg.NewScope(scope)
contents := op.ReadFile(scope.SubScope("ReadFile"), op.Const(scope.SubScope("filename"), imagePath))
output := op.DecodePng(scope.SubScope("DecodeJpeg"), contents, op.DecodePngChannels(channels))
output = op.ExpandDims(scope.SubScope("ExpandDims"), output, op.Const(scope.SubScope("axis"), []int32{0}))
output = op.ExpandDims(scope.SubScope("Stack"), output, op.Const(scope.SubScope("axis"), []int32{1}))
2023-10-12 09:38:00 +01:00
image := &image.Image{
Tensor: tg.NewTensor(scope, output)}
return image.Scale(0, 255)
}
2024-04-12 20:36:23 +01:00
func runModelNormal(base BasePack, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
order = 0
2024-03-02 12:45:49 +00:00
err = nil
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
results := tf_model.Exec([]tf.Output{
tf_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
tf_model.Op("serving_default_rescaling_input", 0): inputImage,
})
var vmax float32 = 0.0
var predictions = results[0].Value().([][]float32)[0]
log.Info("preds", "preds", predictions)
for i, v := range predictions {
if v > vmax {
order = i
vmax = v
}
}
confidence = vmax
return
}
2024-04-12 20:36:23 +01:00
func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) {
log := base.GetLogger()
2024-03-02 12:45:49 +00:00
err = nil
order = 0
log.Info("Running base")
2024-03-02 12:45:49 +00:00
base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
//results := base_model.Exec([]tf.Output{
2024-03-02 12:45:49 +00:00
base_results := base_model.Exec([]tf.Output{
base_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
//base_model.Op("serving_default_rescaling_input", 0): inputImage,
2024-03-02 12:45:49 +00:00
base_model.Op("serving_default_input_1", 0): inputImage,
})
2024-03-02 12:45:49 +00:00
type head struct {
Id string
Range_start int
}
2024-04-12 20:36:23 +01:00
heads, err := GetDbMultitple[head](base.GetDb(), "exp_model_head where def_id=$1;", def_id)
2024-03-02 12:45:49 +00:00
if err != nil {
return
}
log.Info("Running heads", "heads", heads)
2024-04-08 14:17:13 +01:00
2024-03-02 12:45:49 +00:00
var vmax float32 = 0.0
2024-02-19 12:00:30 +00:00
2024-03-02 12:45:49 +00:00
for _, element := range heads {
head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil)
2024-02-19 12:00:30 +00:00
2024-03-02 12:45:49 +00:00
results := head_model.Exec([]tf.Output{
head_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
2024-04-08 14:17:13 +01:00
head_model.Op("serving_default_head_input", 0): base_results[0],
2024-03-02 12:45:49 +00:00
})
2024-02-19 12:00:30 +00:00
2024-03-02 12:45:49 +00:00
var predictions = results[0].Value().([][]float32)[0]
2024-02-19 12:00:30 +00:00
2024-03-02 12:45:49 +00:00
for i, v := range predictions {
2024-04-12 20:36:23 +01:00
base.GetLogger().Debug("predictions", "class", i, "preds", v)
2024-03-02 12:45:49 +00:00
if v > vmax {
order = element.Range_start + i
vmax = v
}
}
}
2024-02-19 12:00:30 +00:00
2024-03-02 12:45:49 +00:00
// TODO runthe head model
confidence = vmax
2024-04-12 20:36:23 +01:00
base.GetLogger().Debug("Got", "heads", len(heads), "order", order, "vmax", vmax)
2024-03-02 12:45:49 +00:00
return
}
2024-04-12 20:36:23 +01:00
func ClassifyTask(base BasePack, task Task) (err error) {
2024-04-18 15:01:36 +01:00
defer func() {
if r := recover(); r != nil {
base.GetLogger().Error("Task failed due to", "error", r)
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Task failed running")
}
}()
2024-04-14 14:51:16 +01:00
task.UpdateStatusLog(base, TASK_RUNNING, "Runner running task")
2023-09-27 21:20:39 +01:00
model, err := GetBaseModel(base.GetDb(), *task.ModelId)
2024-04-12 20:36:23 +01:00
if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model")
return err
}
2023-09-27 21:20:39 +01:00
2024-04-12 20:36:23 +01:00
if !model.CanEval() {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model")
return errors.New("Model not in the right state for evaluation")
}
2023-09-27 21:20:39 +01:00
2024-04-12 20:36:23 +01:00
def := JustId{}
err = GetDBOnce(base.GetDb(), &def, "model_definition where model_id=$1", model.Id)
if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model")
return
}
2023-09-27 21:20:39 +01:00
2024-04-12 20:36:23 +01:00
def_id := def.Id
2023-09-27 21:20:39 +01:00
2024-04-12 20:36:23 +01:00
// TODO create a database table with tasks
run_path := path.Join("/tmp", model.Id, "runs")
os.MkdirAll(run_path, os.ModePerm)
2024-04-12 20:36:23 +01:00
img_path := path.Join("savedData", model.Id, "tasks", task.Id+"."+model.Format)
2023-09-27 21:20:39 +01:00
2024-04-12 20:36:23 +01:00
root := tg.NewRoot()
2023-09-27 21:20:39 +01:00
2024-04-12 20:36:23 +01:00
var tf_img *image.Image = nil
2023-10-06 09:45:47 +01:00
2024-04-12 20:36:23 +01:00
switch model.Format {
case "png":
tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
case "jpeg":
tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
default:
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model")
}
2023-09-29 13:27:43 +01:00
2024-04-12 20:36:23 +01:00
exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
inputImage, err := tf.NewTensor(exec_results[0].Value())
if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
return
}
2023-09-27 21:20:39 +01:00
2024-04-12 20:36:23 +01:00
vi := -1
var confidence float32 = 0
2023-09-28 12:16:36 +01:00
2024-04-12 20:36:23 +01:00
if model.ModelType == 2 {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
vi, confidence, err = runModelExp(base, model, def_id, inputImage)
if err != nil {
2024-04-12 20:36:23 +01:00
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
return
}
2024-04-12 20:36:23 +01:00
} else {
base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id)
vi, confidence, err = runModelNormal(base, model, def_id, inputImage)
if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model")
return
2024-03-02 12:45:49 +00:00
}
2024-04-12 20:36:23 +01:00
}
2023-09-29 13:27:43 +01:00
2024-04-12 20:36:23 +01:00
var GetName struct {
Name string
Id string
}
err = GetDBOnce(base.GetDb(), &GetName, "model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
if err != nil {
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain model results")
return
}
2023-10-03 18:49:36 +01:00
2024-04-12 20:36:23 +01:00
returnValue := struct {
ClassId string `json:"class_id"`
Class string `json:"class"`
Confidence float32 `json:"confidence"`
}{
Class: GetName.Name,
ClassId: GetName.Id,
Confidence: confidence,
}
2024-04-14 14:51:16 +01:00
err = task.SetResult(base, returnValue)
if err != nil {
2024-04-12 20:36:23 +01:00
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to save model results")
return
2024-04-14 14:51:16 +01:00
}
2024-03-09 10:52:08 +00:00
2024-04-12 20:36:23 +01:00
task.UpdateStatusLog(base, TASK_DONE, "Model ran successfully")
2023-10-03 18:49:36 +01:00
2024-04-14 14:51:16 +01:00
return
2023-09-27 21:20:39 +01:00
}