diff --git a/logic/db_types/layer.go b/logic/db_types/layer.go index a82bd7a..6897770 100644 --- a/logic/db_types/layer.go +++ b/logic/db_types/layer.go @@ -18,10 +18,10 @@ const ( type Layer struct { Id string `db:"mdl.id" json:"id"` DefinitionId string `db:"mdl.def_id" json:"definition_id"` - LayerOrder string `db:"mdl.layer_order" json:"layer_order"` + LayerOrder int `db:"mdl.layer_order" json:"layer_order"` LayerType LayerType `db:"mdl.layer_type" json:"layer_type"` Shape string `db:"mdl.shape" json:"shape"` - ExpType string `db:"mdl.exp_type" json:"exp_type"` + ExpType int `db:"mdl.exp_type" json:"exp_type"` } func ShapeToString(args ...int) string { diff --git a/logic/models/train/remote_train.go b/logic/models/train/remote_train.go index 971e490..7cbf356 100644 --- a/logic/models/train/remote_train.go +++ b/logic/models/train/remote_train.go @@ -6,6 +6,7 @@ import ( . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + "github.com/charmbracelet/log" "github.com/goccy/go-json" ) @@ -39,7 +40,6 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) ( } if model.ModelType == 2 { - panic("TODO") full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) if full_error != nil { l.Error("Failed to generate defintions", "err", full_error) @@ -75,4 +75,44 @@ func CleanUpFailed(b BasePack, task *Task) { l.Error("Failed to get status", err) } } + // Set the class status to trained + err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) + if err != nil { + l.Error("Failed to set class status") + return + } +} + +func CleanUpFailedRetrain(b BasePack, task *Task) { + db := b.GetDb() + l := b.GetLogger() + + model, err := GetBaseModel(db, *task.ModelId) + if err != nil { + l.Error("Failed to get model", "err", err) + } else { + err = model.UpdateStatus(db, FAILED_TRAINING) + if err != nil { + l.Error("Failed to get status", err) + } + } + + ResetClasses(b, model) + ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED) + + var defData struct { + Id string `db:"md.id"` + TargetAcuuracy float64 `db:"md.target_accuracy"` + } + + err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId) + if err != nil { + log.Error("failed to get def data", err) + return + } + + _, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id) + if err_ != nil { + panic(err_) + } } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 0309cfe..e911d59 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -101,6 +101,10 @@ func setModelClassStatus(c BasePack, status ModelClassStatus, filter string, arg return } +func SetModelClassStatus(c BasePack, status ModelClassStatus, filter string, args ...any) (err error) { + return setModelClassStatus(c, status, filter, args...) +} + func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) { db := c.GetDb() @@ -1090,7 +1094,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) { return err } - if err = splitModel(c, model); err != nil { + if err = SplitModel(c, model); err != nil { err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) if err != nil { l.Error("Failed to split the model! And Failed to set class status") @@ -1123,7 +1127,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) { return } -func splitModel(c BasePack, model *BaseModel) (err error) { +func SplitModel(c BasePack, model *BaseModel) (err error) { db := c.GetDb() l := c.GetLogger() @@ -1260,7 +1264,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe } db := c.GetDb() - l := c.GetLogger() def, err := MakeDefenition(db, model.Id, target_accuracy) if err != nil { @@ -1279,67 +1282,35 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe } order++ - if complexity == 0 { - /* - _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "") - if err != nil { - failed() - return - } - order++ - */ - - _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "") + loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10))))) + for i := 0; i < loop; i++ { + _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "") + order++ if err != nil { failed() return } - order++ + } - loop := int(math.Log2(float64(number_of_classes))) - for i := 0; i < loop; i++ { - _, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i))) - order++ - if err != nil { - ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) - return - } - } - } else if complexity == 1 || complexity == 2 { - loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10))))) - for i := 0; i < loop; i++ { - _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "") - order++ - if err != nil { - failed() - return - } - } - - _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "") - if err != nil { - failed() - return - } - order++ - - loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) - if loop == 0 { - loop = 1 - } - for i := 0; i < loop; i++ { - _, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i))) - order++ - if err != nil { - failed() - return - } - } - } else { - l.Error("Unkown complexity", "complexity", complexity) + _, err = def.MakeLayer(db, order, LAYER_FLATTEN, "") + if err != nil { failed() return } + order++ + + loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) + if loop == 0 { + loop = 1 + } + for i := 0; i < loop; i++ { + _, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i))) + order++ + if err != nil { + failed() + return + } + } return def.UpdateStatus(db, DEFINITION_STATUS_INIT) } @@ -1410,19 +1381,12 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy order := 1 - width := model.Width - height := model.Height - - // Note the shape of the first layer defines the import size - if complexity == 2 { - // Note the shape for now is no used - width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0)))) - height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0)))) - l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height) + err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height), 1) + if err != nil { + failed() + return } - err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1) - order++ // handle the errors inside the pervious if block @@ -1460,7 +1424,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy order++ // Flatten the blocks into dense - err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1) + err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*2), 1) if err != nil { failed() return @@ -1474,7 +1438,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy loop = max(loop, 3) for i := 0; i < loop; i++ { - err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)*2), 2) + err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)*2), 2) order++ if err != nil { failed() @@ -1747,6 +1711,13 @@ func RunTaskRetrain(b BasePack, task Task) (err error) { return } + _, err = db.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id) + if err != nil { + l.Error("Error while updating the classes", "error", err) + failed() + return + } + task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining") return diff --git a/logic/tasks/runner.go b/logic/tasks/runner.go index f84d72d..8be7305 100644 --- a/logic/tasks/runner.go +++ b/logic/tasks/runner.go @@ -10,6 +10,7 @@ import ( . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + "github.com/charmbracelet/log" ) func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) { @@ -34,6 +35,12 @@ type VerifyTask struct { TaskId string `json:"taskId" validate:"required"` } +type RunnerTrainDef struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + DefId string `json:"defId" validate:"required"` +} + func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) { mutex := x.DataMap["runners_mutex"].(*sync.Mutex) mutex.Lock() @@ -53,6 +60,18 @@ func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Erro return runner_data["task"].(*Task), nil } +func clearRunnerTask(x *Handle, runner_id string) { + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + var runner_data map[string]interface{} = runners[runner_id].(map[string]interface{}) + runner_data["task"] = nil + runners[runner_id] = runner_data + x.DataMap["runners"] = runners +} + func handleRemoteRunner(x *Handle) { type RegisterRunner struct { @@ -202,6 +221,8 @@ func handleRemoteRunner(x *Handle) { switch task.TaskType { case int(TASK_TYPE_TRAINING): CleanUpFailed(c, task) + case int(TASK_TYPE_RETRAINING): + CleanUpFailedRetrain(c, task) case int(TASK_TYPE_CLASSIFICATION): // DO nothing default: @@ -237,6 +258,8 @@ func handleRemoteRunner(x *Handle) { switch task.TaskType { case int(TASK_TYPE_TRAINING): status = DEFINITION_STATUS_INIT + case int(TASK_TYPE_RETRAINING): + fallthrough case int(TASK_TYPE_CLASSIFICATION): status = DEFINITION_STATUS_READY default: @@ -268,27 +291,35 @@ func handleRemoteRunner(x *Handle) { return error } - switch task.TaskType { - case int(TASK_TYPE_TRAINING): - //DO NOTHING - case int(TASK_TYPE_CLASSIFICATION): - //DO NOTHING - default: - c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) - return c.JsonBadRequest("Task is not the right type go get the definitions") - } - model, err := GetBaseModel(c, *task.ModelId) if err != nil { return c.E500M("Failed to get model information", err) } - classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN) - if err != nil { - return c.E500M("Failed to get the model classes", err) + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + classes, err := model.GetClasses(c, "and status in ($2, $3) order by mc.class_order asc", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING) + if err != nil { + return c.E500M("Failed to get the model classes", err) + } + return c.SendJSON(classes) + case int(TASK_TYPE_RETRAINING): + classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINING) + if err != nil { + return c.E500M("Failed to get the model classes", err) + } + return c.SendJSON(classes) + case int(TASK_TYPE_CLASSIFICATION): + classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINED) + if err != nil { + return c.E500M("Failed to get the model classes", err) + } + return c.SendJSON(classes) + default: + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") } - return c.SendJSON(classes) }) type RunnerTrainDefStatus struct { @@ -326,12 +357,13 @@ func handleRemoteRunner(x *Handle) { return c.SendJSON("Ok") }) - type RunnerTrainDefLayers struct { - Id string `json:"id" validate:"required"` - TaskId string `json:"taskId" validate:"required"` - DefId string `json:"defId" validate:"required"` + type RunnerTrainDefHeadStatus struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + DefId string `json:"defId" validate:"required"` + Status ModelHeadStatus `json:"status" validate:"required"` } - PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDefLayers) *Error { + PostAuthJson(x, "/tasks/runner/train/def/head/status", User_Normal, func(c *Context, dat *RunnerTrainDefHeadStatus) *Error { _, error := verifyRunner(c, &JustId{Id: dat.Id}) if error != nil { return error @@ -352,6 +384,69 @@ func handleRemoteRunner(x *Handle) { return c.E500M("Failed to get definition information", err) } + _, err = c.Exec("update exp_model_head set status=$1 where def_id=$2;", dat.Status, def.Id) + if err != nil { + log.Error("Failed to train definition!") + return c.E500M("Failed to train definition", err) + } + + return c.SendJSON("Ok") + }) + + type RunnerRetrainDefHeadStatus struct { + Id string `json:"id" validate:"required"` + TaskId string `json:"taskId" validate:"required"` + HeadId string `json:"defId" validate:"required"` + Status ModelHeadStatus `json:"status" validate:"required"` + } + PostAuthJson(x, "/tasks/runner/retrain/def/head/status", User_Normal, func(c *Context, dat *RunnerRetrainDefHeadStatus) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId}) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_RETRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + if err := UpdateStatus(c.GetDb(), "exp_model_head", dat.HeadId, MODEL_DEFINITION_STATUS_TRAINING); err != nil { + return c.E500M("Failed to update head status", err) + } + + return c.SendJSON("Ok") + }) + PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId}) + if error != nil { + return error + } + + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + // Do nothing + case int(TASK_TYPE_RETRAINING): + // Do nothing + default: + c.Logger.Error("Task not is not the right type to get the layers", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the layers") + } + + def, err := GetDefinition(c, dat.DefId) + if err != nil { + return c.E500M("Failed to get definition information", err) + } + layers, err := def.GetLayers(c, " order by layer_order asc") if err != nil { return c.E500M("Failed to get layers", err) @@ -371,7 +466,12 @@ func handleRemoteRunner(x *Handle) { return error } - if task.TaskType != int(TASK_TYPE_TRAINING) { + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + // DO nothing + case int(TASK_TYPE_RETRAINING): + // DO nothing + default: c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) return c.JsonBadRequest("Task is not the right type go get the definitions") } @@ -486,6 +586,8 @@ func handleRemoteRunner(x *Handle) { switch task.TaskType { case int(TASK_TYPE_TRAINING): //DO NOTHING + case int(TASK_TYPE_RETRAINING): + //DO NOTHING case int(TASK_TYPE_CLASSIFICATION): //DO NOTHING default: @@ -501,6 +603,74 @@ func handleRemoteRunner(x *Handle) { return c.SendJSON(model) }) + PostAuthJson(x, "/tasks/runner/heads", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId}) + if error != nil { + return error + } + + type ExpHead struct { + Id string `json:"id"` + Start int `db:"range_start" json:"start"` + End int `db:"range_end" json:"end"` + } + + switch task.TaskType { + case int(TASK_TYPE_TRAINING): + fallthrough + case int(TASK_TYPE_RETRAINING): + // status = 2 (INIT) 3 (TRAINING) + heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and status in (2,3)", dat.DefId) + if err != nil { + return c.E500M("Failed getting active heads", err) + } + return c.SendJSON(heads) + case int(TASK_TYPE_CLASSIFICATION): + heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1", dat.DefId) + if err != nil { + return c.E500M("Failed getting active heads", err) + } + return c.SendJSON(heads) + default: + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + }) + + PostAuthJson(x, "/tasks/runner/train_exp/class/status/train", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + return c.E500M("Failed to get model", err) + } + + err = SetModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN) + if err != nil { + return c.E500M("Failed update status", err) + } + + return c.SendJSON("Ok") + }) + PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error { _, error := verifyRunner(c, &JustId{Id: dat.Id}) if error != nil { @@ -568,6 +738,13 @@ func handleRemoteRunner(x *Handle) { return c.E500M("Failed to delete unsed definitions", err) } + // Set the class status to trained + err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1;", model.Id) + if err != nil { + c.Logger.Error("Failed to set class status") + return c.E500M("Failed to set class status", err) + } + if err = model.UpdateStatus(c, READY); err != nil { model.UpdateStatus(c, FAILED_TRAINING) task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") @@ -576,16 +753,7 @@ func handleRemoteRunner(x *Handle) { task.UpdateStatusLog(c, TASK_DONE, "Model finished training") - mutex := x.DataMap["runners_mutex"].(*sync.Mutex) - mutex.Lock() - defer mutex.Unlock() - - var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) - var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{}) - runner_data["task"] = nil - runners[dat.Id] = runner_data - x.DataMap["runners"] = runners - + clearRunnerTask(x, dat.Id) return c.SendJSON("Ok") }) @@ -635,4 +803,153 @@ func handleRemoteRunner(x *Handle) { return c.SendJSON("Ok") }) + + PostAuthJson(x, "/tasks/runner/train_exp/done", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_TRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + c.Logger.Error("Failed to get model", "err", err) + return c.E500M("Failed to get mode", err) + } + + // TODO add check the to the model + + var def Definition + err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED) + if err == NotFoundError { + // TODO Make the Model status have a message + c.Logger.Error("All definitions failed to train!") + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!") + clearRunnerTask(x, dat.Id) + return c.SendJSON("Ok") + } else if err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition") + return c.E500M("Failed to get model definition", err) + } + + if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition") + return c.E500M("Failed to update model definition", err) + } + + to_delete, err := GetDbMultitple[JustId](c, "model_definition where status!=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id) + if err != nil { + c.GetLogger().Error("Failed to select model_definition to delete") + return c.E500M("Failed to select model definition to delete", err) + } + + for _, d := range to_delete { + os.RemoveAll(path.Join("savedData", model.Id, "defs", d.Id)) + } + + // TODO Check if returning also works here + if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") + return c.E500M("Failed to delete unsed definitions", err) + } + + if err = SplitModel(c, model); err != nil { + err = SetModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) + if err != nil { + c.Logger.Error("Failed to split the model! And Failed to set class status") + return c.E500M("Failed to split the model", err) + } + + c.Logger.Error("Failed to split the model") + return c.E500M("Failed to split the model", err) + } + + // Set the class status to trained + err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) + if err != nil { + c.Logger.Error("Failed to set class status") + return c.E500M("Failed to set class status", err) + } + + c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id) + os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model")) + os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras")) + + if err = model.UpdateStatus(c, READY); err != nil { + model.UpdateStatus(c, FAILED_TRAINING) + task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") + return c.E500M("Failed to update status of model", err) + } + + task.UpdateStatusLog(c, TASK_DONE, "Model finished training") + + mutex := x.DataMap["runners_mutex"].(*sync.Mutex) + mutex.Lock() + defer mutex.Unlock() + + var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{}) + var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{}) + runner_data["task"] = nil + runners[dat.Id] = runner_data + x.DataMap["runners"] = runners + + return c.SendJSON("Ok") + }) + + PostAuthJson(x, "/tasks/runner/retrain/done", User_Normal, func(c *Context, dat *VerifyTask) *Error { + _, error := verifyRunner(c, &JustId{Id: dat.Id}) + if error != nil { + return error + } + + task, error := verifyTask(x, c, dat) + if error != nil { + return error + } + + if task.TaskType != int(TASK_TYPE_RETRAINING) { + c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) + return c.JsonBadRequest("Task is not the right type go get the definitions") + } + + model, err := GetBaseModel(c, *task.ModelId) + if err != nil { + c.Logger.Error("Failed to get model", "err", err) + return c.E500M("Failed to get mode", err) + } + + err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) + if err != nil { + return c.E500M("Failed to set class status", err) + } + + _, err = c.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id) + if err != nil { + return c.E500M("Failed to set head status", err) + } + + err = model.UpdateStatus(c, READY) + if err != nil { + return c.E500M("Failed to set class status", err) + } + + task.UpdateStatusLog(c, TASK_DONE, "Model finished training") + clearRunnerTask(x, dat.Id) + + return c.SendJSON("Ok") + }) + } diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go index 01bf4ec..38c3cdc 100644 --- a/logic/tasks/runner/runner.go +++ b/logic/tasks/runner/runner.go @@ -120,6 +120,12 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas defer mutex.Unlock() switch task.TaskType { + case int(TASK_TYPE_RETRAINING): + runners := handler.DataMap["runners"].(map[string]interface{}) + runner := runners[runner_id].(map[string]interface{}) + runner["task"] = &task + runners[runner_id] = runner + handler.DataMap["runners"] = runners case int(TASK_TYPE_TRAINING): if err := PrepareTraining(handler, base, task, runner_id); err != nil { logger.Error("Failed to prepare for training", "err", err)