feat: closes #39
This commit is contained in:
parent
c7c6cfcd00
commit
f163e25fba
@ -92,7 +92,34 @@ func handleEdit(handle *Handle) {
|
|||||||
"Model": model,
|
"Model": model,
|
||||||
}))
|
}))
|
||||||
case TRAINING:
|
case TRAINING:
|
||||||
fallthrough
|
|
||||||
|
type defrow struct {
|
||||||
|
Status int
|
||||||
|
EpochProgress int
|
||||||
|
Accuracy int
|
||||||
|
}
|
||||||
|
|
||||||
|
def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
defer def_rows.Close()
|
||||||
|
|
||||||
|
defs := []defrow{}
|
||||||
|
|
||||||
|
for def_rows.Next() {
|
||||||
|
var def defrow
|
||||||
|
err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
defs = append(defs, def)
|
||||||
|
}
|
||||||
|
|
||||||
|
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||||
|
"Model": model,
|
||||||
|
"Defs": defs,
|
||||||
|
}))
|
||||||
case PREPARING_ZIP_FILE:
|
case PREPARING_ZIP_FILE:
|
||||||
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
|
||||||
"Model": model,
|
"Model": model,
|
||||||
|
@ -160,6 +160,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
|
|||||||
"RunPath": run_path,
|
"RunPath": run_path,
|
||||||
"ColorMode": model.ImageMode,
|
"ColorMode": model.ImageMode,
|
||||||
"Model": model,
|
"Model": model,
|
||||||
|
"DefId": definition_id,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -239,6 +240,7 @@ func trainModel(c *Context, model *BaseModel) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, def := range definitions {
|
for _, def := range definitions {
|
||||||
|
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||||
accuracy, err := trainDefinition(c, model, def.id)
|
accuracy, err := trainDefinition(c, model, def.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.Logger.Error("Failed to train definition!Err:")
|
c.Logger.Error("Failed to train definition!Err:")
|
||||||
@ -480,4 +482,58 @@ func handleTrain(handle *Handle) {
|
|||||||
Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
|
Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
|
||||||
|
// TODO check auth level
|
||||||
|
if c.Mode != NORMAL {
|
||||||
|
// This should only handle normal requests
|
||||||
|
c.Logger.Warn("This function only works with normal")
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
f := r.URL.Query()
|
||||||
|
|
||||||
|
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
|
||||||
|
c.Logger.Warn("Invalid: model_id or definition or epoch")
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
model_id := f.Get("model_id")
|
||||||
|
def_id := f.Get("definition")
|
||||||
|
epoch, err := strconv.Atoi(f.Get("epoch"))
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Warn("Epoch is not a number")
|
||||||
|
// No need to improve message because this function is only called internaly
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
c.Logger.Error("Could not get status of model definition")
|
||||||
|
return c.Error500(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
var status int
|
||||||
|
err = rows.Scan(&status)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status != 3 {
|
||||||
|
c.Logger.Warn("Definition not on status 3(training)", "status", status)
|
||||||
|
// No need to improve message because this function is only called internaly
|
||||||
|
return c.UnsafeErrorCode(nil, 400, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
|
||||||
|
if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
@ -387,6 +387,14 @@ func (c Context) ErrorCode(err error, code int, data AnyMap) *Error {
|
|||||||
return &Error{code, nil, c.AddMap(data)}
|
return &Error{code, nil, c.AddMap(data)}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c Context) UnsafeErrorCode(err error, code int, data AnyMap) *Error {
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
|
||||||
|
c.Logger.Error(err)
|
||||||
|
}
|
||||||
|
return &Error{code, nil, c.AddMap(data)}
|
||||||
|
}
|
||||||
|
|
||||||
func (c Context) Error500(err error) *Error {
|
func (c Context) Error500(err error) *Error {
|
||||||
return c.ErrorCode(err, http.StatusInternalServerError, nil)
|
return c.ErrorCode(err, http.StatusInternalServerError, nil)
|
||||||
}
|
}
|
||||||
@ -414,6 +422,12 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
|||||||
|
|
||||||
var token *string
|
var token *string
|
||||||
|
|
||||||
|
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||||
|
ReportTimestamp: true,
|
||||||
|
TimeFormat: time.Kitchen,
|
||||||
|
Prefix: r.URL.Path,
|
||||||
|
})
|
||||||
|
|
||||||
for _, r := range r.Cookies() {
|
for _, r := range r.Cookies() {
|
||||||
if r.Name == "auth" {
|
if r.Name == "auth" {
|
||||||
token = &r.Value
|
token = &r.Value
|
||||||
@ -425,6 +439,8 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
|||||||
if token == nil {
|
if token == nil {
|
||||||
return &Context{
|
return &Context{
|
||||||
Mode: mode,
|
Mode: mode,
|
||||||
|
Logger: logger,
|
||||||
|
Db: handler.Db,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -433,12 +449,6 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request)
|
|||||||
return nil, errors.Join(err, LogoffError)
|
return nil, errors.Join(err, LogoffError)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
|
||||||
ReportTimestamp: true,
|
|
||||||
TimeFormat: time.Kitchen,
|
|
||||||
Prefix: r.URL.Path,
|
|
||||||
})
|
|
||||||
|
|
||||||
return &Context{token, user, mode, logger, handler.Db}, nil
|
return &Context{token, user, mode, logger, handler.Db}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,7 +55,8 @@ create table if not exists model_definition (
|
|||||||
-- 4: Tranied
|
-- 4: Tranied
|
||||||
-- 5: Ready
|
-- 5: Ready
|
||||||
status integer default 1,
|
status integer default 1,
|
||||||
created_on timestamp default current_timestamp
|
created_on timestamp default current_timestamp,
|
||||||
|
epoch_progress integer default 0
|
||||||
);
|
);
|
||||||
|
|
||||||
-- drop table if exists model_definition_layer;
|
-- drop table if exists model_definition_layer;
|
||||||
|
@ -434,7 +434,20 @@
|
|||||||
{{/* TODO improve this */}}
|
{{/* TODO improve this */}}
|
||||||
Training the model...<br/>
|
Training the model...<br/>
|
||||||
{{/* TODO Add progress status on definitions */}}
|
{{/* TODO Add progress status on definitions */}}
|
||||||
{{/* TODO Add aility to stop training */}}
|
{{ range .Defs}}
|
||||||
|
<div>
|
||||||
|
<div>
|
||||||
|
{{.Status}}
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
{{.EpochProgress}}
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
{{.Accuracy}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{{ end }}
|
||||||
|
{{/* TODO Add ability to stop training */}}
|
||||||
</div>
|
</div>
|
||||||
{{/* Model Ready */}}
|
{{/* Model Ready */}}
|
||||||
{{ else if (eq .Model.Status 5)}}
|
{{ else if (eq .Model.Status 5)}}
|
||||||
|
@ -4,6 +4,14 @@ import pandas as pd
|
|||||||
from tensorflow import keras
|
from tensorflow import keras
|
||||||
from tensorflow.data import AUTOTUNE
|
from tensorflow.data import AUTOTUNE
|
||||||
from keras import layers, losses, optimizers
|
from keras import layers, losses, optimizers
|
||||||
|
import requests
|
||||||
|
|
||||||
|
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||||
|
def on_epoch_begin(self, epoch, *args, **kwargs):
|
||||||
|
if (epoch % 5) == 0:
|
||||||
|
# TODO change this
|
||||||
|
requests.get(f'http://localhost:8000/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch}&definition={{.DefId}}')
|
||||||
|
|
||||||
|
|
||||||
DATA_DIR = "{{ .DataDir }}"
|
DATA_DIR = "{{ .DataDir }}"
|
||||||
image_size = ({{ .Size }})
|
image_size = ({{ .Size }})
|
||||||
@ -26,11 +34,15 @@ DATA_DIR_PREPARE = DATA_DIR + "/"
|
|||||||
#based on https://www.tensorflow.org/tutorials/load_data/images
|
#based on https://www.tensorflow.org/tutorials/load_data/images
|
||||||
def pathToLabel(path):
|
def pathToLabel(path):
|
||||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||||
|
{{ if eq .Model.Format "png" }}
|
||||||
|
path = tf.strings.regex_replace(path, ".png", "")
|
||||||
|
{{ else if eq .Model.Format "jpeg" }}
|
||||||
path = tf.strings.regex_replace(path, ".jpg", "")
|
path = tf.strings.regex_replace(path, ".jpg", "")
|
||||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||||
path = tf.strings.regex_replace(path, ".png", "")
|
{{ else }}
|
||||||
|
ERROR
|
||||||
|
{{ end }}
|
||||||
return table.lookup(tf.strings.as_string([path]))
|
return table.lookup(tf.strings.as_string([path]))
|
||||||
#return tf.strings.as_string([path])
|
|
||||||
|
|
||||||
def decode_image(img):
|
def decode_image(img):
|
||||||
{{ if eq .Model.Format "png" }}
|
{{ if eq .Model.Format "png" }}
|
||||||
@ -100,7 +112,7 @@ model.compile(
|
|||||||
optimizer=tf.keras.optimizers.Adam(),
|
optimizer=tf.keras.optimizers.Adam(),
|
||||||
metrics=['accuracy'])
|
metrics=['accuracy'])
|
||||||
|
|
||||||
his = model.fit(dataset, validation_data= dataset_validation, epochs=50)
|
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()])
|
||||||
|
|
||||||
acc = his.history["accuracy"]
|
acc = his.history["accuracy"]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user