feat: closes #39

This commit is contained in:
2023-10-12 12:08:12 +01:00
parent c7c6cfcd00
commit f163e25fba
6 changed files with 131 additions and 12 deletions

View File

@@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
"DefId": definition_id,
}); err != nil {
return
}
@@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) {
}
for _, def := range definitions {
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinition(c, model, def.id)
if err != nil {
c.Logger.Error("Failed to train definition!Err:")
@@ -480,4 +482,58 @@ func handleTrain(handle *Handle) {
Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
return nil
})
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
// TODO check auth level
if c.Mode != NORMAL {
// This should only handle normal requests
c.Logger.Warn("This function only works with normal")
return c.UnsafeErrorCode(nil, 400, nil)
}
f := r.URL.Query()
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
c.Logger.Warn("Invalid: model_id or definition or epoch")
return c.UnsafeErrorCode(nil, 400, nil)
}
model_id := f.Get("model_id")
def_id := f.Get("definition")
epoch, err := strconv.Atoi(f.Get("epoch"))
if err != nil {
c.Logger.Warn("Epoch is not a number")
// No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil)
}
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
if err != nil {
return c.Error500(err)
}
defer rows.Close()
if !rows.Next() {
c.Logger.Error("Could not get status of model definition")
return c.Error500(nil)
}
var status int
err = rows.Scan(&status)
if err != nil {
return c.Error500(err)
}
if status != 3 {
c.Logger.Warn("Definition not on status 3(training)", "status", status)
// No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil)
}
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
if err != nil {
return c.Error500(err)
}
return nil
})
}