move model retrain to runners closes #94

This commit is contained in:
Andre Henriques 2024-04-16 17:48:52 +01:00
parent f182b205f8
commit 06642dcb1e
9 changed files with 165 additions and 96 deletions

View File

@ -1,14 +1,17 @@
PORT=5002 PORT = 5002
HOSTNAME="https://testing.andr3h3nriqu3s.com" HOSTNAME = "https://testing.andr3h3nriqu3s.com"
NUMBER_OF_WORKERS=20 NUMBER_OF_WORKERS = 20
SUPRESS_CUDA=1 SUPRESS_CUDA = 1
[ServiceUser] [ServiceUser]
USER="service" USER = "service"
[Worker] [Worker]
PULLING_TIME="500ms" PULLING_TIME = "500ms"
NUMBER_OF_WORKERS=1 NUMBER_OF_WORKERS = 1
[DB]
MAX_CONNECTIONS = 600

View File

@ -1,9 +1,9 @@
version: '3.1' version: "3.1"
services: services:
db: db:
image: docker.andr3h3nriqu3s.com/services/postgres image: docker.andr3h3nriqu3s.com/services/postgres
command: -c 'max_connections=400' command: -c 'max_connections=600'
restart: always restart: always
environment: environment:
POSTGRES_PASSWORD: verysafepassword POSTGRES_PASSWORD: verysafepassword

View File

@ -281,16 +281,17 @@ func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_pr
return return
} }
func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) { func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
l, db := c.GetLogger(), c.GetDb()
var co struct { var co struct {
Count int `db:"count(*)"` Count int `db:"count(*)"`
} }
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING) err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
return return
} }
c.Logger.Info("test here", "count", co) l.Info("test here", "count", co)
count_re = co.Count count_re = co.Count
count := co.Count count := co.Count
@ -304,7 +305,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
return generateCvsExpandExp(c, run_path, model_id, offset, true) return generateCvsExpandExp(c, run_path, model_id, offset, true)
} }
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING) data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
return return
} }
@ -338,7 +339,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
// This is to load some extra data so that the model has more things to train on // This is to load some extra data so that the model has more things to train on
// //
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10) data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
if err != nil { if err != nil {
return return
} }
@ -361,10 +362,11 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
return return
} }
func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0 accuracy = 0
c.Logger.Warn("About to retrain model") l := c.GetLogger()
l.Warn("About to retrain model")
// Get untrained models heads // Get untrained models heads
@ -375,7 +377,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
} }
// status = 2 (INIT) 3 (TRAINING) // status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id) heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil { if err != nil {
return return
} else if len(heads) == 0 { } else if len(heads) == 0 {
@ -389,13 +391,13 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
exp := heads[0] exp := heads[0]
c.Logger.Info("Got exp head", "head", exp) l.Info("Got exp head", "head", exp)
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil { if err = UpdateStatus(c.GetDb(), "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return return
} }
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) layers, err := c.GetDb().Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil { if err != nil {
return return
} }
@ -447,7 +449,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
LayerNum: i, LayerNum: i,
}) })
c.Logger.Info("Got layers", "layers", got) l.Info("Got layers", "layers", got)
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain") run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
@ -462,7 +464,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
return return
} }
c.Logger.Info("Generated cvs", "classCount", classCount) l.Info("Generated cvs", "classCount", classCount)
// TODO update the run script // TODO update the run script
@ -473,7 +475,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
} }
defer f.Close() defer f.Close()
c.Logger.Info("About to run python!") l.Info("About to run python!")
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py") tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
if err != nil { if err != nil {
@ -498,7 +500,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id), "SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
"Depth": classCount, "Depth": classCount,
"StartPoint": 0, "StartPoint": 0,
"Host": (*c.Handle).Config.Hostname, "Host": c.GetHost(),
}); err != nil { }); err != nil {
return return
} }
@ -506,11 +508,11 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
// Run the command // Run the command
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput() out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
if err != nil { if err != nil {
c.Logger.Warn("Python failed to run", "err", err, "out", string(out)) l.Warn("Python failed to run", "err", err, "out", string(out))
return return
} }
c.Logger.Info("Python finished running") l.Info("Python finished running")
if err = os.MkdirAll(result_path, os.ModePerm); err != nil { if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return return
@ -533,7 +535,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
} }
os.RemoveAll(run_path) os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy) l.Info("Model finished training!", "accuracy", accuracy)
return return
} }
@ -1555,10 +1557,10 @@ func generateExpandableDefinitions(c BasePack, model *BaseModel, target_accuracy
return nil return nil
} }
func ResetClasses(c *Context, model *BaseModel) { func ResetClasses(c BasePack, model *BaseModel) {
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id) _, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil { if err != nil {
c.Logger.Error("Error while reseting the classes", "error", err) c.GetLogger().Error("Error while reseting the classes", "error", err)
} }
} }
@ -1620,44 +1622,6 @@ func trainExpandable(c *Context, model *BaseModel) {
ModelUpdateStatus(c, model.Id, READY) ModelUpdateStatus(c, model.Id, READY)
} }
func trainRetrain(c *Context, model *BaseModel, defId string) {
var err error
failed := func() {
ResetClasses(c, model)
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
return
}
// This is something I have to check
acc, err := trainDefinitionExpandExp(c, model, defId, false)
if err != nil {
c.Logger.Error("Failed to retrain the model", "err", err)
failed()
return
}
c.Logger.Info("Retrained model", "accuracy", acc)
// TODO check accuracy
err = UpdateStatus(c, "models", model.Id, READY)
if err != nil {
failed()
return
}
c.Logger.Info("model updaded")
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.Logger.Error("Error while updating the classes", "error", err)
failed()
return
}
}
func RunTaskTrain(b BasePack, task Task) (err error) { func RunTaskTrain(b BasePack, task Task) (err error) {
l := b.GetLogger() l := b.GetLogger()
@ -1718,6 +1682,62 @@ func RunTaskTrain(b BasePack, task Task) (err error) {
return return
} }
func RunTaskRetrain(b BasePack, task Task) (err error) {
model, err := GetBaseModel(b.GetDb(), task.ModelId)
if err != nil {
return err
} else if model.Status != READY_RETRAIN {
return errors.New("Model in invalid status for re-training")
}
l := b.GetLogger()
db := b.GetDb()
failed := func() {
ResetClasses(b, model)
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining")
l.Error("Failed to retrain", "err", err)
}
task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining")
defId, err := GetDbVar[string](db, "md.id", "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
if err != nil {
failed()
return
}
// This is something I have to check
acc, err := trainDefinitionExpandExp(b, model, *defId, false)
if err != nil {
failed()
return
}
l.Info("Retrained model", "accuracy", acc)
// TODO check accuracy
err = UpdateStatus(db, "models", model.Id, READY)
if err != nil {
failed()
return
}
l.Info("Model updaded")
_, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
l.Error("Error while updating the classes", "error", err)
failed()
return
}
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
return
}
func handleTrain(handle *Handle) { func handleTrain(handle *Handle) {
type TrainReq struct { type TrainReq struct {
@ -1899,17 +1919,29 @@ func handleTrain(handle *Handle) {
return failed() return failed()
} }
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id) _, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil { if err != nil {
fmt.Println("Failed to update model status") return c.E500M("Failed to update model status", err)
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
} }
return c.SendJSON(model.Id) newTask := struct {
UserId string `db:"user_id"`
ModelId string `db:"model_id"`
TaskType TaskType `db:"task_type"`
Status int `db:"status"`
}{
UserId: c.User.Id,
ModelId: model.Id,
TaskType: TASK_TYPE_RETRAINING,
Status: 1,
}
id, err := InsertReturnId(c, &newTask, "tasks", "id")
if err != nil {
return c.E500M("Failed to create task", err)
}
return c.SendJSON(JustId{Id: id})
}) })
handle.Get("/model/epoch/update", func(c *Context) *Error { handle.Get("/model/epoch/update", func(c *Context) *Error {

View File

@ -63,6 +63,14 @@ func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_c
logger.Error("Failed to tain the model", "error", err) logger.Error("Failed to tain the model", "error", err)
} }
back_channel <- index
continue
} else if task.TaskType == int(TASK_TYPE_RETRAINING) {
logger.Info("Retraining Task")
if err = RunTaskRetrain(base, task); err != nil {
logger.Error("Failed to tain the model", "error", err)
}
back_channel <- index back_channel <- index
continue continue
} }

View File

@ -38,6 +38,7 @@ type TaskType int
const ( const (
TASK_TYPE_CLASSIFICATION TaskType = 1 + iota TASK_TYPE_CLASSIFICATION TaskType = 1 + iota
TASK_TYPE_TRAINING TASK_TYPE_TRAINING
TASK_TYPE_RETRAINING
) )
func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) { func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) {

View File

@ -21,6 +21,10 @@ type ServiceUser struct {
UserId string `toml:"__user__id__"` UserId string `toml:"__user__id__"`
} }
type DbInfo struct {
MaxConnections int `toml:"max_connections"`
}
type Config struct { type Config struct {
Hostname string Hostname string
Port int Port int
@ -30,6 +34,8 @@ type Config struct {
GpuWorker WorkerConfig `toml:"Worker"` GpuWorker WorkerConfig `toml:"Worker"`
ServiceUser ServiceUser `toml:"ServiceUser"` ServiceUser ServiceUser `toml:"ServiceUser"`
DbInfo DbInfo `toml:"DB"`
} }
func LoadConfig() Config { func LoadConfig() Config {
@ -52,6 +58,9 @@ func LoadConfig() Config {
User: "Service", User: "Service",
UserId: "", UserId: "",
}, },
DbInfo: DbInfo{
MaxConnections: 200,
},
} }
} }

View File

@ -39,6 +39,8 @@ func main() {
log.Info("Config loaded!", "config", config) log.Info("Config loaded!", "config", config)
config.GenerateToken(db) config.GenerateToken(db)
db.SetMaxOpenConns(config.DbInfo.MaxConnections)
StartRunners(db, config) StartRunners(db, config)
//TODO check if file structure exists to save data //TODO check if file structure exists to save data

View File

@ -33,8 +33,10 @@
$effect(() => { $effect(() => {
if (data && ctx) { if (data && ctx) {
if (chart) { if (chart) {
console.log('update'); chart.destroy();
chart.data = { chart = new Chart(ctx, {
type: 'bar',
data: {
labels: data.map((a) => a.name), labels: data.map((a) => a.name),
datasets: [ datasets: [
{ {
@ -46,10 +48,20 @@
data: data.map((a) => a.testing) data: data.map((a) => a.testing)
} }
] ]
}; },
chart.update('resize'); options: {
animation: false,
scales: {
x: {
stacked: true
},
y: {
stacked: true
}
}
}
});
} else { } else {
console.log('create');
chart = new Chart(ctx, { chart = new Chart(ctx, {
type: 'bar', type: 'bar',
data: { data: {

View File

@ -85,6 +85,8 @@
Image Run Image Run
{:else if task.type == 2} {:else if task.type == 2}
Model training Model training
{:else if task.type == 3}
Model Re-training
{:else} {:else}
{task.type} {task.type}
{/if} {/if}
@ -98,7 +100,7 @@
width="30px" width="30px"
style="object-fit: contain;" style="object-fit: contain;"
/> />
{:else if [2].includes(task.type)}{:else} {:else if [2, 3].includes(task.type)}{:else}
TODO Show more information {task.status} TODO Show more information {task.status}
{/if} {/if}
</td> </td>
@ -117,7 +119,7 @@
{:else} {:else}
- -
{/if} {/if}
{:else if [2].includes(task.type)}{:else} {:else if [2, 3].includes(task.type)}{:else}
TODO Handle {task.type} TODO Handle {task.type}
{/if} {/if}
</td> </td>