runner-go #102
@ -18,10 +18,10 @@ const (
|
||||
type Layer struct {
|
||||
Id string `db:"mdl.id" json:"id"`
|
||||
DefinitionId string `db:"mdl.def_id" json:"definition_id"`
|
||||
LayerOrder string `db:"mdl.layer_order" json:"layer_order"`
|
||||
LayerOrder int `db:"mdl.layer_order" json:"layer_order"`
|
||||
LayerType LayerType `db:"mdl.layer_type" json:"layer_type"`
|
||||
Shape string `db:"mdl.shape" json:"shape"`
|
||||
ExpType string `db:"mdl.exp_type" json:"exp_type"`
|
||||
ExpType int `db:"mdl.exp_type" json:"exp_type"`
|
||||
}
|
||||
|
||||
func ShapeToString(args ...int) string {
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
@ -39,7 +40,6 @@ func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (
|
||||
}
|
||||
|
||||
if model.ModelType == 2 {
|
||||
panic("TODO")
|
||||
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
|
||||
if full_error != nil {
|
||||
l.Error("Failed to generate defintions", "err", full_error)
|
||||
@ -75,4 +75,44 @@ func CleanUpFailed(b BasePack, task *Task) {
|
||||
l.Error("Failed to get status", err)
|
||||
}
|
||||
}
|
||||
// Set the class status to trained
|
||||
err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
l.Error("Failed to set class status")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func CleanUpFailedRetrain(b BasePack, task *Task) {
|
||||
db := b.GetDb()
|
||||
l := b.GetLogger()
|
||||
|
||||
model, err := GetBaseModel(db, *task.ModelId)
|
||||
if err != nil {
|
||||
l.Error("Failed to get model", "err", err)
|
||||
} else {
|
||||
err = model.UpdateStatus(db, FAILED_TRAINING)
|
||||
if err != nil {
|
||||
l.Error("Failed to get status", err)
|
||||
}
|
||||
}
|
||||
|
||||
ResetClasses(b, model)
|
||||
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
|
||||
|
||||
var defData struct {
|
||||
Id string `db:"md.id"`
|
||||
TargetAcuuracy float64 `db:"md.target_accuracy"`
|
||||
}
|
||||
|
||||
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
|
||||
if err != nil {
|
||||
log.Error("failed to get def data", err)
|
||||
return
|
||||
}
|
||||
|
||||
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
|
||||
if err_ != nil {
|
||||
panic(err_)
|
||||
}
|
||||
}
|
||||
|
@ -101,6 +101,10 @@ func setModelClassStatus(c BasePack, status ModelClassStatus, filter string, arg
|
||||
return
|
||||
}
|
||||
|
||||
func SetModelClassStatus(c BasePack, status ModelClassStatus, filter string, args ...any) (err error) {
|
||||
return setModelClassStatus(c, status, filter, args...)
|
||||
}
|
||||
|
||||
func generateCvsExp(c BasePack, run_path string, model_id string, doPanic bool) (count int, err error) {
|
||||
|
||||
db := c.GetDb()
|
||||
@ -1090,7 +1094,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = splitModel(c, model); err != nil {
|
||||
if err = SplitModel(c, model); err != nil {
|
||||
err = setModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
l.Error("Failed to split the model! And Failed to set class status")
|
||||
@ -1123,7 +1127,7 @@ func trainModelExp(c BasePack, model *BaseModel) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func splitModel(c BasePack, model *BaseModel) (err error) {
|
||||
func SplitModel(c BasePack, model *BaseModel) (err error) {
|
||||
db := c.GetDb()
|
||||
l := c.GetLogger()
|
||||
|
||||
@ -1260,7 +1264,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
|
||||
}
|
||||
|
||||
db := c.GetDb()
|
||||
l := c.GetLogger()
|
||||
|
||||
def, err := MakeDefenition(db, model.Id, target_accuracy)
|
||||
if err != nil {
|
||||
@ -1279,33 +1282,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
|
||||
}
|
||||
order++
|
||||
|
||||
if complexity == 0 {
|
||||
/*
|
||||
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
order++
|
||||
*/
|
||||
|
||||
_, err = def.MakeLayer(db, order, LAYER_FLATTEN, "")
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
order++
|
||||
|
||||
loop := int(math.Log2(float64(number_of_classes)))
|
||||
for i := 0; i < loop; i++ {
|
||||
_, err = def.MakeLayer(db, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)))
|
||||
order++
|
||||
if err != nil {
|
||||
ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else if complexity == 1 || complexity == 2 {
|
||||
loop := max(1, int((math.Log(float64(model.Width)) / math.Log(float64(10)))))
|
||||
for i := 0; i < loop; i++ {
|
||||
_, err = def.MakeLayer(db, order, LAYER_SIMPLE_BLOCK, "")
|
||||
@ -1335,11 +1311,6 @@ func generateDefinition(c BasePack, model *BaseModel, target_accuracy int, numbe
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
l.Error("Unkown complexity", "complexity", complexity)
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
return def.UpdateStatus(db, DEFINITION_STATUS_INIT)
|
||||
}
|
||||
@ -1410,19 +1381,12 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
|
||||
|
||||
order := 1
|
||||
|
||||
width := model.Width
|
||||
height := model.Height
|
||||
|
||||
// Note the shape of the first layer defines the import size
|
||||
if complexity == 2 {
|
||||
// Note the shape for now is no used
|
||||
width := int(math.Pow(2, math.Floor(math.Log(float64(model.Width))/math.Log(2.0))))
|
||||
height := int(math.Pow(2, math.Floor(math.Log(float64(model.Height))/math.Log(2.0))))
|
||||
l.Warn("Complexity 2 creating model with smaller size", "width", width, "height", height)
|
||||
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, ShapeToString(3, model.Width, model.Height), 1)
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
err = MakeLayerExpandable(c.GetDb(), def_id, order, LAYER_INPUT, fmt.Sprintf("%d,%d,1", width, height), 1)
|
||||
|
||||
order++
|
||||
|
||||
// handle the errors inside the pervious if block
|
||||
@ -1460,7 +1424,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
|
||||
order++
|
||||
|
||||
// Flatten the blocks into dense
|
||||
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*2), 1)
|
||||
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*2), 1)
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
@ -1474,7 +1438,7 @@ func generateExpandableDefinition(c BasePack, model *BaseModel, target_accuracy
|
||||
loop = max(loop, 3)
|
||||
|
||||
for i := 0; i < loop; i++ {
|
||||
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)*2), 2)
|
||||
err = MakeLayerExpandable(db, def_id, order, LAYER_DENSE, ShapeToString(number_of_classes*(loop-i)*2), 2)
|
||||
order++
|
||||
if err != nil {
|
||||
failed()
|
||||
@ -1747,6 +1711,13 @@ func RunTaskRetrain(b BasePack, task Task) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = db.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id)
|
||||
if err != nil {
|
||||
l.Error("Error while updating the classes", "error", err)
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
|
||||
|
||||
return
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
"github.com/charmbracelet/log"
|
||||
)
|
||||
|
||||
func verifyRunner(c *Context, dat *JustId) (runner *Runner, e *Error) {
|
||||
@ -34,6 +35,12 @@ type VerifyTask struct {
|
||||
TaskId string `json:"taskId" validate:"required"`
|
||||
}
|
||||
|
||||
type RunnerTrainDef struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
TaskId string `json:"taskId" validate:"required"`
|
||||
DefId string `json:"defId" validate:"required"`
|
||||
}
|
||||
|
||||
func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Error) {
|
||||
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
||||
mutex.Lock()
|
||||
@ -53,6 +60,18 @@ func verifyTask(x *Handle, c *Context, dat *VerifyTask) (task *Task, error *Erro
|
||||
return runner_data["task"].(*Task), nil
|
||||
}
|
||||
|
||||
func clearRunnerTask(x *Handle, runner_id string) {
|
||||
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
|
||||
var runner_data map[string]interface{} = runners[runner_id].(map[string]interface{})
|
||||
runner_data["task"] = nil
|
||||
runners[runner_id] = runner_data
|
||||
x.DataMap["runners"] = runners
|
||||
}
|
||||
|
||||
func handleRemoteRunner(x *Handle) {
|
||||
|
||||
type RegisterRunner struct {
|
||||
@ -202,6 +221,8 @@ func handleRemoteRunner(x *Handle) {
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
CleanUpFailed(c, task)
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
CleanUpFailedRetrain(c, task)
|
||||
case int(TASK_TYPE_CLASSIFICATION):
|
||||
// DO nothing
|
||||
default:
|
||||
@ -237,6 +258,8 @@ func handleRemoteRunner(x *Handle) {
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
status = DEFINITION_STATUS_INIT
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
fallthrough
|
||||
case int(TASK_TYPE_CLASSIFICATION):
|
||||
status = DEFINITION_STATUS_READY
|
||||
default:
|
||||
@ -268,27 +291,35 @@ func handleRemoteRunner(x *Handle) {
|
||||
return error
|
||||
}
|
||||
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
//DO NOTHING
|
||||
case int(TASK_TYPE_CLASSIFICATION):
|
||||
//DO NOTHING
|
||||
default:
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c, *task.ModelId)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to get model information", err)
|
||||
}
|
||||
|
||||
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TO_TRAIN)
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
classes, err := model.GetClasses(c, "and status in ($2, $3) order by mc.class_order asc", CLASS_STATUS_TO_TRAIN, CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to get the model classes", err)
|
||||
}
|
||||
|
||||
return c.SendJSON(classes)
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to get the model classes", err)
|
||||
}
|
||||
return c.SendJSON(classes)
|
||||
case int(TASK_TYPE_CLASSIFICATION):
|
||||
classes, err := model.GetClasses(c, "and status=$2 order by mc.class_order asc", CLASS_STATUS_TRAINED)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to get the model classes", err)
|
||||
}
|
||||
return c.SendJSON(classes)
|
||||
default:
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
type RunnerTrainDefStatus struct {
|
||||
@ -326,12 +357,13 @@ func handleRemoteRunner(x *Handle) {
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
|
||||
type RunnerTrainDefLayers struct {
|
||||
type RunnerTrainDefHeadStatus struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
TaskId string `json:"taskId" validate:"required"`
|
||||
DefId string `json:"defId" validate:"required"`
|
||||
Status ModelHeadStatus `json:"status" validate:"required"`
|
||||
}
|
||||
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDefLayers) *Error {
|
||||
PostAuthJson(x, "/tasks/runner/train/def/head/status", User_Normal, func(c *Context, dat *RunnerTrainDefHeadStatus) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
return error
|
||||
@ -352,6 +384,69 @@ func handleRemoteRunner(x *Handle) {
|
||||
return c.E500M("Failed to get definition information", err)
|
||||
}
|
||||
|
||||
_, err = c.Exec("update exp_model_head set status=$1 where def_id=$2;", dat.Status, def.Id)
|
||||
if err != nil {
|
||||
log.Error("Failed to train definition!")
|
||||
return c.E500M("Failed to train definition", err)
|
||||
}
|
||||
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
|
||||
type RunnerRetrainDefHeadStatus struct {
|
||||
Id string `json:"id" validate:"required"`
|
||||
TaskId string `json:"taskId" validate:"required"`
|
||||
HeadId string `json:"defId" validate:"required"`
|
||||
Status ModelHeadStatus `json:"status" validate:"required"`
|
||||
}
|
||||
PostAuthJson(x, "/tasks/runner/retrain/def/head/status", User_Normal, func(c *Context, dat *RunnerRetrainDefHeadStatus) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
if task.TaskType != int(TASK_TYPE_RETRAINING) {
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
|
||||
if err := UpdateStatus(c.GetDb(), "exp_model_head", dat.HeadId, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||
return c.E500M("Failed to update head status", err)
|
||||
}
|
||||
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
PostAuthJson(x, "/tasks/runner/train/def/layers", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
// Do nothing
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
// Do nothing
|
||||
default:
|
||||
c.Logger.Error("Task not is not the right type to get the layers", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the layers")
|
||||
}
|
||||
|
||||
def, err := GetDefinition(c, dat.DefId)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to get definition information", err)
|
||||
}
|
||||
|
||||
layers, err := def.GetLayers(c, " order by layer_order asc")
|
||||
if err != nil {
|
||||
return c.E500M("Failed to get layers", err)
|
||||
@ -371,7 +466,12 @@ func handleRemoteRunner(x *Handle) {
|
||||
return error
|
||||
}
|
||||
|
||||
if task.TaskType != int(TASK_TYPE_TRAINING) {
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
// DO nothing
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
// DO nothing
|
||||
default:
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
@ -486,6 +586,8 @@ func handleRemoteRunner(x *Handle) {
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
//DO NOTHING
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
//DO NOTHING
|
||||
case int(TASK_TYPE_CLASSIFICATION):
|
||||
//DO NOTHING
|
||||
default:
|
||||
@ -501,6 +603,74 @@ func handleRemoteRunner(x *Handle) {
|
||||
return c.SendJSON(model)
|
||||
})
|
||||
|
||||
PostAuthJson(x, "/tasks/runner/heads", User_Normal, func(c *Context, dat *RunnerTrainDef) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
task, error := verifyTask(x, c, &VerifyTask{Id: dat.Id, TaskId: dat.TaskId})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
type ExpHead struct {
|
||||
Id string `json:"id"`
|
||||
Start int `db:"range_start" json:"start"`
|
||||
End int `db:"range_end" json:"end"`
|
||||
}
|
||||
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
fallthrough
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
// status = 2 (INIT) 3 (TRAINING)
|
||||
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and status in (2,3)", dat.DefId)
|
||||
if err != nil {
|
||||
return c.E500M("Failed getting active heads", err)
|
||||
}
|
||||
return c.SendJSON(heads)
|
||||
case int(TASK_TYPE_CLASSIFICATION):
|
||||
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1", dat.DefId)
|
||||
if err != nil {
|
||||
return c.E500M("Failed getting active heads", err)
|
||||
}
|
||||
return c.SendJSON(heads)
|
||||
default:
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
})
|
||||
|
||||
PostAuthJson(x, "/tasks/runner/train_exp/class/status/train", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
task, error := verifyTask(x, c, dat)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
if task.TaskType != int(TASK_TYPE_TRAINING) {
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c, *task.ModelId)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to get model", err)
|
||||
}
|
||||
|
||||
err = SetModelClassStatus(c, CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TO_TRAIN)
|
||||
if err != nil {
|
||||
return c.E500M("Failed update status", err)
|
||||
}
|
||||
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
|
||||
PostAuthJson(x, "/tasks/runner/train/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
@ -568,6 +738,13 @@ func handleRemoteRunner(x *Handle) {
|
||||
return c.E500M("Failed to delete unsed definitions", err)
|
||||
}
|
||||
|
||||
// Set the class status to trained
|
||||
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to set class status")
|
||||
return c.E500M("Failed to set class status", err)
|
||||
}
|
||||
|
||||
if err = model.UpdateStatus(c, READY); err != nil {
|
||||
model.UpdateStatus(c, FAILED_TRAINING)
|
||||
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
|
||||
@ -576,16 +753,7 @@ func handleRemoteRunner(x *Handle) {
|
||||
|
||||
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
|
||||
|
||||
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
|
||||
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
|
||||
runner_data["task"] = nil
|
||||
runners[dat.Id] = runner_data
|
||||
x.DataMap["runners"] = runners
|
||||
|
||||
clearRunnerTask(x, dat.Id)
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
|
||||
@ -635,4 +803,153 @@ func handleRemoteRunner(x *Handle) {
|
||||
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
|
||||
PostAuthJson(x, "/tasks/runner/train_exp/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
task, error := verifyTask(x, c, dat)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
if task.TaskType != int(TASK_TYPE_TRAINING) {
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c, *task.ModelId)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to get model", "err", err)
|
||||
return c.E500M("Failed to get mode", err)
|
||||
}
|
||||
|
||||
// TODO add check the to the model
|
||||
|
||||
var def Definition
|
||||
err = GetDBOnce(c, &def, "model_definition as md where model_id=$1 and status=$2 order by accuracy desc limit 1;", task.ModelId, DEFINITION_STATUS_TRANIED)
|
||||
if err == NotFoundError {
|
||||
// TODO Make the Model status have a message
|
||||
c.Logger.Error("All definitions failed to train!")
|
||||
model.UpdateStatus(c, FAILED_TRAINING)
|
||||
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "All definition failed to train!")
|
||||
clearRunnerTask(x, dat.Id)
|
||||
return c.SendJSON("Ok")
|
||||
} else if err != nil {
|
||||
model.UpdateStatus(c, FAILED_TRAINING)
|
||||
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to get model definition")
|
||||
return c.E500M("Failed to get model definition", err)
|
||||
}
|
||||
|
||||
if err = def.UpdateStatus(c, DEFINITION_STATUS_READY); err != nil {
|
||||
model.UpdateStatus(c, FAILED_TRAINING)
|
||||
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to update model definition")
|
||||
return c.E500M("Failed to update model definition", err)
|
||||
}
|
||||
|
||||
to_delete, err := GetDbMultitple[JustId](c, "model_definition where status!=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
|
||||
if err != nil {
|
||||
c.GetLogger().Error("Failed to select model_definition to delete")
|
||||
return c.E500M("Failed to select model definition to delete", err)
|
||||
}
|
||||
|
||||
for _, d := range to_delete {
|
||||
os.RemoveAll(path.Join("savedData", model.Id, "defs", d.Id))
|
||||
}
|
||||
|
||||
// TODO Check if returning also works here
|
||||
if _, err = c.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
|
||||
model.UpdateStatus(c, FAILED_TRAINING)
|
||||
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
|
||||
return c.E500M("Failed to delete unsed definitions", err)
|
||||
}
|
||||
|
||||
if err = SplitModel(c, model); err != nil {
|
||||
err = SetModelClassStatus(c, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to split the model! And Failed to set class status")
|
||||
return c.E500M("Failed to split the model", err)
|
||||
}
|
||||
|
||||
c.Logger.Error("Failed to split the model")
|
||||
return c.E500M("Failed to split the model", err)
|
||||
}
|
||||
|
||||
// Set the class status to trained
|
||||
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to set class status")
|
||||
return c.E500M("Failed to set class status", err)
|
||||
}
|
||||
|
||||
c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
|
||||
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
|
||||
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
|
||||
|
||||
if err = model.UpdateStatus(c, READY); err != nil {
|
||||
model.UpdateStatus(c, FAILED_TRAINING)
|
||||
task.UpdateStatusLog(c, TASK_FAILED_RUNNING, "Failed to delete unsed definitions")
|
||||
return c.E500M("Failed to update status of model", err)
|
||||
}
|
||||
|
||||
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
|
||||
|
||||
mutex := x.DataMap["runners_mutex"].(*sync.Mutex)
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
var runners map[string]interface{} = x.DataMap["runners"].(map[string]interface{})
|
||||
var runner_data map[string]interface{} = runners[dat.Id].(map[string]interface{})
|
||||
runner_data["task"] = nil
|
||||
runners[dat.Id] = runner_data
|
||||
x.DataMap["runners"] = runners
|
||||
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
|
||||
PostAuthJson(x, "/tasks/runner/retrain/done", User_Normal, func(c *Context, dat *VerifyTask) *Error {
|
||||
_, error := verifyRunner(c, &JustId{Id: dat.Id})
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
task, error := verifyTask(x, c, dat)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
if task.TaskType != int(TASK_TYPE_RETRAINING) {
|
||||
c.Logger.Error("Task not is not the right type to get the definitions", "task type", task.TaskType)
|
||||
return c.JsonBadRequest("Task is not the right type go get the definitions")
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c, *task.ModelId)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to get model", "err", err)
|
||||
return c.E500M("Failed to get mode", err)
|
||||
}
|
||||
|
||||
err = SetModelClassStatus(c, CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to set class status", err)
|
||||
}
|
||||
|
||||
_, err = c.Exec("update exp_model_head set status=$1 where status=$2 and model_id=$3", MODEL_HEAD_STATUS_READY, MODEL_HEAD_STATUS_TRAINING, model.Id)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to set head status", err)
|
||||
}
|
||||
|
||||
err = model.UpdateStatus(c, READY)
|
||||
if err != nil {
|
||||
return c.E500M("Failed to set class status", err)
|
||||
}
|
||||
|
||||
task.UpdateStatusLog(c, TASK_DONE, "Model finished training")
|
||||
clearRunnerTask(x, dat.Id)
|
||||
|
||||
return c.SendJSON("Ok")
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -120,6 +120,12 @@ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Tas
|
||||
defer mutex.Unlock()
|
||||
|
||||
switch task.TaskType {
|
||||
case int(TASK_TYPE_RETRAINING):
|
||||
runners := handler.DataMap["runners"].(map[string]interface{})
|
||||
runner := runners[runner_id].(map[string]interface{})
|
||||
runner["task"] = &task
|
||||
runners[runner_id] = runner
|
||||
handler.DataMap["runners"] = runners
|
||||
case int(TASK_TYPE_TRAINING):
|
||||
if err := PrepareTraining(handler, base, task, runner_id); err != nil {
|
||||
logger.Error("Failed to prepare for training", "err", err)
|
||||
|
Loading…
Reference in New Issue
Block a user