diff --git a/logic/models/edit.go b/logic/models/edit.go index ab3b4bd..c9ece01 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -92,7 +92,34 @@ func handleEdit(handle *Handle) { "Model": model, })) case TRAINING: - fallthrough + + type defrow struct { + Status int + EpochProgress int + Accuracy int + } + + def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id) + if err != nil { + return c.Error500(err) + } + defer def_rows.Close() + + defs := []defrow{} + + for def_rows.Next() { + var def defrow + err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy) + if err != nil { + return c.Error500(err) + } + defs = append(defs, def) + } + + LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ + "Model": model, + "Defs": defs, + })) case PREPARING_ZIP_FILE: LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ "Model": model, diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 076f843..42d9487 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura "RunPath": run_path, "ColorMode": model.ImageMode, "Model": model, + "DefId": definition_id, }); err != nil { return } @@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) { } for _, def := range definitions { + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) accuracy, err := trainDefinition(c, model, def.id) if err != nil { c.Logger.Error("Failed to train definition!Err:") @@ -480,4 +482,58 @@ func handleTrain(handle *Handle) { Redirect("/models/edit?id="+model.Id, c.Mode, w, r) return nil }) + + handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + // TODO check auth level + if c.Mode != NORMAL { + // This should only handle normal requests + c.Logger.Warn("This function only works with normal") + return c.UnsafeErrorCode(nil, 400, nil) + } + + f := r.URL.Query() + + if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") { + c.Logger.Warn("Invalid: model_id or definition or epoch") + return c.UnsafeErrorCode(nil, 400, nil) + } + + model_id := f.Get("model_id") + def_id := f.Get("definition") + epoch, err := strconv.Atoi(f.Get("epoch")) + if err != nil { + c.Logger.Warn("Epoch is not a number") + // No need to improve message because this function is only called internaly + return c.UnsafeErrorCode(nil, 400, nil) + } + + rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id) + if err != nil { + return c.Error500(err) + } + defer rows.Close() + + if !rows.Next() { + c.Logger.Error("Could not get status of model definition") + return c.Error500(nil) + } + + var status int + err = rows.Scan(&status) + if err != nil { + return c.Error500(err) + } + + if status != 3 { + c.Logger.Warn("Definition not on status 3(training)", "status", status) + // No need to improve message because this function is only called internaly + return c.UnsafeErrorCode(nil, 400, nil) + } + + _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id) + if err != nil { + return c.Error500(err) + } + return nil + }) } diff --git a/logic/utils/handler.go b/logic/utils/handler.go index bd9ea3a..0d4ebdf 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -387,6 +387,14 @@ func (c Context) ErrorCode(err error, code int, data AnyMap) *Error { return &Error{code, nil, c.AddMap(data)} } +func (c Context) UnsafeErrorCode(err error, code int, data AnyMap) *Error { + if err != nil { + c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code) + c.Logger.Error(err) + } + return &Error{code, nil, c.AddMap(data)} +} + func (c Context) Error500(err error) *Error { return c.ErrorCode(err, http.StatusInternalServerError, nil) } @@ -414,6 +422,12 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) var token *string + logger := log.NewWithOptions(os.Stdout, log.Options{ + ReportTimestamp: true, + TimeFormat: time.Kitchen, + Prefix: r.URL.Path, + }) + for _, r := range r.Cookies() { if r.Name == "auth" { token = &r.Value @@ -425,6 +439,8 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) if token == nil { return &Context{ Mode: mode, + Logger: logger, + Db: handler.Db, }, nil } @@ -433,12 +449,6 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) return nil, errors.Join(err, LogoffError) } - logger := log.NewWithOptions(os.Stdout, log.Options{ - ReportTimestamp: true, - TimeFormat: time.Kitchen, - Prefix: r.URL.Path, - }) - return &Context{token, user, mode, logger, handler.Db}, nil } diff --git a/sql/models.sql b/sql/models.sql index 3dae684..943923e 100644 --- a/sql/models.sql +++ b/sql/models.sql @@ -55,7 +55,8 @@ create table if not exists model_definition ( -- 4: Tranied -- 5: Ready status integer default 1, - created_on timestamp default current_timestamp + created_on timestamp default current_timestamp, + epoch_progress integer default 0 ); -- drop table if exists model_definition_layer; diff --git a/views/models/edit.html b/views/models/edit.html index fb42885..d180e6d 100644 --- a/views/models/edit.html +++ b/views/models/edit.html @@ -434,7 +434,20 @@ {{/* TODO improve this */}} Training the model...
{{/* TODO Add progress status on definitions */}} - {{/* TODO Add aility to stop training */}} + {{ range .Defs}} +
+
+ {{.Status}} +
+
+ {{.EpochProgress}} +
+
+ {{.Accuracy}} +
+
+ {{ end }} + {{/* TODO Add ability to stop training */}} {{/* Model Ready */}} {{ else if (eq .Model.Status 5)}} diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 6edcaf2..872fab7 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -4,6 +4,14 @@ import pandas as pd from tensorflow import keras from tensorflow.data import AUTOTUNE from keras import layers, losses, optimizers +import requests + +class NotifyServerCallback(tf.keras.callbacks.Callback): + def on_epoch_begin(self, epoch, *args, **kwargs): + if (epoch % 5) == 0: + # TODO change this + requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&definition={{.DefId}}') + DATA_DIR = "{{ .DataDir }}" image_size = ({{ .Size }}) @@ -26,11 +34,15 @@ DATA_DIR_PREPARE = DATA_DIR + "/" #based on https://www.tensorflow.org/tutorials/load_data/images def pathToLabel(path): path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") + {{ if eq .Model.Format "png" }} + path = tf.strings.regex_replace(path, ".png", "") + {{ else if eq .Model.Format "jpeg" }} path = tf.strings.regex_replace(path, ".jpg", "") path = tf.strings.regex_replace(path, ".jpeg", "") - path = tf.strings.regex_replace(path, ".png", "") + {{ else }} + ERROR + {{ end }} return table.lookup(tf.strings.as_string([path])) - #return tf.strings.as_string([path]) def decode_image(img): {{ if eq .Model.Format "png" }} @@ -100,7 +112,7 @@ model.compile( optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy']) -his = model.fit(dataset, validation_data= dataset_validation, epochs=50) +his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()]) acc = his.history["accuracy"]