From 9eba93cd3c37bb4e12c524513d8f1b76faf600b4 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Tue, 7 May 2024 18:18:48 +0100 Subject: [PATCH] more work on runner --- logic/db_types/types.go | 36 ++++----- logic/models/train/remote_train.go | 1 - logic/tasks/runner.go | 123 ++++++++++++++++++++++++++--- logic/tasks/runner/runner.go | 6 ++ logic/tasks/utils/utils.go | 6 +- logic/utils/handler.go | 2 +- 6 files changed, 142 insertions(+), 32 deletions(-) diff --git a/logic/db_types/types.go b/logic/db_types/types.go index 0bd3537..bad0035 100644 --- a/logic/db_types/types.go +++ b/logic/db_types/types.go @@ -3,7 +3,6 @@ package dbtypes import ( "errors" "fmt" - "path" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" ) @@ -51,17 +50,16 @@ const ( ) type BaseModel struct { - Name string - Status int - Id string - - ModelType int `db:"model_type"` - ImageModeRaw string `db:"color_mode"` - ImageMode int `db:"0"` - Width int - Height int - Format string - CanTrain int `db:"can_train"` + Name string `json:"name"` + Status int `json:"status"` + Id string `json:"id"` + ModelType int `db:"model_type" json:"model_type"` + ImageModeRaw string `db:"color_mode" json:"image_more_raw"` + ImageMode int `db:"0" json:"image_mode"` + Width int `json:"width"` + Height int `json:"height"` + Format string `json:"format"` + CanTrain int `db:"can_train" json:"can_train"` } var ModelNotFoundError = errors.New("Model not found error") @@ -102,6 +100,7 @@ func (m *BaseModel) UpdateStatus(db db.Db, status ModelStatus) (err error) { } type DataPoint struct { + Id string `json:"id"` Class int `json:"class"` Path string `json:"path"` } @@ -126,14 +125,11 @@ func (m BaseModel) DataPoints(db db.Db, mode DATA_POINT_MODE) (data []DataPoint, if err = rows.Scan(&id, &class_order, &file_path); err != nil { return } - if file_path == "id://" { - data = append(data, DataPoint{ - Path: path.Join("./savedData", m.Id, "data", id+"."+m.Format), - Class: class_order, - }) - } else { - panic("TODO remote file path") - } + data = append(data, DataPoint{ + Id: id, + Path: file_path, + Class: class_order, + }) } return } diff --git a/logic/models/train/remote_train.go b/logic/models/train/remote_train.go index 501a707..971e490 100644 --- a/logic/models/train/remote_train.go +++ b/logic/models/train/remote_train.go @@ -57,7 +57,6 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) ( runners := handler.DataMap["runners"].(map[string]interface{}) runner := runners[runner_id].(map[string]interface{}) runner["task"] = &task - runners[runner_id] = runner handler.DataMap["runners"] = runners diff --git a/logic/tasks/runner.go b/logic/tasks/runner.go index bfce53f..f84d72d 100644 --- a/logic/tasks/runner.go +++ b/logic/tasks/runner.go @@ -202,6 +202,8 @@ func handleRemoteRunner(x *Handle) { switch task.TaskType { case int(TASK_TYPE_TRAINING): CleanUpFailed(c, task) + case int(TASK_TYPE_CLASSIFICATION): + // DO nothing default: panic("Do not know how to handle this") } @@ -220,7 +222,7 @@ func handleRemoteRunner(x *Handle) { return c.SendJSON("Ok") }) - PostAuthJson(x, "/tasks/runner/train/defs", User_Normal, func(c *Context, dat *VerifyTask) *Error { + PostAuthJson(x, "/tasks/runner/defs", User_Normal, func(c *Context, dat *VerifyTask) *Error { _, error := verifyRunner(c, &JustId{Id: dat.Id}) if error != nil { return error @@ -231,7 +233,13 @@ func handleRemoteRunner(x *Handle) { return error } - if task.TaskType != int(TASK_TYPE_TRAINING) { + var status DefinitionStatus + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + status = DEFINITION_STATUS_INIT + case int(TASK_TYPE_CLASSIFICATION): + status = DEFINITION_STATUS_READY + default: c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) return c.JsonBadRequest("Task is not the right type go get the definitions") } @@ -241,7 +249,7 @@ func handleRemoteRunner(x *Handle) { return c.E500M("Failed to get model information", err) } - defs, err := model.GetDefinitions(c, "and md.status=$2", DEFINITION_STATUS_INIT) + defs, err := model.GetDefinitions(c, "and md.status=$2", status) if err != nil { return c.E500M("Failed to get the model definitions", err) } @@ -249,7 +257,7 @@ func handleRemoteRunner(x *Handle) { return c.SendJSON(defs) }) - PostAuthJson(x, "/tasks/runner/train/classes", User_Normal, func(c *Context, dat *VerifyTask) *Error { + PostAuthJson(x, "/tasks/runner/classes", User_Normal, func(c *Context, dat *VerifyTask) *Error { _, error := verifyRunner(c, &JustId{Id: dat.Id}) if error != nil { return error @@ -260,7 +268,12 @@ func handleRemoteRunner(x *Handle) { return error } - if task.TaskType != int(TASK_TYPE_TRAINING) { + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + //DO NOTHING + case int(TASK_TYPE_CLASSIFICATION): + //DO NOTHING + default: c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) return c.JsonBadRequest("Task is not the right type go get the definitions") } @@ -425,7 +438,7 @@ func handleRemoteRunner(x *Handle) { return c.SendJSON("Ok") }) - PostAuthJson(x, "/task/runner/train/mark-failed", User_Normal, func(c *Context, dat *VerifyTask) *Error { + PostAuthJson(x, "/tasks/runner/train/mark-failed", User_Normal, func(c *Context, dat *VerifyTask) *Error { _, error := verifyRunner(c, &JustId{Id: dat.Id}) if error != nil { return error @@ -459,7 +472,36 @@ func handleRemoteRunner(x *Handle) { return c.SendJSON("Ok") }) - PostAuthJson(x, "/task/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error { + PostAuthJson(x, "/tasks/runner/model", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + //DO NOTHING + case int(TASK_TYPE_CLASSIFICATION): + //DO NOTHING + default: + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + return c.E500M("Failed to get model information", err) + } + + return c.SendJSON(model) + }) + + PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error { _, error := verifyRunner(c, &JustId{Id: dat.Id}) if error != nil { return error @@ -482,7 +524,7 @@ func handleRemoteRunner(x *Handle) { } var def Definition - err = GetDBOnce(c, &def, "from model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED) + err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED) if err == NotFoundError { // TODO Make the Model status have a message c.Logger.Error("All definitions failed to train!") @@ -526,7 +568,70 @@ func handleRemoteRunner(x *Handle) { return c.E500M("Failed to delete unsed definitions", err) } - model.UpdateStatus(c, READY) + if err = model.UpdateStatus(c, READY); err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") + return c.E500M("Failed to update status of model", err) + } + + task.UpdateStatusLog(c, TASK_DONE, "Model finished training") + + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{}) + runner_data["task"] = nil + runners[dat.Id] = runner_data + x.DataMap["runners"] = runners + + return c.SendJSON("Ok") + }) + + type RunnerClassDone struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + Result string `json:"result" validate:"required"` + } + PostAuthJson(x, "/tasks/runner/class/done", User_Normal, func(c *Context, dat *RunnerClassDone) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{ + Id: dat.Id, + TaskId: dat.TaskId, + }) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_CLASSIFICATION) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + err := task.SetResultText(c, dat.Result) + if err != nil { + return c.E500M("Failed to update the task", err) + } + + err = task.UpdateStatus(c, TASK_DONE, "Task completed") + if err != nil { + return c.E500M("Failed to update task", err) + } + + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{}) + runner_data["task"] = nil + runners[dat.Id] = runner_data + x.DataMap["runners"] = runners return c.SendJSON("Ok") }) diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go index 7f53e48..01bf4ec 100644 --- a/logic/tasks/runner/runner.go +++ b/logic/tasks/runner/runner.go @@ -124,6 +124,12 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas if err := PrepareTraining(handler, base, task, runner_id); err != nil { logger.Error("Failed to prepare for training", "err", err) } + case int(TASK_TYPE_CLASSIFICATION): + runners := handler.DataMap["runners"].(map[string]interface{}) + runner := runners[runner_id].(map[string]interface{}) + runner["task"] = &task + runners[runner_id] = runner + handler.DataMap["runners"] = runners default: logger.Error("Not sure what to do panicing", "taskType", task.TaskType) panic("not sure what to do") diff --git a/logic/tasks/utils/utils.go b/logic/tasks/utils/utils.go index 8a7f105..e202717 100644 --- a/logic/tasks/utils/utils.go +++ b/logic/tasks/utils/utils.go @@ -101,7 +101,11 @@ func (t Task) SetResult(base BasePack, result any) (err error) { if err != nil { return } - _, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", text, t.Id) + return t.SetResultText(base, string(text)) +} + +func (t Task) SetResultText(base BasePack, text string) (err error) { + _, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", []byte(text), t.Id) return } diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 7ec981b..451bbd6 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -449,7 +449,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, - TimeFormat: time.Kitchen, + TimeFormat: time.DateTime, Prefix: r.URL.Path, })