diff --git a/logic/models/edit.go b/logic/models/edit.go
index ab3b4bd..c9ece01 100644
--- a/logic/models/edit.go
+++ b/logic/models/edit.go
@@ -92,7 +92,34 @@ func handleEdit(handle *Handle) {
"Model": model,
}))
case TRAINING:
- fallthrough
+
+ type defrow struct {
+ Status int
+ EpochProgress int
+ Accuracy int
+ }
+
+ def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defer def_rows.Close()
+
+ defs := []defrow{}
+
+ for def_rows.Next() {
+ var def defrow
+ err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defs = append(defs, def)
+ }
+
+ LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
+ "Model": model,
+ "Defs": defs,
+ }))
case PREPARING_ZIP_FILE:
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
"Model": model,
diff --git a/logic/models/train/train.go b/logic/models/train/train.go
index 076f843..42d9487 100644
--- a/logic/models/train/train.go
+++ b/logic/models/train/train.go
@@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
+ "DefId": definition_id,
}); err != nil {
return
}
@@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) {
}
for _, def := range definitions {
+ ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinition(c, model, def.id)
if err != nil {
c.Logger.Error("Failed to train definition!Err:")
@@ -480,4 +482,58 @@ func handleTrain(handle *Handle) {
Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
return nil
})
+
+ handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
+ // TODO check auth level
+ if c.Mode != NORMAL {
+ // This should only handle normal requests
+ c.Logger.Warn("This function only works with normal")
+ return c.UnsafeErrorCode(nil, 400, nil)
+ }
+
+ f := r.URL.Query()
+
+ if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
+ c.Logger.Warn("Invalid: model_id or definition or epoch")
+ return c.UnsafeErrorCode(nil, 400, nil)
+ }
+
+ model_id := f.Get("model_id")
+ def_id := f.Get("definition")
+ epoch, err := strconv.Atoi(f.Get("epoch"))
+ if err != nil {
+ c.Logger.Warn("Epoch is not a number")
+ // No need to improve message because this function is only called internaly
+ return c.UnsafeErrorCode(nil, 400, nil)
+ }
+
+ rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defer rows.Close()
+
+ if !rows.Next() {
+ c.Logger.Error("Could not get status of model definition")
+ return c.Error500(nil)
+ }
+
+ var status int
+ err = rows.Scan(&status)
+ if err != nil {
+ return c.Error500(err)
+ }
+
+ if status != 3 {
+ c.Logger.Warn("Definition not on status 3(training)", "status", status)
+ // No need to improve message because this function is only called internaly
+ return c.UnsafeErrorCode(nil, 400, nil)
+ }
+
+ _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
+ if err != nil {
+ return c.Error500(err)
+ }
+ return nil
+ })
}
diff --git a/logic/utils/handler.go b/logic/utils/handler.go
index bd9ea3a..0d4ebdf 100644
--- a/logic/utils/handler.go
+++ b/logic/utils/handler.go
@@ -387,6 +387,14 @@ func (c Context) ErrorCode(err error, code int, data AnyMap) *Error {
return &Error{code, nil, c.AddMap(data)}
}
+func (c Context) UnsafeErrorCode(err error, code int, data AnyMap) *Error {
+ if err != nil {
+ c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
+ c.Logger.Error(err)
+ }
+ return &Error{code, nil, c.AddMap(data)}
+}
+
func (c Context) Error500(err error) *Error {
return c.ErrorCode(err, http.StatusInternalServerError, nil)
}
@@ -414,6 +422,12 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
var token *string
+ logger := log.NewWithOptions(os.Stdout, log.Options{
+ ReportTimestamp: true,
+ TimeFormat: time.Kitchen,
+ Prefix: r.URL.Path,
+ })
+
for _, r := range r.Cookies() {
if r.Name == "auth" {
token = &r.Value
@@ -425,6 +439,8 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
if token == nil {
return &Context{
Mode: mode,
+ Logger: logger,
+ Db: handler.Db,
}, nil
}
@@ -433,12 +449,6 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
return nil, errors.Join(err, LogoffError)
}
- logger := log.NewWithOptions(os.Stdout, log.Options{
- ReportTimestamp: true,
- TimeFormat: time.Kitchen,
- Prefix: r.URL.Path,
- })
-
return &Context{token, user, mode, logger, handler.Db}, nil
}
diff --git a/sql/models.sql b/sql/models.sql
index 3dae684..943923e 100644
--- a/sql/models.sql
+++ b/sql/models.sql
@@ -55,7 +55,8 @@ create table if not exists model_definition (
-- 4: Tranied
-- 5: Ready
status integer default 1,
- created_on timestamp default current_timestamp
+ created_on timestamp default current_timestamp,
+ epoch_progress integer default 0
);
-- drop table if exists model_definition_layer;
diff --git a/views/models/edit.html b/views/models/edit.html
index fb42885..d180e6d 100644
--- a/views/models/edit.html
+++ b/views/models/edit.html
@@ -434,7 +434,20 @@
{{/* TODO improve this */}}
Training the model...
{{/* TODO Add progress status on definitions */}}
- {{/* TODO Add aility to stop training */}}
+ {{ range .Defs}}
+