feat: closes #39
This commit is contained in:
parent
c7c6cfcd00
commit
f163e25fba
@ -92,7 +92,34 @@ func handleEdit(handle *Handle) {
|
||||
"Model": model,
|
||||
}))
|
||||
case TRAINING:
|
||||
fallthrough
|
||||
|
||||
type defrow struct {
|
||||
Status int
|
||||
EpochProgress int
|
||||
Accuracy int
|
||||
}
|
||||
|
||||
def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
defer def_rows.Close()
|
||||
|
||||
defs := []defrow{}
|
||||
|
||||
for def_rows.Next() {
|
||||
var def defrow
|
||||
err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
defs = append(defs, def)
|
||||
}
|
||||
|
||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
"Defs": defs,
|
||||
}))
|
||||
case PREPARING_ZIP_FILE:
|
||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||
"Model": model,
|
||||
|
@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
||||
"RunPath": run_path,
|
||||
"ColorMode": model.ImageMode,
|
||||
"Model": model,
|
||||
"DefId": definition_id,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) {
|
||||
}
|
||||
|
||||
for _, def := range definitions {
|
||||
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||
accuracy, err := trainDefinition(c, model, def.id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to train definition!Err:")
|
||||
@ -480,4 +482,58 @@ func handleTrain(handle *Handle) {
|
||||
Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
|
||||
return nil
|
||||
})
|
||||
|
||||
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||
// TODO check auth level
|
||||
if c.Mode != NORMAL {
|
||||
// This should only handle normal requests
|
||||
c.Logger.Warn("This function only works with normal")
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
f := r.URL.Query()
|
||||
|
||||
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
|
||||
c.Logger.Warn("Invalid: model_id or definition or epoch")
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
model_id := f.Get("model_id")
|
||||
def_id := f.Get("definition")
|
||||
epoch, err := strconv.Atoi(f.Get("epoch"))
|
||||
if err != nil {
|
||||
c.Logger.Warn("Epoch is not a number")
|
||||
// No need to improve message because this function is only called internaly
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
c.Logger.Error("Could not get status of model definition")
|
||||
return c.Error500(nil)
|
||||
}
|
||||
|
||||
var status int
|
||||
err = rows.Scan(&status)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
if status != 3 {
|
||||
c.Logger.Warn("Definition not on status 3(training)", "status", status)
|
||||
// No need to improve message because this function is only called internaly
|
||||
return c.UnsafeErrorCode(nil, 400, nil)
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
@ -387,6 +387,14 @@ func (c Context) ErrorCode(err error, code int, data AnyMap) *Error {
|
||||
return &Error{code, nil, c.AddMap(data)}
|
||||
}
|
||||
|
||||
func (c Context) UnsafeErrorCode(err error, code int, data AnyMap) *Error {
|
||||
if err != nil {
|
||||
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
|
||||
c.Logger.Error(err)
|
||||
}
|
||||
return &Error{code, nil, c.AddMap(data)}
|
||||
}
|
||||
|
||||
func (c Context) Error500(err error) *Error {
|
||||
return c.ErrorCode(err, http.StatusInternalServerError, nil)
|
||||
}
|
||||
@ -414,6 +422,12 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
||||
|
||||
var token *string
|
||||
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: r.URL.Path,
|
||||
})
|
||||
|
||||
for _, r := range r.Cookies() {
|
||||
if r.Name == "auth" {
|
||||
token = &r.Value
|
||||
@ -425,6 +439,8 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
||||
if token == nil {
|
||||
return &Context{
|
||||
Mode: mode,
|
||||
Logger: logger,
|
||||
Db: handler.Db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -433,12 +449,6 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
||||
return nil, errors.Join(err, LogoffError)
|
||||
}
|
||||
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: r.URL.Path,
|
||||
})
|
||||
|
||||
return &Context{token, user, mode, logger, handler.Db}, nil
|
||||
}
|
||||
|
||||
|
@ -55,7 +55,8 @@ create table if not exists model_definition (
|
||||
-- 4: Tranied
|
||||
-- 5: Ready
|
||||
status integer default 1,
|
||||
created_on timestamp default current_timestamp
|
||||
created_on timestamp default current_timestamp,
|
||||
epoch_progress integer default 0
|
||||
);
|
||||
|
||||
-- drop table if exists model_definition_layer;
|
||||
|
@ -434,7 +434,20 @@
|
||||
{{/* TODO improve this */}}
|
||||
Training the model...<br/>
|
||||
{{/* TODO Add progress status on definitions */}}
|
||||
{{/* TODO Add aility to stop training */}}
|
||||
{{ range .Defs}}
|
||||
<div>
|
||||
<div>
|
||||
{{.Status}}
|
||||
</div>
|
||||
<div>
|
||||
{{.EpochProgress}}
|
||||
</div>
|
||||
<div>
|
||||
{{.Accuracy}}
|
||||
</div>
|
||||
</div>
|
||||
{{ end }}
|
||||
{{/* TODO Add ability to stop training */}}
|
||||
</div>
|
||||
{{/* Model Ready */}}
|
||||
{{ else if (eq .Model.Status 5)}}
|
||||
|
@ -4,6 +4,14 @@ import pandas as pd
|
||||
from tensorflow import keras
|
||||
from tensorflow.data import AUTOTUNE
|
||||
from keras import layers, losses, optimizers
|
||||
import requests
|
||||
|
||||
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||
def on_epoch_begin(self, epoch, *args, **kwargs):
|
||||
if (epoch % 5) == 0:
|
||||
# TODO change this
|
||||
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&definition={{.DefId}}')
|
||||
|
||||
|
||||
DATA_DIR = "{{ .DataDir }}"
|
||||
image_size = ({{ .Size }})
|
||||
@ -26,11 +34,15 @@ DATA_DIR_PREPARE = DATA_DIR + "/"
|
||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
def pathToLabel(path):
|
||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||
{{ if eq .Model.Format "png" }}
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
{{ else if eq .Model.Format "jpeg" }}
|
||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
{{ else }}
|
||||
ERROR
|
||||
{{ end }}
|
||||
return table.lookup(tf.strings.as_string([path]))
|
||||
#return tf.strings.as_string([path])
|
||||
|
||||
def decode_image(img):
|
||||
{{ if eq .Model.Format "png" }}
|
||||
@ -100,7 +112,7 @@ model.compile(
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=50)
|
||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
|
||||
|
||||
acc = his.history["accuracy"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user