From 2c3539b81ad2049716538ef71fd9d8fe05abcd86 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Thu, 19 Oct 2023 10:44:13 +0100 Subject: [PATCH] feat: closes #40 --- logic/models/train/train.go | 234 +++++++++++++++++++----------- sql/models.sql | 1 - views/models/edit.html | 43 ++++-- views/py/python_model_template.py | 11 +- 4 files changed, 184 insertions(+), 105 deletions(-) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 42d9487..d5d24dc 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -17,6 +17,9 @@ import ( . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) +const EPOCH_PER_RUN = 20; +const MAX_EPOCH = 100 + func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { id = "" rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy) @@ -34,12 +37,13 @@ func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string type ModelDefinitionStatus int const ( - MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3 - MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1 - MODEL_DEFINITION_STATUS_INIT = 2 - MODEL_DEFINITION_STATUS_TRAINING = 3 - MODEL_DEFINITION_STATUS_TRANIED = 4 - MODEL_DEFINITION_STATUS_READY = 5 + MODEL_DEFINITION_STATUS_CANCELD_TRAINING = -4 + MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3 + MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1 + MODEL_DEFINITION_STATUS_INIT = 2 + MODEL_DEFINITION_STATUS_TRAINING = 3 + MODEL_DEFINITION_STATUS_TRANIED = 4 + MODEL_DEFINITION_STATUS_READY = 5 ) type LayerType int @@ -104,7 +108,8 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e return } -func trainDefinition(c *Context, model *BaseModel, definition_id string) (accuracy float64, err error) { +func trainDefinition(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { + c.Logger.Warn("About to start training definition") accuracy = 0 layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { @@ -153,14 +158,20 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura return } + // Copy result around + result_path := path.Join("savedData", model.Id, "defs", definition_id) + if err = tmpl.Execute(f, AnyMap{ - "Layers": got, - "Size": got[0].Shape, - "DataDir": path.Join(getDir(), "savedData", model.Id, "data"), - "RunPath": run_path, - "ColorMode": model.ImageMode, - "Model": model, - "DefId": definition_id, + "Layers": got, + "Size": got[0].Shape, + "DataDir": path.Join(getDir(), "savedData", model.Id, "data"), + "RunPath": run_path, + "ColorMode": model.ImageMode, + "Model": model, + "EPOCH_PER_RUN": EPOCH_PER_RUN, + "DefId": definition_id, + "LoadPrev": load_prev, + "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"), }); err != nil { return } @@ -168,13 +179,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura // Run the command out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output() if err != nil { - c.Logger.Debug(string(out)) + c.Logger.Debug(string(out)) return } - // Copy result around - result_path := path.Join("savedData", model.Id, "defs", definition_id) - if err = os.MkdirAll(result_path, os.ModePerm); err != nil { return } @@ -183,6 +191,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura return } + if err = exec.Command("cp", "-r", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run(); err != nil { + return + } + accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val")) if err != nil { return @@ -194,7 +206,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura return } - fmt.Println(string(accuracy_file_bytes)) + c.Logger.Info("Model finished training!", "accuracy", accuracy) accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64) if err != nil { @@ -205,8 +217,25 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura return } +func remove[T interface{}](lst []T, i int) []T { + lng := len(lst) + if i >= lng { + return []T{} + } + + if i+1 >= lng { + return lst[:lng-1] + } + + if i == 0 { + return lst[1:] + } + + return append(lst[:i], lst[i+1:]...) +} + func trainModel(c *Context, model *BaseModel) { - definitionsRows, err := c.Db.Query("select id, target_accuracy from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) + definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) if err != nil { c.Logger.Error("Failed to trainModel!Err:") c.Logger.Error(err) @@ -218,13 +247,14 @@ func trainModel(c *Context, model *BaseModel) { type row struct { id string target_accuracy int + epoch int } definitions := []row{} for definitionsRows.Next() { var rowv row - if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy); err != nil { + if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil { c.Logger.Error("Failed to train Model Could not read definition from db!Err:") c.Logger.Error(err) ModelUpdateStatus(c, model.Id, FAILED_TRAINING) @@ -239,30 +269,58 @@ func trainModel(c *Context, model *BaseModel) { return } - for _, def := range definitions { - ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) - accuracy, err := trainDefinition(c, model, def.id) - if err != nil { - c.Logger.Error("Failed to train definition!Err:") - c.Logger.Error(err) - ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) - continue + toTrain := len(definitions) + firstRound := true + var newDefinitions = []row{} + copy(newDefinitions, definitions) + for { + for i, def := range definitions { + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) + accuracy, err := trainDefinition(c, model, def.id, !firstRound) + if err != nil { + c.Logger.Error("Failed to train definition!Err:", "err", err) + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + toTrain = toTrain - 1 + newDefinitions = remove(newDefinitions, i) + continue + } + def.epoch += EPOCH_PER_RUN + + int_accuracy := int(accuracy * 100) + + if int_accuracy >= def.target_accuracy { + c.Logger.Info("Found a definition that reaches target_accuracy!") + _, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id) + if err != nil { + c.Logger.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + _, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + if err != nil { + c.Logger.Error("Failed to train definition!Err:\n", "err", err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) + return + } + + toTrain = 0 + break + } + + if def.epoch > MAX_EPOCH { + fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy) + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + toTrain = toTrain - 1 + newDefinitions = remove(newDefinitions, i) + continue + } + } - - int_accuracy := int(accuracy * 100) - - if int_accuracy < def.target_accuracy { - fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy) - ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) - continue - } - - _, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2 where id=$3", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.id) - if err != nil { - fmt.Printf("Failed to train definition!Err:\n") - fmt.Println(err) - ModelUpdateStatus(c, model.Id, FAILED_TRAINING) - return + copy(definitions, newDefinitions) + firstRound = false + if toTrain == 0 { + break } } @@ -335,7 +393,7 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) { } defer rows.Close() - base_path := path.Join("savedData", model.Id, "data") + base_path := path.Join("savedData", model.Id, "data") for rows.Next() { var dataPointId string @@ -343,13 +401,13 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) { if err != nil { return } - err = os.RemoveAll(path.Join(base_path, dataPointId + model.Format)) - if err != nil { - return - } + err = os.RemoveAll(path.Join(base_path, dataPointId+model.Format)) + if err != nil { + return + } } - _, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id) + _, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id) return } @@ -484,56 +542,56 @@ func handleTrain(handle *Handle) { }) handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { - // TODO check auth level + // TODO check auth level if c.Mode != NORMAL { - // This should only handle normal requests - c.Logger.Warn("This function only works with normal") - return c.UnsafeErrorCode(nil, 400, nil) + // This should only handle normal requests + c.Logger.Warn("This function only works with normal") + return c.UnsafeErrorCode(nil, 400, nil) } - f := r.URL.Query() + f := r.URL.Query() - if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") { - c.Logger.Warn("Invalid: model_id or definition or epoch") - return c.UnsafeErrorCode(nil, 400, nil) - } + if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") { + c.Logger.Warn("Invalid: model_id or definition or epoch") + return c.UnsafeErrorCode(nil, 400, nil) + } model_id := f.Get("model_id") def_id := f.Get("definition") - epoch, err := strconv.Atoi(f.Get("epoch")) - if err != nil { - c.Logger.Warn("Epoch is not a number") - // No need to improve message because this function is only called internaly - return c.UnsafeErrorCode(nil, 400, nil) - } + epoch, err := strconv.Atoi(f.Get("epoch")) + if err != nil { + c.Logger.Warn("Epoch is not a number") + // No need to improve message because this function is only called internaly + return c.UnsafeErrorCode(nil, 400, nil) + } - rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id) - if err != nil { - return c.Error500(err) - } - defer rows.Close() + rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id) + if err != nil { + return c.Error500(err) + } + defer rows.Close() - if !rows.Next() { - c.Logger.Error("Could not get status of model definition") - return c.Error500(nil) - } + if !rows.Next() { + c.Logger.Error("Could not get status of model definition") + return c.Error500(nil) + } - var status int - err = rows.Scan(&status) - if err != nil { - return c.Error500(err) - } + var status int + err = rows.Scan(&status) + if err != nil { + return c.Error500(err) + } - if status != 3 { - c.Logger.Warn("Definition not on status 3(training)", "status", status) - // No need to improve message because this function is only called internaly - return c.UnsafeErrorCode(nil, 400, nil) - } + if status != 3 { + c.Logger.Warn("Definition not on status 3(training)", "status", status) + // No need to improve message because this function is only called internaly + return c.UnsafeErrorCode(nil, 400, nil) + } - _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id) - if err != nil { - return c.Error500(err) - } + _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id) + if err != nil { + return c.Error500(err) + } return nil }) } diff --git a/sql/models.sql b/sql/models.sql index 943923e..ba72479 100644 --- a/sql/models.sql +++ b/sql/models.sql @@ -40,7 +40,6 @@ create table if not exists model_data_point ( status_message text ); --- drop table if exists model_definition; -- drop table if exists model_definition; create table if not exists model_definition ( id uuid primary key default gen_random_uuid(), diff --git a/views/models/edit.html b/views/models/edit.html index d180e6d..fbc5f5a 100644 --- a/views/models/edit.html +++ b/views/models/edit.html @@ -434,19 +434,36 @@ {{/* TODO improve this */}} Training the model...
{{/* TODO Add progress status on definitions */}} - {{ range .Defs}} -
-
- {{.Status}} -
-
- {{.EpochProgress}} -
-
- {{.Accuracy}} -
-
- {{ end }} + + + + + + + + + + {{ range .Defs}} + + + + + + {{ end }} + +
+ Status + + EpochProgress + + Accuracy +
+ {{.Status}} + + {{.EpochProgress}} + + {{.Accuracy}} +
{{/* TODO Add ability to stop training */}} {{/* Model Ready */}} diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 872fab7..0046bde 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -93,6 +93,10 @@ val_ds = list_ds.take(val_size) dataset = prepare_dataset(train_ds) dataset_validation = prepare_dataset(val_ds) + +{{ if .LoadPrev }} +model = tf.keras.saving.load_model('{{.LastModelRunPath}}') +{{ else }} model = keras.Sequential([ {{- range .Layers }} {{- if eq .LayerType 1}} @@ -106,13 +110,14 @@ model = keras.Sequential([ {{- end }} {{- end }} ]) +{{ end }} model.compile( loss=losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) -his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()]) +his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()]) acc = his.history["accuracy"] @@ -120,6 +125,6 @@ f = open("accuracy.val", "w") f.write(str(acc[-1])) f.close() -tf.saved_model.save(model, "model") -# model.save("model.keras", save_format="tf") +tf.saved_model.save(model, "model") +model.save("model.keras")