feat: closes #39

This commit is contained in:
Andre Henriques 2023-10-12 12:08:12 +01:00
parent c7c6cfcd00
commit f163e25fba
6 changed files with 131 additions and 12 deletions

View File

@ -92,7 +92,34 @@ func handleEdit(handle *Handle) {
"Model": model,
}))
case TRAINING:
fallthrough
type defrow struct {
Status int
EpochProgress int
Accuracy int
}
def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id)
if err != nil {
return c.Error500(err)
}
defer def_rows.Close()
defs := []defrow{}
for def_rows.Next() {
var def defrow
err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy)
if err != nil {
return c.Error500(err)
}
defs = append(defs, def)
}
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
"Model": model,
"Defs": defs,
}))
case PREPARING_ZIP_FILE:
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
"Model": model,

View File

@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
"DefId": definition_id,
}); err != nil {
return
}
@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) {
}
for _, def := range definitions {
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
accuracy, err := trainDefinition(c, model, def.id)
if err != nil {
c.Logger.Error("Failed to train definition!Err:")
@ -480,4 +482,58 @@ func handleTrain(handle *Handle) {
Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
return nil
})
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
// TODO check auth level
if c.Mode != NORMAL {
// This should only handle normal requests
c.Logger.Warn("This function only works with normal")
return c.UnsafeErrorCode(nil, 400, nil)
}
f := r.URL.Query()
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
c.Logger.Warn("Invalid: model_id or definition or epoch")
return c.UnsafeErrorCode(nil, 400, nil)
}
model_id := f.Get("model_id")
def_id := f.Get("definition")
epoch, err := strconv.Atoi(f.Get("epoch"))
if err != nil {
c.Logger.Warn("Epoch is not a number")
// No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil)
}
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
if err != nil {
return c.Error500(err)
}
defer rows.Close()
if !rows.Next() {
c.Logger.Error("Could not get status of model definition")
return c.Error500(nil)
}
var status int
err = rows.Scan(&status)
if err != nil {
return c.Error500(err)
}
if status != 3 {
c.Logger.Warn("Definition not on status 3(training)", "status", status)
// No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil)
}
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
if err != nil {
return c.Error500(err)
}
return nil
})
}

View File

@ -387,6 +387,14 @@ func (c Context) ErrorCode(err error, code int, data AnyMap) *Error {
return &Error{code, nil, c.AddMap(data)}
}
func (c Context) UnsafeErrorCode(err error, code int, data AnyMap) *Error {
if err != nil {
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
c.Logger.Error(err)
}
return &Error{code, nil, c.AddMap(data)}
}
func (c Context) Error500(err error) *Error {
return c.ErrorCode(err, http.StatusInternalServerError, nil)
}
@ -414,6 +422,12 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
var token *string
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: r.URL.Path,
})
for _, r := range r.Cookies() {
if r.Name == "auth" {
token = &r.Value
@ -425,6 +439,8 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
if token == nil {
return &Context{
Mode: mode,
Logger: logger,
Db: handler.Db,
}, nil
}
@ -433,12 +449,6 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
return nil, errors.Join(err, LogoffError)
}
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: r.URL.Path,
})
return &Context{token, user, mode, logger, handler.Db}, nil
}

View File

@ -55,7 +55,8 @@ create table if not exists model_definition (
-- 4: Tranied
-- 5: Ready
status integer default 1,
created_on timestamp default current_timestamp
created_on timestamp default current_timestamp,
epoch_progress integer default 0
);
-- drop table if exists model_definition_layer;

View File

@ -434,7 +434,20 @@
{{/* TODO improve this */}}
Training the model...<br/>
{{/* TODO Add progress status on definitions */}}
{{/* TODO Add aility to stop training */}}
{{ range .Defs}}
<div>
<div>
{{.Status}}
</div>
<div>
{{.EpochProgress}}
</div>
<div>
{{.Accuracy}}
</div>
</div>
{{ end }}
{{/* TODO Add ability to stop training */}}
</div>
{{/* Model Ready */}}
{{ else if (eq .Model.Status 5)}}

View File

@ -4,6 +4,14 @@ import pandas as pd
from tensorflow import keras
from tensorflow.data import AUTOTUNE
from keras import layers, losses, optimizers
import requests
class NotifyServerCallback(tf.keras.callbacks.Callback):
def on_epoch_begin(self, epoch, *args, **kwargs):
if (epoch % 5) == 0:
# TODO change this
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&definition={{.DefId}}')
DATA_DIR = "{{ .DataDir }}"
image_size = ({{ .Size }})
@ -26,11 +34,15 @@ DATA_DIR_PREPARE = DATA_DIR + "/"
#based on https://www.tensorflow.org/tutorials/load_data/images
def pathToLabel(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpg", "")
path = tf.strings.regex_replace(path, ".jpeg", "")
path = tf.strings.regex_replace(path, ".png", "")
{{ else }}
ERROR
{{ end }}
return table.lookup(tf.strings.as_string([path]))
#return tf.strings.as_string([path])
def decode_image(img):
{{ if eq .Model.Format "png" }}
@ -100,7 +112,7 @@ model.compile(
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
his = model.fit(dataset, validation_data= dataset_validation, epochs=50)
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
acc = his.history["accuracy"]