From 06642dcb1e635f5d787da0f32299008539eae095 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Tue, 16 Apr 2024 17:48:52 +0100 Subject: [PATCH] move model retrain to runners closes #94 --- config.toml | 17 +- docker-compose.yml | 6 +- logic/models/train/train.go | 164 +++++++++++------- logic/tasks/runner/runner.go | 8 + logic/tasks/utils/utils.go | 1 + logic/utils/config.go | 13 +- main.go | 4 +- .../edit/ModelDataPageStatsGraph.svelte | 42 +++-- .../src/routes/models/edit/TasksTable.svelte | 6 +- 9 files changed, 165 insertions(+), 96 deletions(-) diff --git a/config.toml b/config.toml index 4d604d5..494654a 100644 --- a/config.toml +++ b/config.toml @@ -1,14 +1,17 @@ -PORT=5002 +PORT = 5002 -HOSTNAME="https://testing.andr3h3nriqu3s.com" +HOSTNAME = "https://testing.andr3h3nriqu3s.com" -NUMBER_OF_WORKERS=20 +NUMBER_OF_WORKERS = 20 -SUPRESS_CUDA=1 +SUPRESS_CUDA = 1 [ServiceUser] -USER="service" +USER = "service" [Worker] -PULLING_TIME="500ms" -NUMBER_OF_WORKERS=1 +PULLING_TIME = "500ms" +NUMBER_OF_WORKERS = 1 + +[DB] +MAX_CONNECTIONS = 600 diff --git a/docker-compose.yml b/docker-compose.yml index d0ce937..9909955 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,11 +1,11 @@ -version: '3.1' +version: "3.1" services: db: image: docker.andr3h3nriqu3s.com/services/postgres - command: -c 'max_connections=400' + command: -c 'max_connections=600' restart: always environment: POSTGRES_PASSWORD: verysafepassword ports: - - "5432:5432" + - "5432:5432" diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 75570e9..af51834 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -281,16 +281,17 @@ func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_pr return } -func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) { +func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) { + l, db := c.GetLogger(), c.GetDb() var co struct { Count int `db:"count(*)"` } - err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING) + err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING) if err != nil { return } - c.Logger.Info("test here", "count", co) + l.Info("test here", "count", co) count_re = co.Count count := co.Count @@ -304,7 +305,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i return generateCvsExpandExp(c, run_path, model_id, offset, true) } - data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING) + data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING) if err != nil { return } @@ -338,7 +339,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i // This is to load some extra data so that the model has more things to train on // - data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10) + data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10) if err != nil { return } @@ -361,10 +362,11 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i return } -func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { +func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { accuracy = 0 - c.Logger.Warn("About to retrain model") + l := c.GetLogger() + l.Warn("About to retrain model") // Get untrained models heads @@ -375,7 +377,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string } // status = 2 (INIT) 3 (TRAINING) - heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id) + heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id) if err != nil { return } else if len(heads) == 0 { @@ -389,13 +391,13 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string exp := heads[0] - c.Logger.Info("Got exp head", "head", exp) + l.Info("Got exp head", "head", exp) - if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil { + if err = UpdateStatus(c.GetDb(), "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil { return } - layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) + layers, err := c.GetDb().Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { return } @@ -447,7 +449,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string LayerNum: i, }) - c.Logger.Info("Got layers", "layers", got) + l.Info("Got layers", "layers", got) // Generate run folder run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain") @@ -462,7 +464,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string return } - c.Logger.Info("Generated cvs", "classCount", classCount) + l.Info("Generated cvs", "classCount", classCount) // TODO update the run script @@ -473,7 +475,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string } defer f.Close() - c.Logger.Info("About to run python!") + l.Info("About to run python!") tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py") if err != nil { @@ -498,7 +500,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string "SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id), "Depth": classCount, "StartPoint": 0, - "Host": (*c.Handle).Config.Hostname, + "Host": c.GetHost(), }); err != nil { return } @@ -506,11 +508,11 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string // Run the command out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput() if err != nil { - c.Logger.Warn("Python failed to run", "err", err, "out", string(out)) + l.Warn("Python failed to run", "err", err, "out", string(out)) return } - c.Logger.Info("Python finished running") + l.Info("Python finished running") if err = os.MkdirAll(result_path, os.ModePerm); err != nil { return @@ -533,7 +535,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string } os.RemoveAll(run_path) - c.Logger.Info("Model finished training!", "accuracy", accuracy) + l.Info("Model finished training!", "accuracy", accuracy) return } @@ -1555,10 +1557,10 @@ func generateExpandableDefinitions(c BasePack, model *BaseModel, target_accuracy return nil } -func ResetClasses(c *Context, model *BaseModel) { - _, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id) +func ResetClasses(c BasePack, model *BaseModel) { + _, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id) if err != nil { - c.Logger.Error("Error while reseting the classes", "error", err) + c.GetLogger().Error("Error while reseting the classes", "error", err) } } @@ -1620,44 +1622,6 @@ func trainExpandable(c *Context, model *BaseModel) { ModelUpdateStatus(c, model.Id, READY) } -func trainRetrain(c *Context, model *BaseModel, defId string) { - var err error - - failed := func() { - ResetClasses(c, model) - ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED) - c.Logger.Error("Failed to retrain", "err", err) - return - } - - // This is something I have to check - acc, err := trainDefinitionExpandExp(c, model, defId, false) - if err != nil { - c.Logger.Error("Failed to retrain the model", "err", err) - failed() - return - - } - c.Logger.Info("Retrained model", "accuracy", acc) - - // TODO check accuracy - - err = UpdateStatus(c, "models", model.Id, READY) - if err != nil { - failed() - return - } - - c.Logger.Info("model updaded") - - _, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id) - if err != nil { - c.Logger.Error("Error while updating the classes", "error", err) - failed() - return - } -} - func RunTaskTrain(b BasePack, task Task) (err error) { l := b.GetLogger() @@ -1718,6 +1682,62 @@ func RunTaskTrain(b BasePack, task Task) (err error) { return } +func RunTaskRetrain(b BasePack, task Task) (err error) { + model, err := GetBaseModel(b.GetDb(), task.ModelId) + if err != nil { + return err + } else if model.Status != READY_RETRAIN { + return errors.New("Model in invalid status for re-training") + } + + l := b.GetLogger() + db := b.GetDb() + + failed := func() { + ResetClasses(b, model) + ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED) + task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining") + l.Error("Failed to retrain", "err", err) + } + + task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining") + + defId, err := GetDbVar[string](db, "md.id", "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId) + if err != nil { + failed() + return + } + + // This is something I have to check + acc, err := trainDefinitionExpandExp(b, model, *defId, false) + if err != nil { + failed() + return + } + + l.Info("Retrained model", "accuracy", acc) + + // TODO check accuracy + err = UpdateStatus(db, "models", model.Id, READY) + if err != nil { + failed() + return + } + + l.Info("Model updaded") + + _, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id) + if err != nil { + l.Error("Error while updating the classes", "error", err) + failed() + return + } + + task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining") + + return +} + func handleTrain(handle *Handle) { type TrainReq struct { @@ -1899,17 +1919,29 @@ func handleTrain(handle *Handle) { return failed() } - go trainRetrain(c, model, def.Id) - _, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id) if err != nil { - fmt.Println("Failed to update model status") - fmt.Println(err) - // TODO improve this response - return c.Error500(err) + return c.E500M("Failed to update model status", err) } - return c.SendJSON(model.Id) + newTask := struct { + UserId string `db:"user_id"` + ModelId string `db:"model_id"` + TaskType TaskType `db:"task_type"` + Status int `db:"status"` + }{ + UserId: c.User.Id, + ModelId: model.Id, + TaskType: TASK_TYPE_RETRAINING, + Status: 1, + } + + id, err := InsertReturnId(c, &newTask, "tasks", "id") + if err != nil { + return c.E500M("Failed to create task", err) + } + + return c.SendJSON(JustId{Id: id}) }) handle.Get("/model/epoch/update", func(c *Context) *Error { diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go index 5a48b44..6f5737a 100644 --- a/logic/tasks/runner/runner.go +++ b/logic/tasks/runner/runner.go @@ -63,6 +63,14 @@ func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_c logger.Error("Failed to tain the model", "error", err) } + back_channel <- index + continue + } else if task.TaskType == int(TASK_TYPE_RETRAINING) { + logger.Info("Retraining Task") + if err = RunTaskRetrain(base, task); err != nil { + logger.Error("Failed to tain the model", "error", err) + } + back_channel <- index continue } diff --git a/logic/tasks/utils/utils.go b/logic/tasks/utils/utils.go index 9f7f3d9..bdeae3b 100644 --- a/logic/tasks/utils/utils.go +++ b/logic/tasks/utils/utils.go @@ -38,6 +38,7 @@ type TaskType int const ( TASK_TYPE_CLASSIFICATION TaskType = 1 + iota TASK_TYPE_TRAINING + TASK_TYPE_RETRAINING ) func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) { diff --git a/logic/utils/config.go b/logic/utils/config.go index ea1f62b..4788309 100644 --- a/logic/utils/config.go +++ b/logic/utils/config.go @@ -21,6 +21,10 @@ type ServiceUser struct { UserId string `toml:"__user__id__"` } +type DbInfo struct { + MaxConnections int `toml:"max_connections"` +} + type Config struct { Hostname string Port int @@ -30,6 +34,8 @@ type Config struct { GpuWorker WorkerConfig `toml:"Worker"` ServiceUser ServiceUser `toml:"ServiceUser"` + + DbInfo DbInfo `toml:"DB"` } func LoadConfig() Config { @@ -52,6 +58,9 @@ func LoadConfig() Config { User: "Service", UserId: "", }, + DbInfo: DbInfo{ + MaxConnections: 200, + }, } } @@ -84,13 +93,13 @@ func (c *Config) GenerateToken(db *sql.DB) { Email string Salt string Password string - UserType UserType `db:"user_type"` + UserType UserType `db:"user_type"` }{ c.ServiceUser.User, c.ServiceUser.User, "", "", - User_Admin, + User_Admin, } id, err := InsertReturnId(db, &newUser, "users", "id") if err != nil { diff --git a/main.go b/main.go index fe44f34..a1c24e5 100644 --- a/main.go +++ b/main.go @@ -37,7 +37,9 @@ func main() { config := LoadConfig() log.Info("Config loaded!", "config", config) - config.GenerateToken(db) + config.GenerateToken(db) + + db.SetMaxOpenConns(config.DbInfo.MaxConnections) StartRunners(db, config) diff --git a/webpage/src/routes/models/edit/ModelDataPageStatsGraph.svelte b/webpage/src/routes/models/edit/ModelDataPageStatsGraph.svelte index 651556d..ac55835 100644 --- a/webpage/src/routes/models/edit/ModelDataPageStatsGraph.svelte +++ b/webpage/src/routes/models/edit/ModelDataPageStatsGraph.svelte @@ -33,23 +33,35 @@ $effect(() => { if (data && ctx) { if (chart) { - console.log('update'); - chart.data = { - labels: data.map((a) => a.name), - datasets: [ - { - label: 'Training', - data: data.map((a) => a.training) - }, - { - label: 'Testing', - data: data.map((a) => a.testing) + chart.destroy(); + chart = new Chart(ctx, { + type: 'bar', + data: { + labels: data.map((a) => a.name), + datasets: [ + { + label: 'Training', + data: data.map((a) => a.training) + }, + { + label: 'Testing', + data: data.map((a) => a.testing) + } + ] + }, + options: { + animation: false, + scales: { + x: { + stacked: true + }, + y: { + stacked: true + } } - ] - }; - chart.update('resize'); + } + }); } else { - console.log('create'); chart = new Chart(ctx, { type: 'bar', data: { diff --git a/webpage/src/routes/models/edit/TasksTable.svelte b/webpage/src/routes/models/edit/TasksTable.svelte index bdd24bb..eadd0fd 100644 --- a/webpage/src/routes/models/edit/TasksTable.svelte +++ b/webpage/src/routes/models/edit/TasksTable.svelte @@ -85,6 +85,8 @@ Image Run {:else if task.type == 2} Model training + {:else if task.type == 3} + Model Re-training {:else} {task.type} {/if} @@ -98,7 +100,7 @@ width="30px" style="object-fit: contain;" /> - {:else if [2].includes(task.type)}{:else} + {:else if [2, 3].includes(task.type)}{:else} TODO Show more information {task.status} {/if} @@ -117,7 +119,7 @@ {:else} - {/if} - {:else if [2].includes(task.type)}{:else} + {:else if [2, 3].includes(task.type)}{:else} TODO Handle {task.type} {/if}