feat: closes #39

This commit is contained in:
Andre Henriques 2023-10-12 12:08:12 +01:00
parent c7c6cfcd00
commit f163e25fba
6 changed files with 131 additions and 12 deletions

View File

@ -92,7 +92,34 @@ func handleEdit(handle *Handle) {
"Model": model, "Model": model,
})) }))
case TRAINING: case TRAINING:
fallthrough
type defrow struct {
Status int
EpochProgress int
Accuracy int
}
def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id)
if err != nil {
return c.Error500(err)
}
defer def_rows.Close()
defs := []defrow{}
for def_rows.Next() {
var def defrow
err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy)
if err != nil {
return c.Error500(err)
}
defs = append(defs, def)
}
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
"Model": model,
"Defs": defs,
}))
case PREPARING_ZIP_FILE: case PREPARING_ZIP_FILE:
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
"Model": model, "Model": model,

View File

@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
"RunPath": run_path, "RunPath": run_path,
"ColorMode": model.ImageMode, "ColorMode": model.ImageMode,
"Model": model, "Model": model,
"DefId": definition_id,
}); err != nil { }); err != nil {
return return
} }
@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) {
} }
for _, def := range definitions { for _, def := range definitions {
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinition(c, model, def.id) accuracy, err := trainDefinition(c, model, def.id)
if err != nil { if err != nil {
c.Logger.Error("Failed to train definition!Err:") c.Logger.Error("Failed to train definition!Err:")
@ -480,4 +482,58 @@ func handleTrain(handle *Handle) {
Redirect("/models/edit?id="+model.Id, c.Mode, w, r) Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
return nil return nil
}) })
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
// TODO check auth level
if c.Mode != NORMAL {
// This should only handle normal requests
c.Logger.Warn("This function only works with normal")
return c.UnsafeErrorCode(nil, 400, nil)
}
f := r.URL.Query()
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
c.Logger.Warn("Invalid: model_id or definition or epoch")
return c.UnsafeErrorCode(nil, 400, nil)
}
model_id := f.Get("model_id")
def_id := f.Get("definition")
epoch, err := strconv.Atoi(f.Get("epoch"))
if err != nil {
c.Logger.Warn("Epoch is not a number")
// No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil)
}
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
if err != nil {
return c.Error500(err)
}
defer rows.Close()
if !rows.Next() {
c.Logger.Error("Could not get status of model definition")
return c.Error500(nil)
}
var status int
err = rows.Scan(&status)
if err != nil {
return c.Error500(err)
}
if status != 3 {
c.Logger.Warn("Definition not on status 3(training)", "status", status)
// No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil)
}
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
if err != nil {
return c.Error500(err)
}
return nil
})
} }

View File

@ -387,6 +387,14 @@ func (c Context) ErrorCode(err error, code int, data AnyMap) *Error {
return &Error{code, nil, c.AddMap(data)} return &Error{code, nil, c.AddMap(data)}
} }
func (c Context) UnsafeErrorCode(err error, code int, data AnyMap) *Error {
if err != nil {
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
c.Logger.Error(err)
}
return &Error{code, nil, c.AddMap(data)}
}
func (c Context) Error500(err error) *Error { func (c Context) Error500(err error) *Error {
return c.ErrorCode(err, http.StatusInternalServerError, nil) return c.ErrorCode(err, http.StatusInternalServerError, nil)
} }
@ -414,6 +422,12 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
var token *string var token *string
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: r.URL.Path,
})
for _, r := range r.Cookies() { for _, r := range r.Cookies() {
if r.Name == "auth" { if r.Name == "auth" {
token = &r.Value token = &r.Value
@ -425,6 +439,8 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
if token == nil { if token == nil {
return &Context{ return &Context{
Mode: mode, Mode: mode,
Logger: logger,
Db: handler.Db,
}, nil }, nil
} }
@ -433,12 +449,6 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
return nil, errors.Join(err, LogoffError) return nil, errors.Join(err, LogoffError)
} }
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: r.URL.Path,
})
return &Context{token, user, mode, logger, handler.Db}, nil return &Context{token, user, mode, logger, handler.Db}, nil
} }

View File

@ -55,7 +55,8 @@ create table if not exists model_definition (
-- 4: Tranied -- 4: Tranied
-- 5: Ready -- 5: Ready
status integer default 1, status integer default 1,
created_on timestamp default current_timestamp created_on timestamp default current_timestamp,
epoch_progress integer default 0
); );
-- drop table if exists model_definition_layer; -- drop table if exists model_definition_layer;

View File

@ -434,7 +434,20 @@
{{/* TODO improve this */}} {{/* TODO improve this */}}
Training the model...<br/> Training the model...<br/>
{{/* TODO Add progress status on definitions */}} {{/* TODO Add progress status on definitions */}}
{{/* TODO Add aility to stop training */}} {{ range .Defs}}
<div>
<div>
{{.Status}}
</div>
<div>
{{.EpochProgress}}
</div>
<div>
{{.Accuracy}}
</div>
</div>
{{ end }}
{{/* TODO Add ability to stop training */}}
</div> </div>
{{/* Model Ready */}} {{/* Model Ready */}}
{{ else if (eq .Model.Status 5)}} {{ else if (eq .Model.Status 5)}}

View File

@ -4,6 +4,14 @@ import pandas as pd
from tensorflow import keras from tensorflow import keras
from tensorflow.data import AUTOTUNE from tensorflow.data import AUTOTUNE
from keras import layers, losses, optimizers from keras import layers, losses, optimizers
import requests
class NotifyServerCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, *args, **kwargs):
if (epoch % 5) == 0:
# TODO change this
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&definition={{.DefId}}')
DATA_DIR = "{{ .DataDir }}" DATA_DIR = "{{ .DataDir }}"
image_size = ({{ .Size }}) image_size = ({{ .Size }})
@ -26,11 +34,15 @@ DATA_DIR_PREPARE = DATA_DIR + "/"
#based on https://www.tensorflow.org/tutorials/load_data/images #based on https://www.tensorflow.org/tutorials/load_data/images
def pathToLabel(path): def pathToLabel(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "") path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpg", "") path = tf.strings.regex_replace(path, ".jpg", "")
path = tf.strings.regex_replace(path, ".jpeg", "") path = tf.strings.regex_replace(path, ".jpeg", "")
path = tf.strings.regex_replace(path, ".png", "") {{ else }}
ERROR
{{ end }}
return table.lookup(tf.strings.as_string([path])) return table.lookup(tf.strings.as_string([path]))
#return tf.strings.as_string([path])
def decode_image(img): def decode_image(img):
{{ if eq .Model.Format "png" }} {{ if eq .Model.Format "png" }}
@ -100,7 +112,7 @@ model.compile(
optimizer=tf.keras.optimizers.Adam(), optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy']) metrics=['accuracy'])
his = model.fit(dataset, validation_data= dataset_validation, epochs=50) his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
acc = his.history["accuracy"] acc = his.history["accuracy"]