feat: closes #39
This commit is contained in:
@@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
||||
"RunPath": run_path,
|
||||
"ColorMode": model.ImageMode,
|
||||
"Model": model,
|
||||
"DefId": definition_id,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) {
|
||||
}
|
||||
|
||||
for _, def := range definitions {
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||
accuracy, err := trainDefinition(c, model, def.id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to train definition!Err:")
|
||||
@@ -480,4 +482,58 @@ func handleTrain(handle *Handle) {
|
||||
Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
|
||||
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
// TODO check auth level
|
||||
if c.Mode != NORMAL {
|
||||
// This should only handle normal requests
|
||||
c.Logger.Warn("This function only works with normal")
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
f := r.URL.Query()
|
||||
|
||||
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
|
||||
c.Logger.Warn("Invalid: model_id or definition or epoch")
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
model_id := f.Get("model_id")
|
||||
def_id := f.Get("definition")
|
||||
epoch, err := strconv.Atoi(f.Get("epoch"))
|
||||
if err != nil {
|
||||
c.Logger.Warn("Epoch is not a number")
|
||||
// No need to improve message because this function is only called internaly
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
c.Logger.Error("Could not get status of model definition")
|
||||
return c.Error500(nil)
|
||||
}
|
||||
|
||||
var status int
|
||||
err = rows.Scan(&status)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
if status != 3 {
|
||||
c.Logger.Warn("Definition not on status 3(training)", "status", status)
|
||||
// No need to improve message because this function is only called internaly
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user