runner-go #102

Merged
andr3 merged 9 commits from runner-go into main 2024-05-10 02:13:02 +01:00
5 changed files with 436 additions and 102 deletions
Showing only changes of commit a0aed71b3c - Show all commits

View File

@ -18,10 +18,10 @@ const (
type Layer struct { type Layer struct {
Id string `db:"mdl.id" json:"id"` Id string `db:"mdl.id" json:"id"`
DefinitionId string `db:"mdl.def_id" json:"definition_id"` DefinitionId string `db:"mdl.def_id" json:"definition_id"`
LayerOrder string `db:"mdl.layer_order" json:"layer_order"` LayerOrder int `db:"mdl.layer_order" json:"layer_order"`
LayerType LayerType `db:"mdl.layer_type" json:"layer_type"` LayerType LayerType `db:"mdl.layer_type" json:"layer_type"`
Shape string `db:"mdl.shape" json:"shape"` Shape string `db:"mdl.shape" json:"shape"`
ExpType string `db:"mdl.exp_type" json:"exp_type"` ExpType int `db:"mdl.exp_type" json:"exp_type"`
} }
func ShapeToString(args ...int) string { func ShapeToString(args ...int) string {

View File

@ -6,6 +6,7 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
"github.com/goccy/go-json" "github.com/goccy/go-json"
) )
@ -39,7 +40,6 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (
} }
if model.ModelType == 2 { if model.ModelType == 2 {
panic("TODO")
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if full_error != nil { if full_error != nil {
l.Error("Failed to generate defintions", "err", full_error) l.Error("Failed to generate defintions", "err", full_error)
@ -75,4 +75,44 @@ func CleanUpFailed(b BasePack, task *Task) {
l.Error("Failed to get status", err) l.Error("Failed to get status", err)
} }
} }
// Set the class status to trained
err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
l.Error("Failed to set class status")
return
}
}
func CleanUpFailedRetrain(b BasePack, task *Task) {
db := b.GetDb()
l := b.GetLogger()
model, err := GetBaseModel(db, *task.ModelId)
if err != nil {
l.Error("Failed to get model", "err", err)
} else {
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("Failed to get status", err)
}
}
ResetClasses(b, model)
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
var defData struct {
Id string `db:"md.id"`
TargetAcuuracy float64 `db:"md.target_accuracy"`
}
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
if err != nil {
log.Error("failed to get def data", err)
return
}
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
if err_ != nil {
panic(err_)
}
} }

View File

@ -101,6 +101,10 @@ func setModelClassStatus(c BasePack, status ModelClassStatus, filter string, arg
return return
} }
func SetModelClassStatus(c BasePack, status ModelClassStatus, filter string, args ...any) (err error) {
return setModelClassStatus(c, status, filter, args...)
}
func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) { func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) {
db := c.GetDb() db := c.GetDb()
@ -1090,7 +1094,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return err return err
} }
if err = splitModel(c, model); err != nil { if err = SplitModel(c, model); err != nil {
err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
l.Error("Failed to split the model! And Failed to set class status") l.Error("Failed to split the model! And Failed to set class status")
@ -1123,7 +1127,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
return return
} }
func splitModel(c BasePack, model *BaseModel) (err error) { func SplitModel(c BasePack, model *BaseModel) (err error) {
db := c.GetDb() db := c.GetDb()
l := c.GetLogger() l := c.GetLogger()
@ -1260,7 +1264,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
} }
db := c.GetDb() db := c.GetDb()
l := c.GetLogger()
def, err := MakeDefenition(db, model.Id, target_accuracy) def, err := MakeDefenition(db, model.Id, target_accuracy)
if err != nil { if err != nil {
@ -1279,33 +1282,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
} }
order++ order++
if complexity == 0 {
/*
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
if err != nil {
failed()
return
}
order++
*/
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
if err != nil {
failed()
return
}
order++
loop := int(math.Log2(float64(number_of_classes)))
for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
order++
if err != nil {
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
return
}
}
} else if complexity == 1 || complexity == 2 {
loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10))))) loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "") _, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
@ -1335,11 +1311,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
return return
} }
} }
} else {
l.Error("Unkown complexity", "complexity", complexity)
failed()
return
}
return def.UpdateStatus(db, DEFINITION_STATUS_INIT) return def.UpdateStatus(db, DEFINITION_STATUS_INIT)
} }
@ -1410,19 +1381,12 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
order := 1 order := 1
width := model.Width err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height), 1)
height := model.Height if err != nil {
failed()
// Note the shape of the first layer defines the import size return
if complexity == 2 {
// Note the shape for now is no used
width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0))))
height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0))))
l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
} }
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1)
order++ order++
// handle the errors inside the pervious if block // handle the errors inside the pervious if block
@ -1460,7 +1424,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
order++ order++
// Flatten the blocks into dense // Flatten the blocks into dense
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1) err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*2), 1)
if err != nil { if err != nil {
failed() failed()
return return
@ -1474,7 +1438,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
loop = max(loop, 3) loop = max(loop, 3)
for i := 0; i < loop; i++ { for i := 0; i < loop; i++ {
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)*2), 2) err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)*2), 2)
order++ order++
if err != nil { if err != nil {
failed() failed()
@ -1747,6 +1711,13 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
return return
} }
_, err = db.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id)
if err != nil {
l.Error("Error while updating the classes", "error", err)
failed()
return
}
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining") task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
return return

View File

@ -10,6 +10,7 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
) )
func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) { func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) {
@ -34,6 +35,12 @@ type VerifyTask struct {
TaskId string `json:"taskId" validate:"required"` TaskId string `json:"taskId" validate:"required"`
} }
type RunnerTrainDef struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"`
}
func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) { func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) {
mutex := x.DataMap["runners_mutex"].(*sync.Mutex) mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock() mutex.Lock()
@ -53,6 +60,18 @@ func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Erro
return runner_data["task"].(*Task), nil return runner_data["task"].(*Task), nil
} }
func clearRunnerTask(x *Handle, runner_id string) {
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[runner_id].(map[string]interface{})
runner_data["task"] = nil
runners[runner_id] = runner_data
x.DataMap["runners"] = runners
}
func handleRemoteRunner(x *Handle) { func handleRemoteRunner(x *Handle) {
type RegisterRunner struct { type RegisterRunner struct {
@ -202,6 +221,8 @@ func handleRemoteRunner(x *Handle) {
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
CleanUpFailed(c, task) CleanUpFailed(c, task)
case int(TASK_TYPE_RETRAINING):
CleanUpFailedRetrain(c, task)
case int(TASK_TYPE_CLASSIFICATION): case int(TASK_TYPE_CLASSIFICATION):
// DO nothing // DO nothing
default: default:
@ -237,6 +258,8 @@ func handleRemoteRunner(x *Handle) {
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
status = DEFINITION_STATUS_INIT status = DEFINITION_STATUS_INIT
case int(TASK_TYPE_RETRAINING):
fallthrough
case int(TASK_TYPE_CLASSIFICATION): case int(TASK_TYPE_CLASSIFICATION):
status = DEFINITION_STATUS_READY status = DEFINITION_STATUS_READY
default: default:
@ -268,27 +291,35 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
//DO NOTHING
case int(TASK_TYPE_CLASSIFICATION):
//DO NOTHING
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId) model, err := GetBaseModel(c, *task.ModelId)
if err != nil { if err != nil {
return c.E500M("Failed to get model information", err) return c.E500M("Failed to get model information", err)
} }
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN) switch task.TaskType {
case int(TASK_TYPE_TRAINING):
classes, err := model.GetClasses(c, "and status in ($2, $3) order by mc.class_order asc", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
return c.E500M("Failed to get the model classes", err) return c.E500M("Failed to get the model classes", err)
} }
return c.SendJSON(classes) return c.SendJSON(classes)
case int(TASK_TYPE_RETRAINING):
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINING)
if err != nil {
return c.E500M("Failed to get the model classes", err)
}
return c.SendJSON(classes)
case int(TASK_TYPE_CLASSIFICATION):
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINED)
if err != nil {
return c.E500M("Failed to get the model classes", err)
}
return c.SendJSON(classes)
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
}) })
type RunnerTrainDefStatus struct { type RunnerTrainDefStatus struct {
@ -326,12 +357,13 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON("Ok") return c.SendJSON("Ok")
}) })
type RunnerTrainDefLayers struct { type RunnerTrainDefHeadStatus struct {
Id string `json:"id" validate:"required"` Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"` TaskId string `json:"taskId" validate:"required"`
DefId string `json:"defId" validate:"required"` DefId string `json:"defId" validate:"required"`
Status ModelHeadStatus `json:"status" validate:"required"`
} }
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDefLayers) *Error { PostAuthJson(x, "/tasks/runner/train/def/head/status", User_Normal, func(c *Context, dat *RunnerTrainDefHeadStatus) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
return error return error
@ -352,6 +384,69 @@ func handleRemoteRunner(x *Handle) {
return c.E500M("Failed to get definition information", err) return c.E500M("Failed to get definition information", err)
} }
_, err = c.Exec("update exp_model_head set status=$1 where def_id=$2;", dat.Status, def.Id)
if err != nil {
log.Error("Failed to train definition!")
return c.E500M("Failed to train definition", err)
}
return c.SendJSON("Ok")
})
type RunnerRetrainDefHeadStatus struct {
Id string `json:"id" validate:"required"`
TaskId string `json:"taskId" validate:"required"`
HeadId string `json:"defId" validate:"required"`
Status ModelHeadStatus `json:"status" validate:"required"`
}
PostAuthJson(x, "/tasks/runner/retrain/def/head/status", User_Normal, func(c *Context, dat *RunnerRetrainDefHeadStatus) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_RETRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
if err := UpdateStatus(c.GetDb(), "exp_model_head", dat.HeadId, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return c.E500M("Failed to update head status", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
// Do nothing
case int(TASK_TYPE_RETRAINING):
// Do nothing
default:
c.Logger.Error("Task not is not the right type to get the layers", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the layers")
}
def, err := GetDefinition(c, dat.DefId)
if err != nil {
return c.E500M("Failed to get definition information", err)
}
layers, err := def.GetLayers(c, " order by layer_order asc") layers, err := def.GetLayers(c, " order by layer_order asc")
if err != nil { if err != nil {
return c.E500M("Failed to get layers", err) return c.E500M("Failed to get layers", err)
@ -371,7 +466,12 @@ func handleRemoteRunner(x *Handle) {
return error return error
} }
if task.TaskType != int(TASK_TYPE_TRAINING) { switch task.TaskType {
case int(TASK_TYPE_TRAINING):
// DO nothing
case int(TASK_TYPE_RETRAINING):
// DO nothing
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType) c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions") return c.JsonBadRequest("Task is not the right type go get the definitions")
} }
@ -486,6 +586,8 @@ func handleRemoteRunner(x *Handle) {
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
//DO NOTHING //DO NOTHING
case int(TASK_TYPE_RETRAINING):
//DO NOTHING
case int(TASK_TYPE_CLASSIFICATION): case int(TASK_TYPE_CLASSIFICATION):
//DO NOTHING //DO NOTHING
default: default:
@ -501,6 +603,74 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON(model) return c.SendJSON(model)
}) })
PostAuthJson(x, "/tasks/runner/heads", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
if error != nil {
return error
}
type ExpHead struct {
Id string `json:"id"`
Start int `db:"range_start" json:"start"`
End int `db:"range_end" json:"end"`
}
switch task.TaskType {
case int(TASK_TYPE_TRAINING):
fallthrough
case int(TASK_TYPE_RETRAINING):
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and status in (2,3)", dat.DefId)
if err != nil {
return c.E500M("Failed getting active heads", err)
}
return c.SendJSON(heads)
case int(TASK_TYPE_CLASSIFICATION):
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1", dat.DefId)
if err != nil {
return c.E500M("Failed getting active heads", err)
}
return c.SendJSON(heads)
default:
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
})
PostAuthJson(x, "/tasks/runner/train_exp/class/status/train", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
return c.E500M("Failed to get model", err)
}
err = SetModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN)
if err != nil {
return c.E500M("Failed update status", err)
}
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error { PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id}) _, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil { if error != nil {
@ -568,6 +738,13 @@ func handleRemoteRunner(x *Handle) {
return c.E500M("Failed to delete unsed definitions", err) return c.E500M("Failed to delete unsed definitions", err)
} }
// Set the class status to trained
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1;", model.Id)
if err != nil {
c.Logger.Error("Failed to set class status")
return c.E500M("Failed to set class status", err)
}
if err = model.UpdateStatus(c, READY); err != nil { if err = model.UpdateStatus(c, READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING) model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions") task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
@ -576,16 +753,7 @@ func handleRemoteRunner(x *Handle) {
task.UpdateStatusLog(c, TASK_DONE, "Model finished training") task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
mutex := x.DataMap["runners_mutex"].(*sync.Mutex) clearRunnerTask(x, dat.Id)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
runner_data["task"] = nil
runners[dat.Id] = runner_data
x.DataMap["runners"] = runners
return c.SendJSON("Ok") return c.SendJSON("Ok")
}) })
@ -635,4 +803,153 @@ func handleRemoteRunner(x *Handle) {
return c.SendJSON("Ok") return c.SendJSON("Ok")
}) })
PostAuthJson(x, "/tasks/runner/train_exp/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_TRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
// TODO add check the to the model
var def Definition
err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// TODO Make the Model status have a message
c.Logger.Error("All definitions failed to train!")
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
} else if err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition")
return c.E500M("Failed to get model definition", err)
}
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition")
return c.E500M("Failed to update model definition", err)
}
to_delete, err := GetDbMultitple[JustId](c, "model_definition where status!=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
c.GetLogger().Error("Failed to select model_definition to delete")
return c.E500M("Failed to select model definition to delete", err)
}
for _, d := range to_delete {
os.RemoveAll(path.Join("savedData", model.Id, "defs", d.Id))
}
// TODO Check if returning also works here
if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to delete unsed definitions", err)
}
if err = SplitModel(c, model); err != nil {
err = SetModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
c.Logger.Error("Failed to split the model! And Failed to set class status")
return c.E500M("Failed to split the model", err)
}
c.Logger.Error("Failed to split the model")
return c.E500M("Failed to split the model", err)
}
// Set the class status to trained
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
c.Logger.Error("Failed to set class status")
return c.E500M("Failed to set class status", err)
}
c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
if err = model.UpdateStatus(c, READY); err != nil {
model.UpdateStatus(c, FAILED_TRAINING)
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
return c.E500M("Failed to update status of model", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
mutex.Lock()
defer mutex.Unlock()
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
runner_data["task"] = nil
runners[dat.Id] = runner_data
x.DataMap["runners"] = runners
return c.SendJSON("Ok")
})
PostAuthJson(x, "/tasks/runner/retrain/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
_, error := verifyRunner(c, &JustId{Id: dat.Id})
if error != nil {
return error
}
task, error := verifyTask(x, c, dat)
if error != nil {
return error
}
if task.TaskType != int(TASK_TYPE_RETRAINING) {
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
return c.JsonBadRequest("Task is not the right type go get the definitions")
}
model, err := GetBaseModel(c, *task.ModelId)
if err != nil {
c.Logger.Error("Failed to get model", "err", err)
return c.E500M("Failed to get mode", err)
}
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
return c.E500M("Failed to set class status", err)
}
_, err = c.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id)
if err != nil {
return c.E500M("Failed to set head status", err)
}
err = model.UpdateStatus(c, READY)
if err != nil {
return c.E500M("Failed to set class status", err)
}
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
clearRunnerTask(x, dat.Id)
return c.SendJSON("Ok")
})
} }

View File

@ -120,6 +120,12 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas
defer mutex.Unlock() defer mutex.Unlock()
switch task.TaskType { switch task.TaskType {
case int(TASK_TYPE_RETRAINING):
runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task
runners[runner_id] = runner
handler.DataMap["runners"] = runners
case int(TASK_TYPE_TRAINING): case int(TASK_TYPE_TRAINING):
if err := PrepareTraining(handler, base, task, runner_id); err != nil { if err := PrepareTraining(handler, base, task, runner_id); err != nil {
logger.Error("Failed to prepare for training", "err", err) logger.Error("Failed to prepare for training", "err", err)