diff --git a/config.toml b/config.toml index f6b16db..8490224 100644 --- a/config.toml +++ b/config.toml @@ -3,3 +3,9 @@ PORT=5002 HOSTNAME="https://testing.andr3h3nriqu3s.com" NUMBER_OF_WORKERS=20 + +SUPRESS_CUDA=1 + +[Worker] +PULLING_TIME="500ms" +NUMBER_OF_WORKERS=1 diff --git a/logic/models/data.go b/logic/models/data.go index 1d03c3d..a55c305 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -89,7 +89,7 @@ func fileProcessor( defer f.Close() f.Write(file_data) - if !testImgForModel(c, model, file_path) { + if !TestImgForModel(c, model, file_path) { c.Logger.Errorf("Image did not have valid format for model %s (in zip: %s)!", file_path, file.Name) c.Logger.Warn("Not failling updating data point to status -1") message := "Image did not have valid format for the model" diff --git a/logic/models/index.go b/logic/models/index.go index ef3e194..fa0258b 100644 --- a/logic/models/index.go +++ b/logic/models/index.go @@ -18,7 +18,6 @@ func HandleModels (handle *Handle) { model_classes.HandleList(handle) // Train endpoints - handleRun(handle) models_train.HandleTrainEndpoints(handle) } diff --git a/logic/models/run.go b/logic/models/run.go index d7969d6..7a064c9 100644 --- a/logic/models/run.go +++ b/logic/models/run.go @@ -1,12 +1,12 @@ package models import ( - "bytes" - "io" + "errors" "os" "path" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" tf "github.com/galeone/tensorflow/tensorflow/go" @@ -35,7 +35,7 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image { return image.Scale(0, 255) } -func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) { +func runModelNormal(base BasePack, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) { order = 0 err = nil @@ -62,7 +62,7 @@ func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf. return } -func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) { +func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, confidence float32, err error) { err = nil order = 0 @@ -82,12 +82,12 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten Range_start int } - heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id) + heads, err := GetDbMultitple[head](base.GetDb(), "exp_model_head where def_id=$1;", def_id) if err != nil { return } - c.Logger.Info("test", "count", len(heads)) + base.GetLogger().Info("test", "count", len(heads)) var vmax float32 = 0.0 @@ -102,9 +102,8 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten var predictions = results[0].Value().([][]float32)[0] - for i, v := range predictions { - c.Logger.Info("predictions", "class", i, "preds", v) + base.GetLogger().Debug("predictions", "class", i, "preds", v) if v > vmax { order = element.Range_start + i vmax = v @@ -115,139 +114,105 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten // TODO runthe head model confidence = vmax - c.Logger.Info("Got", "heads", len(heads), "order", order, "vmax", vmax) + base.GetLogger().Debug("Got", "heads", len(heads), "order", order, "vmax", vmax) return } -func handleRun(handle *Handle) { - handle.Post("/models/run", func(c *Context) *Error { - if !c.CheckAuthLevel(1) { - return nil - } +func ClassifyTask(base BasePack, task Task) (err error) { + task.UpdateStatusLog(base, TASK_RUNNING, "Runner running task") - read_form, err := c.R.MultipartReader() + model, err := GetBaseModel(base.GetDb(), task.ModelId) + if err != nil { + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model") + return err + } + + if !model.CanEval() { + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model") + return errors.New("Model not in the right state for evaluation") + } + + def := JustId{} + err = GetDBOnce(base.GetDb(), &def, "model_definition where model_id=$1", model.Id) + if err != nil { + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model") + return + } + + def_id := def.Id + + // TODO create a database table with tasks + run_path := path.Join("/tmp", model.Id, "runs") + os.MkdirAll(run_path, os.ModePerm) + + img_path := path.Join("savedData", model.Id, "tasks", task.Id+"."+model.Format) + + root := tg.NewRoot() + + var tf_img *image.Image = nil + + switch model.Format { + case "png": + tf_img = ReadPNG(root, img_path, int64(model.ImageMode)) + case "jpeg": + tf_img = ReadJPG(root, img_path, int64(model.ImageMode)) + default: + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain the model") + } + + exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) + inputImage, err := tf.NewTensor(exec_results[0].Value()) + if err != nil { + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model") + return + } + + vi := -1 + var confidence float32 = 0 + + if model.ModelType == 2 { + base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id) + vi, confidence, err = runModelExp(base, model, def_id, inputImage) if err != nil { - return c.JsonBadRequest("Invalid muilpart body") + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model") + return } - - var id string - var file []byte - - for { - part, err_part := read_form.NextPart() - if err_part == io.EOF { - break - } else if err_part != nil { - return c.JsonBadRequest("Invalid multipart data") - } - if part.FormName() == "id" { - buf := new(bytes.Buffer) - buf.ReadFrom(part) - id = buf.String() - } - if part.FormName() == "file" { - buf := new(bytes.Buffer) - buf.ReadFrom(part) - file = buf.Bytes() - } - } - - model, err := GetBaseModel(handle.Db, id) - if err == ModelNotFoundError { - return c.JsonBadRequest("Models not found") - } else if err != nil { - return c.Error500(err) - } - - if model.Status != READY && model.Status != READY_RETRAIN && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION && model.Status != READY_ALTERATION_FAILED { - return c.JsonBadRequest("Model not ready to run images") - } - - def := JustId{} - err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id) - if err == NotFoundError { - return c.JsonBadRequest("Could not find definition") - } else if err != nil { - return c.Error500(err) - } - - def_id := def.Id - - // TODO create a database table with tasks - run_path := path.Join("/tmp", model.Id, "runs") - os.MkdirAll(run_path, os.ModePerm) - img_path := path.Join(run_path, "img."+model.Format) - - img_file, err := os.Create(img_path) + } else { + base.GetLogger().Info("Running model normal", "model", model.Id, "def", def_id) + vi, confidence, err = runModelNormal(base, model, def_id, inputImage) if err != nil { - return c.Error500(err) + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to run model") + return } - defer img_file.Close() - img_file.Write(file) + } - if !testImgForModel(c, model, img_path) { - return c.JsonBadRequest("Provided image does not match the model") - } + var GetName struct { + Name string + Id string + } + err = GetDBOnce(base.GetDb(), &GetName, "model_classes where model_id=$1 and class_order=$2;", model.Id, vi) + if err != nil { + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to obtain model results") + return + } - root := tg.NewRoot() + returnValue := struct { + ClassId string `json:"class_id"` + Class string `json:"class"` + Confidence float32 `json:"confidence"` + }{ + Class: GetName.Name, + ClassId: GetName.Id, + Confidence: confidence, + } - var tf_img *image.Image = nil + err = task.SetResult(base, returnValue) + if err != nil { + task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to save model results") + return + } - switch model.Format { - case "png": - tf_img = ReadPNG(root, img_path, int64(model.ImageMode)) - case "jpeg": - tf_img = ReadJPG(root, img_path, int64(model.ImageMode)) - default: - panic("Not sure what to do with '" + model.Format + "'") - } + task.UpdateStatusLog(base, TASK_DONE, "Model ran successfully") - exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) - inputImage, err := tf.NewTensor(exec_results[0].Value()) - if err != nil { - return c.Error500(err) - } - - vi := -1 - var confidence float32 = 0 - - if model.ModelType == 2 { - c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) - vi, confidence, err = runModelExp(c, model, def_id, inputImage) - if err != nil { - return c.Error500(err) - } - } else { - c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) - vi, confidence, err = runModelNormal(c, model, def_id, inputImage) - if err != nil { - return c.Error500(err) - } - } - - os.RemoveAll(run_path) - - rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi) - if err != nil { - return c.Error500(err) - } - if !rows.Next() { - return c.SendJSON(nil) - } - - var name string - if err = rows.Scan(&name); err != nil { - return c.Error500(err) - } - - returnValue := struct { - Class string `json:"class"` - Confidence float32 `json:"confidence"` - }{ - Class: name, - Confidence: confidence, - } - - return c.SendJSON(returnValue) - }) + return } diff --git a/logic/models/test.go b/logic/models/test.go index 5b7f915..e1a3891 100644 --- a/logic/models/test.go +++ b/logic/models/test.go @@ -10,7 +10,7 @@ import ( . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) -func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { +func TestImgForModel(c *Context, model *BaseModel, path string) (result bool) { result = false infile, err := os.Open(path) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index e2abaf2..0f1cde1 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -58,11 +58,6 @@ func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionSt return } -func UpdateStatus(c *Context, table string, id string, status int) (err error) { - _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) - return -} - func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) { _, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape) return @@ -341,7 +336,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i // This is to load some extra data so that the model has more things to train on // - data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count) + data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count * 10) if err != nil { return } diff --git a/logic/models/utils/types.go b/logic/models/utils/types.go index d770ff6..3db6c18 100644 --- a/logic/models/utils/types.go +++ b/logic/models/utils/types.go @@ -5,18 +5,6 @@ import ( "errors" ) -type BaseModel struct { - Name string - Status int - Id string - - ModelType int - ImageMode int - Width int - Height int - Format string -} - const ( FAILED_TRAINING = -4 FAILED_PREPARING_TRAINING = -3 @@ -75,6 +63,18 @@ const ( MODEL_HEAD_STATUS_READY = 5 ) +type BaseModel struct { + Name string + Status int + Id string + + ModelType int + ImageMode int + Width int + Height int + Format string +} + var ModelNotFoundError = errors.New("Model not found error") func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { @@ -99,6 +99,13 @@ func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { return } +func (m BaseModel) CanEval() bool { + if m.Status != READY && m.Status != READY_RETRAIN && m.Status != READY_RETRAIN_FAILED && m.Status != READY_ALTERATION && m.Status != READY_ALTERATION_FAILED { + return false + } + return true +} + func StringToImageMode(colorMode string) int { switch colorMode { case "greyscale": diff --git a/logic/tasks/handleUpload.go b/logic/tasks/handleUpload.go new file mode 100644 index 0000000..c80352d --- /dev/null +++ b/logic/tasks/handleUpload.go @@ -0,0 +1,123 @@ +package tasks + +import ( + "bytes" + "io" + "net/http" + "os" + "path" + + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +func handleUpload(handler *Handle) { + handler.PostAuth("/tasks/start/image", 1, func(c *Context) *Error { + + read_form, err := c.R.MultipartReader() + if err != nil { + return c.JsonBadRequest("Please provide a valid form data request!") + } + + var json_data string + var file []byte + + for { + part, err_part := read_form.NextPart() + if err_part == io.EOF { + break + } else if err_part != nil { + return c.JsonBadRequest("Please provide a valid form data request!") + } + if part.FormName() == "json_data" { + buf := new(bytes.Buffer) + buf.ReadFrom(part) + json_data = buf.String() + } + if part.FormName() == "file" { + buf := new(bytes.Buffer) + buf.ReadFrom(part) + file = buf.Bytes() + } + } + + var requestData struct { + ModelId string `json:"id" validate:"required"` + } + + _err := c.ParseJson(&requestData, json_data) + if _err != nil { + return _err + } + + model, err := GetBaseModel(c.Db, requestData.ModelId) + if err != nil { + return c.Error500(err) + } + + switch model.Status { + case READY: + case READY_RETRAIN: + case READY_ALTERATION: + case READY_ALTERATION_FAILED: + case READY_RETRAIN_FAILED: + // Model can run + + default: + return c.SendJSONStatus(http.StatusBadRequest, "Model not in the correct status to be able to evaludate a model") + } + + // TODO Check if the user can use this model + + type CreateNewTask struct { + UserId string `db:"user_id"` + ModelId string `db:"model_id"` + TaskType int `db:"task_type"` + Status int `db:"status"` + } + + newTask := CreateNewTask{ + UserId: c.User.Id, + ModelId: model.Id, + // TODO move this to an enum + TaskType: 1, + Status: 0, + } + + id, err := InsertReturnId(c, &newTask, "tasks", "id") + if err != nil { + return c.E500M("Error 500", err) + } + + save_path := path.Join("savedData", model.Id, "tasks") + os.MkdirAll(save_path, os.ModePerm) + + img_path := path.Join(save_path, id+"."+model.Format) + + img_file, err := os.Create(img_path) + if err != nil { + if _err := UpdateTaskStatus(c,id, -1, "Failed to create the file"); _err != nil { + c.Logger.Error("Failed to update tasks") + } + return c.E500M("Failed to create the file", err) + } + defer img_file.Close() + img_file.Write(file) + + if !TestImgForModel(c, model, img_path) { + if _err := UpdateTaskStatus(c, id, -1, "The provided image is not a valid image for this model"); _err != nil { + c.Logger.Error("Failed to update tasks") + } + return c.JsonBadRequest(struct { + Message string `json:"message"` + Id string `json:"task_id"` + } { "Provided image does not match the model", id}) + } + + UpdateStatus(c, "tasks", id, 1) + + return c.SendJSON(struct {Id string `json:"id"`}{id}) + }) +} diff --git a/logic/tasks/index.go b/logic/tasks/index.go new file mode 100644 index 0000000..22ccdc3 --- /dev/null +++ b/logic/tasks/index.go @@ -0,0 +1,11 @@ +package tasks + +import ( + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +func HandleTasks (handle *Handle) { + handleUpload(handle) + handleList(handle) +} + diff --git a/logic/tasks/list.go b/logic/tasks/list.go new file mode 100644 index 0000000..4f79f21 --- /dev/null +++ b/logic/tasks/list.go @@ -0,0 +1,61 @@ +package tasks + +import ( + dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" +) + +func handleList(handler *Handle) { + handler.PostAuth("/tasks/list", 1, func(c *Context) *Error { + var err error = nil + + var requestData struct { + ModelId string `json:"model_id"` + Page int `json:"page"` + } + + if _err := c.ToJSON(&requestData); _err != nil { + return _err + } + + if requestData.ModelId == "" && c.User.UserType < int(dbtypes.User_Admin) { + return c.SendJSONStatus(400, "Please provide a model_id") + } + + if requestData.ModelId != "" { + _, err := GetBaseModel(c.Db, requestData.ModelId) + if err == ModelNotFoundError { + return c.SendJSONStatus(404, "Model not found!") + } else if err != nil { + return c.Error500(err) + } + } + + var rows []*Task = nil + + if requestData.ModelId != "" { + rows, err = GetDbMultitple[Task](c, "tasks where model_id=$1 order by created_on desc limit 11 offset $2", requestData.ModelId, requestData.Page * 10) + if err != nil { + return c.Error500(err) + } + } else { + rows, err = GetDbMultitple[Task](c, "tasks order by created_on desc limit 11 offset $1", requestData.Page * 10) + if err != nil { + return c.Error500(err) + } + } + + max_len := min(11, len(rows)) + + c.ShowMessage = false + return c.SendJSON(struct { + TaskList []*Task `json:"task_list"` + ShowNext bool `json:"show_next"` + } { + rows[0:max_len], + len(rows) > 10, + }) + }) +} diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go new file mode 100644 index 0000000..521d940 --- /dev/null +++ b/logic/tasks/runner/runner.go @@ -0,0 +1,160 @@ +package task_runner + +import ( + "database/sql" + "fmt" + "os" + "time" + + "github.com/charmbracelet/log" + + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" +) + +/** +* Actually runs the code + */ +func runner(db *sql.DB, task_channel chan Task, index int, back_channel chan int) { + logger := log.NewWithOptions(os.Stdout, log.Options{ + ReportCaller: true, + ReportTimestamp: true, + TimeFormat: time.Kitchen, + Prefix: fmt.Sprintf("Runner %d", index), + }) + + defer func() { + if r := recover(); r != nil { + logger.Error("Recovered in file processor", "processor id", index, "due to", r) + back_channel <- -index + } + }() + + logger.Info("Started up") + + var err error + + base := BasePackStruct{ + Db: db, + Logger: logger, + } + + for task := range task_channel { + logger.Info("Got task", "task", task) + + if task.TaskType == int(TASK_TYPE_CLASSIFICATION) { + logger.Info("Classification Task") + if err = ClassifyTask(base, task); err != nil { + logger.Error("Classification task failed", "error", "err") + } + + back_channel <- index + continue + } + + + logger.Error("Do not know how to route task", "task", task) + back_channel <- index + } +} + +/** +* Tells the orcchestator to look at the task list from time to time + */ +func attentionSeeker(config Config, back_channel chan int) { + logger := log.NewWithOptions(os.Stdout, log.Options{ + ReportCaller: true, + ReportTimestamp: true, + TimeFormat: time.Kitchen, + Prefix: "Runner Orchestrator Logger [Attention]", + }) + + logger.Info("Started up") + + t, err := time.ParseDuration(config.GpuWorker.Pulling) + if err != nil { + logger.Error("Failed to load", "error", err) + return + } + + for true { + back_channel <- 0 + + time.Sleep(t) + } +} + +/** +* Manages what worker should to Work + */ +func RunnerOrchestrator(db *sql.DB, config Config) { + logger := log.NewWithOptions(os.Stdout, log.Options{ + ReportCaller: true, + ReportTimestamp: true, + TimeFormat: time.Kitchen, + Prefix: "Runner Orchestrator Logger", + }) + + gpu_workers := config.GpuWorker.NumberOfWorkers + + logger.Info("Starting runners") + + task_runners := make([]chan Task, gpu_workers) + task_runners_used := make([]bool, gpu_workers) + // One more to accomudate the Attention Seeker channel + back_channel := make(chan int, gpu_workers+1) + + go attentionSeeker(config, back_channel) + + // Start the runners + for i := 0; i < gpu_workers; i++ { + task_runners[i] = make(chan Task, 10) + task_runners_used[i] = false + go runner(db, task_runners[i], i+1, back_channel) + } + + var task_to_dispatch *Task = nil + + for i := range back_channel { + + if i > 0 { + logger.Info("Runner freed", "runner", i) + task_runners_used[i-1] = false + } else if i < 0 { + logger.Error("Runner died! Restarting!", "runner", i) + task_runners_used[i-1] = false + go runner(db, task_runners[i-1], i, back_channel) + } + + if task_to_dispatch == nil { + var task Task + err := GetDBOnce(db, &task, "tasks where status=$1 limit 1", TASK_TODO) + if err != NotFoundError && err != nil{ + log.Error("Failed to get tasks from db") + continue + } + if err == NotFoundError { + task_to_dispatch = nil + } else { + task_to_dispatch = &task + } + } + + if task_to_dispatch != nil { + for i := 0; i < len(task_runners_used); i += 1 { + if !task_runners_used[i] { + task_runners[i] <- *task_to_dispatch + task_runners_used[i] = true + task_to_dispatch = nil + break + } + } + } + + } +} + +func StartRunners(db *sql.DB, config Config) { + go RunnerOrchestrator(db, config) +} diff --git a/logic/tasks/utils/utils.go b/logic/tasks/utils/utils.go new file mode 100644 index 0000000..c9b44c3 --- /dev/null +++ b/logic/tasks/utils/utils.go @@ -0,0 +1,68 @@ +package tasks_utils + +import ( + "time" + + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + "github.com/goccy/go-json" +) + +type Task struct { + Id string `db:"id" json:"id"` + UserId string `db:"user_id" json:"user_id"` + ModelId string `db:"model_id" json:"model_id"` + Status int `db:"status" json:"status"` + StatusMessage string `db:"status_message" json:"status_message"` + UserConfirmed int `db:"user_confirmed" json:"user_confirmed"` + Compacted int `db:"compacted" json:"compacted"` + TaskType int `db:"task_type" json:"type"` + Result string `db:"result" json:"result"` + CreatedOn time.Time `db:"created_on" json:"created"` +} + +type TaskStatus int + +const ( + TASK_FAILED_RUNNING TaskStatus = -2 + TASK_FAILED_CREATION = -1 + TASK_PREPARING = 0 + TASK_TODO = 1 + TASK_PICKED_UP = 2 + TASK_RUNNING = 3 + TASK_DONE = 4 +) + +type TaskType int + +const ( + TASK_TYPE_CLASSIFICATION TaskType = 1 +) + +func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) { + return UpdateTaskStatus(base, t.Id, status, message) +} + +/** +* Call the UpdateStatus function and logs on the case of failure! +* This varient does not return any error message + */ +func (t Task) UpdateStatusLog(base BasePack, status TaskStatus, message string) { + err := t.UpdateStatus(base, status, message) + if err != nil { + base.GetLogger().Error("Failed to update task status", "error", err, "task", t.Id) + } +} + +func UpdateTaskStatus(base BasePack, id string, status TaskStatus, message string) (err error) { + _, err = base.GetDb().Exec("update tasks set status=$1, status_message=$2 where id=$3", status, message, id) + return +} + +func (t Task) SetResult(base BasePack, result any) (err error) { + text, err := json.Marshal(result) + if err != nil { + return + } + _, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", text, t.Id) + return +} diff --git a/logic/utils/config.go b/logic/utils/config.go index aeaa846..c764186 100644 --- a/logic/utils/config.go +++ b/logic/utils/config.go @@ -7,10 +7,18 @@ import ( "github.com/charmbracelet/log" ) +type WorkerConfig struct { + NumberOfWorkers int `toml:"number_of_workers"` + Pulling string `toml:"pulling_time"` +} + type Config struct { Hostname string Port int - NumberOfWorkers int `toml:"number_of_workers"` + NumberOfWorkers int `toml:"number_of_workers"` + SupressCuda int `toml:"supress_cuda"` + + GpuWorker WorkerConfig `toml:"Worker"` } func LoadConfig() Config { @@ -25,10 +33,21 @@ func LoadConfig() Config { Hostname: "localhost", Port: 8000, NumberOfWorkers: 10, + GpuWorker: WorkerConfig{ + NumberOfWorkers: 1, + Pulling: "500ms", + }, } } var conf Config _, err = toml.Decode(string(dat), &conf) + + if conf.SupressCuda == 1 { + log.Warn("Supressing Cuda Messages!") + os.Setenv("TF_CPP_MIN_VLOG_LEVEL", "3") + os.Setenv("TF_CPP_MIN_LOG_LEVEL", "3") + } + return conf } diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 5e49567..1713dd7 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -67,7 +67,7 @@ func handleError(err *Error, c *Context) { e = c.SendJSON(500) } if e != nil { - c.Logger.Error("Something went very wront while trying to send and error message") + c.Logger.Error("Something went very wrong while trying to send and error message") c.Writer.Write([]byte("505")) } } @@ -81,7 +81,6 @@ func (x *Handle) Post(path string, fn func(c *Context) *Error) { x.posts = append(x.posts, HandleFunc{path, fn}) } - func (x *Handle) PostAuth(path string, authLevel int, fn func(c *Context) *Error) { inner_fn := func(c *Context) *Error { if !c.CheckAuthLevel(authLevel) { @@ -97,6 +96,13 @@ func (x *Handle) Delete(path string, fn func(c *Context) *Error) { } func (x *Handle) handleGets(context *Context) { + defer func() { + if r := recover(); r != nil { + context.Logger.Error("Something went very wrong", "Error", r) + handleError(&Error{500, "500"}, context) + } + }() + for _, s := range x.gets { if s.path == context.R.URL.Path { handleError(s.fn(context), context) @@ -108,6 +114,13 @@ func (x *Handle) handleGets(context *Context) { } func (x *Handle) handlePosts(context *Context) { + defer func() { + if r := recover(); r != nil { + context.Logger.Error("Something went very wrong", "Error", r) + handleError(&Error{500, "500"}, context) + } + }() + for _, s := range x.posts { if s.path == context.R.URL.Path { handleError(s.fn(context), context) @@ -119,6 +132,13 @@ func (x *Handle) handlePosts(context *Context) { } func (x *Handle) handleDeletes(context *Context) { + defer func() { + if r := recover(); r != nil { + context.Logger.Error("Something went very wrong", "Error", r) + handleError(&Error{500, "500"}, context) + } + }() + for _, s := range x.deletes { if s.path == context.R.URL.Path { handleError(s.fn(context), context) @@ -155,6 +175,20 @@ type Context struct { Handle *Handle } + +func (c Context) GetDb() (*sql.DB) { + return c.Db +} + +func (c Context) GetLogger() (*log.Logger) { + return c.Logger +} + +func (c Context) Query(query string, args ...any) (*sql.Rows, error) { + return c.Db.Query(query, args...) +} + + func (c Context) Prepare(str string) (*sql.Stmt, error) { if c.Tx == nil { return c.Db.Prepare(str) @@ -199,19 +233,32 @@ func (c *Context) RollbackTx() error { return nil } +/** +* Parse and vailidates the json +*/ +func (c Context) ParseJson(dat any, str string) *Error { + decoder := json.NewDecoder(strings.NewReader(str)) + + return c.decodeAndValidade(decoder, dat) +} + func (c Context) ToJSON(dat any) *Error { - decoder := json.NewDecoder(c.R.Body) + + return c.decodeAndValidade(decoder, dat) +} +func (c Context) decodeAndValidade(decoder *json.Decoder, dat any) *Error { err := decoder.Decode(dat) if err != nil { - return c.Error500(err) + c.Logger.Error("Failed to decode json", "dat", dat, "err", err) + return c.JsonBadRequest("Bad Request! Invalid json passed!"); } err = c.Handle.validate.Struct(dat) if err != nil { c.Logger.Error("Failed invalid json passed", "dat", dat, "err", err) - return c.JsonBadRequest("Bad Request! Invalid body passed!") + return c.JsonBadRequest("Bad Request! Invalid json passed!"); } return nil @@ -246,7 +293,7 @@ func (c Context) JsonBadRequest(dat any) *Error { c.SetReportCaller(true) c.Logger.Warn("Request failed with a bad request", "dat", dat) c.SetReportCaller(false) - return c.SendJSONStatus(http.StatusBadRequest, dat) + return c.ErrorCode(nil, 404, dat) } func (c Context) JsonErrorBadRequest(err error, dat any) *Error { @@ -308,6 +355,10 @@ func (c Context) Error500(err error) *Error { return c.ErrorCode(err, http.StatusInternalServerError, nil) } +func (c Context) E500M(msg string, err error) *Error { + return c.ErrorCode(err, http.StatusInternalServerError, msg) +} + func (c *Context) requireAuth() bool { if c.User == nil { return true diff --git a/logic/utils/utils.go b/logic/utils/utils.go index 136c295..83ba440 100644 --- a/logic/utils/utils.go +++ b/logic/utils/utils.go @@ -12,9 +12,29 @@ import ( "strconv" "strings" + "github.com/charmbracelet/log" "github.com/google/uuid" ) + +type BasePack interface { + GetDb() *sql.DB + GetLogger() *log.Logger +} + +type BasePackStruct struct { + Db *sql.DB + Logger *log.Logger +} + +func (b BasePackStruct) GetDb() (*sql.DB) { + return b.Db +} + +func (b BasePackStruct) GetLogger() (*log.Logger) { + return b.Logger +} + func CheckEmpty(f url.Values, path string) bool { return !f.Has(path) || f.Get(path) == "" } @@ -184,7 +204,7 @@ func MyParseForm(r *http.Request) (vs url.Values, err error) { return } -type JustId struct { Id string } +type JustId struct{ Id string } type Generic struct{ reflect.Type } @@ -196,34 +216,39 @@ func generateQuery(t reflect.Type) (query string, nargs int) { query = "" for i := 0; i < nargs; i += 1 { - field := t.Field(i) - name, ok := field.Tag.Lookup("db") - if !ok { - name = field.Name; - } + field := t.Field(i) + name, ok := field.Tag.Lookup("db") + if !ok { + name = field.Name + } - if name == "__nil__" { - continue - } - query += strings.ToLower(name) + "," + if name == "__nil__" { + continue + } + query += strings.ToLower(name) + "," } - - // Remove the last comma + + // Remove the last comma query = query[0 : len(query)-1] - - return + + return } -func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([]*T, error) { - t := reflect.TypeFor[T]() +type QueryInterface interface { + Prepare(str string) (*sql.Stmt, error) + Query(query string, args ...any) (*sql.Rows, error) +} - query, nargs := generateQuery(t) +func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) { + t := reflect.TypeFor[T]() - db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename)) - if err != nil { - return nil, err - } - defer db_query.Close() + query, nargs := generateQuery(t) + + db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename)) + if err != nil { + return nil, err + } + defer db_query.Close() rows, err := db_query.Query(args...) if err != nil { @@ -231,55 +256,55 @@ func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([ } defer rows.Close() - list := []*T{} + list := []*T{} - for rows.Next() { - item := new(T) - if err = mapRow(item, rows, nargs); err != nil { - return nil, err - } - list = append(list, item) - } + for rows.Next() { + item := new(T) + if err = mapRow(item, rows, nargs); err != nil { + return nil, err + } + list = append(list, item) + } - return list, nil + return list, nil } func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { - err = nil + err = nil - val := reflect.Indirect(reflect.ValueOf(store)) - scan_args := make([]interface{}, nargs); - for i := 0; i < nargs; i++ { - valueField := val.Field(i) - scan_args[i] = valueField.Addr().Interface() - } + val := reflect.Indirect(reflect.ValueOf(store)) + scan_args := make([]interface{}, nargs) + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Addr().Interface() + } err = rows.Scan(scan_args...) if err != nil { return } - return nil + return nil } func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) { - t := reflect.TypeOf(store).Elem() + t := reflect.TypeOf(store).Elem() - query, nargs := generateQuery(t) + query, nargs := generateQuery(t) - query2 := "" - for i := 0; i < nargs; i += 1 { - query2 += fmt.Sprintf("$%d,", i) - } - // Remove last quotation + query2 := "" + for i := 0; i < nargs; i += 1 { + query2 += fmt.Sprintf("$%d,", i+1) + } + // Remove last quotation query2 = query2[0 : len(query2)-1] - val := reflect.ValueOf(store).Elem() - scan_args := make([]interface{}, nargs); - for i := 0; i < nargs; i++ { - valueField := val.Field(i) - scan_args[i] = valueField.Interface() - } + val := reflect.ValueOf(store).Elem() + scan_args := make([]interface{}, nargs) + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Interface() + } rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...) if err != nil { @@ -296,15 +321,15 @@ func InsertReturnId(c *Context, store interface{}, tablename string, returnName return } - return + return } -func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error { - t := reflect.TypeOf(store).Elem() +func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error { + t := reflect.TypeOf(store).Elem() - query, nargs := generateQuery(t) + query, nargs := generateQuery(t) - rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) + rows, err := db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) if err != nil { return err } @@ -314,20 +339,24 @@ func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) err return NotFoundError } - err = nil + err = nil - val := reflect.ValueOf(store).Elem() - scan_args := make([]interface{}, nargs); - for i := 0; i < nargs; i++ { - valueField := val.Field(i) - scan_args[i] = valueField.Addr().Interface() - } + val := reflect.ValueOf(store).Elem() + scan_args := make([]interface{}, nargs) + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Addr().Interface() + } err = rows.Scan(scan_args...) if err != nil { return err } - return nil + return nil } +func UpdateStatus(c *Context, table string, id string, status int) (err error) { + _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) + return +} diff --git a/main.go b/main.go index 0a77edd..b31e8bf 100644 --- a/main.go +++ b/main.go @@ -8,8 +8,10 @@ import ( _ "github.com/lib/pq" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks" models_utils "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/runner" ) const ( @@ -36,6 +38,8 @@ func main() { config := LoadConfig() log.Info("Config loaded!", "config", config) + StartRunners(db, config) + //TODO check if file structure exists to save data handle := NewHandler(db, config) @@ -55,6 +59,7 @@ func main() { usersEndpints(db, handle) HandleModels(handle) + HandleTasks(handle) handle.Startup() } diff --git a/sql/tasks.sql b/sql/tasks.sql new file mode 100644 index 0000000..536cb50 --- /dev/null +++ b/sql/tasks.sql @@ -0,0 +1,33 @@ +-- drop table if exists tasks +create table if not exists tasks ( + id uuid primary key default gen_random_uuid(), + user_id uuid references users (id) not null, + model_id uuid references models (id) on delete cascade default null, + + -- -2: Failed Running + -- -1: Failed Creation + -- 0: Preparing + -- 1: TODO + -- 2: Picked up + -- 3: Running + -- 4: Failed + status integer default 1, + status_message text default '', + + result text default '', + + -- -1: user said task is wrong + -- 0: no user input + -- 1: user said task is ok + user_confirmed integer default 0, + + -- Tells the user if the file has been already compacted into + -- embendings + compacted integer default 0, + + -- TODO move the training tasks to here + -- 1: Classification + task_type integer, + + created_on timestamp default current_timestamp +) diff --git a/webpage/src/routes/models/edit/+page.svelte b/webpage/src/routes/models/edit/+page.svelte index 52b1ee4..c9f0e6f 100644 --- a/webpage/src/routes/models/edit/+page.svelte +++ b/webpage/src/routes/models/edit/+page.svelte @@ -35,12 +35,14 @@ import ModelData from './ModelData.svelte'; import DeleteZip from './DeleteZip.svelte'; + import RunModel from './RunModel.svelte'; + import Tabs from 'src/lib/Tabs.svelte'; + import TasksDataPage from './TasksDataPage.svelte'; import ModelDataPage from './ModelDataPage.svelte'; import 'src/styles/forms.css'; - import RunModel from './RunModel.svelte'; - import Tabs from 'src/lib/Tabs.svelte'; + let model: Promise = $state(new Promise(() => {})); let _model: Model | undefined = $state(undefined); let definitions: Promise = $state(new Promise(() => {})); @@ -148,9 +150,19 @@ Model Data {/if} + {#if _model && [5, 6, 7].includes(_model.status)} + + {/if} {#if _model} + {/if}
{#await model} diff --git a/webpage/src/routes/models/edit/RunModel.svelte b/webpage/src/routes/models/edit/RunModel.svelte index 0214028..4575759 100644 --- a/webpage/src/routes/models/edit/RunModel.svelte +++ b/webpage/src/routes/models/edit/RunModel.svelte @@ -3,11 +3,14 @@ import type { Model } from "./+page.svelte"; import FileUpload from "src/lib/FileUpload.svelte"; import MessageSimple from "src/lib/MessageSimple.svelte"; + import { createEventDispatcher } from "svelte"; let {model} = $props<{model: Model}>(); let file: File | undefined = $state(); + const dispatch = createEventDispatcher<{ upload: void }>(); + type Result = { class: string, confidence: number, @@ -15,23 +18,24 @@ let _result: Promise = $state(new Promise(() => {})); let run = $state(false); + let last_task: string | undefined = $state(); let messages: MessageSimple; async function submit() { - console.log("here", file); if (!file) return; messages.clear(); let form = new FormData(); - form.append("id", model.id); + form.append("json_data", JSON.stringify({id: model.id})); form.append("file", file, "file"); run = true; - + try { - _result = await postFormData('models/run', form); - console.log(await _result); + const r = await postFormData('tasks/start/image', form); + last_task = r.id + file = undefined; } catch (e) { if (e instanceof Response) { messages.display(await e.json()); @@ -39,7 +43,8 @@ messages.display("Could not run the model"); } } - + + dispatch('upload'); }
@@ -66,7 +71,11 @@ Run {#if run} - {#await _result then result} + {#await _result} +

+ Processing Image {last_task} +

+ {:then result} {#if !result}

diff --git a/webpage/src/routes/models/edit/TasksDataPage.svelte b/webpage/src/routes/models/edit/TasksDataPage.svelte new file mode 100644 index 0000000..e0c263b --- /dev/null +++ b/webpage/src/routes/models/edit/TasksDataPage.svelte @@ -0,0 +1,16 @@ + + + +
+ table.getList()} /> + +
diff --git a/webpage/src/routes/models/edit/TasksTable.svelte b/webpage/src/routes/models/edit/TasksTable.svelte new file mode 100644 index 0000000..ce007c5 --- /dev/null +++ b/webpage/src/routes/models/edit/TasksTable.svelte @@ -0,0 +1,185 @@ + + + + +
+

Tasks

+ + + + + + + + + + + + + + {#each task_list as task} + + + + + + + + + + {/each} + +
Task type + + User Confirmed Result Status Status Message Created
+ {#if task.type == 1} + Image Run + {:else} + {task.type} + {/if} + + {#if task.type == 1} + + {:else} + TODO Show more information {task.status} + {/if} + + {#if task.type == 1} + {#if task.status == 4} + {#if task.user_confirmed == 0} + User has not agreed to the result of this task + {:else if task.user_confirmed == -1} + User has disagred with the result of this task + {:else if task.user_confirmed == 1} + User has aggred with the result of this task + {:else} + TODO {task.user_confirmed} + {/if} + {:else} + - + {/if} + {:else} + TODO Handle {task.type} + {/if} + + {#if task.status == 4} + {#if task.type == 1} + {@const temp = JSON.parse(task.result)} + {temp.class}({temp.confidence * 100}%) + {:else} + {task.result} + {/if} + {/if} + + {task.status} + + {task.status_message} + + {(new Date(task.created)).toLocaleString()} +
+
+
+ {#if page > 0} + + {/if} +
+ +
+ {page} +
+ +
+ {#if showNext} + + {/if} +
+
+
+ +