From 2faf90f462e35515b0707a50bfc0049ba9e98b88 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Mon, 8 Apr 2024 15:47:31 +0100 Subject: [PATCH] chore: added config file and fixed updates not showing while training --- config.toml | 2 + go.mod | 1 + go.sum | 2 + logic/models/train/train.go | 43 ++++++++++---------- logic/utils/config.go | 32 +++++++++++++++ logic/utils/handler.go | 50 ++++++++++++------------ main.go | 5 ++- views/py/python_model_template.py | 4 +- views/py/python_model_template_expand.py | 2 +- 9 files changed, 91 insertions(+), 50 deletions(-) create mode 100644 config.toml create mode 100644 logic/utils/config.go diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..fc52843 --- /dev/null +++ b/config.toml @@ -0,0 +1,2 @@ +PORT=5002 +HOSTNAME="https://testing.andr3h3nriqu3s.com" diff --git a/go.mod b/go.mod index 4b001f6..1bfffc2 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( ) require ( + github.com/BurntSushi/toml v1.3.2 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/lipgloss v0.9.1 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect diff --git a/go.sum b/go.sum index 8032581..e170f8b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= +github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/charmbracelet/lipgloss v0.8.0 h1:IS00fk4XAHcf8uZKc3eHeMUTCxUH6NkaTrdyCQk84RU= diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 3f62a1e..0707349 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -244,6 +244,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr "SaveModelPath": path.Join(getDir(), result_path), "Depth": classCount, "StartPoint": 0, + "Host": (*c.Handle).Config.Hostname, }); err != nil { return } @@ -292,9 +293,9 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i if err != nil { return } - c.Logger.Info("test here", "count", co) + c.Logger.Info("test here", "count", co) count_re = co.Count - count := co.Count + count := co.Count if count == 0 { err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN) @@ -416,7 +417,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string var last *layerrow = nil got_2 := false - var first *layerrow = nil + var first *layerrow = nil for layers.Next() { var row = layerrow{} @@ -424,10 +425,10 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string return } - // Keep track of the first layer so we can keep the size of the image - if first == nil { - first = &row - } + // Keep track of the first layer so we can keep the size of the image + if first == nil { + first = &row + } row.LayerNum = i row.Shape = shapeToSize(row.Shape) @@ -500,6 +501,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string "SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id), "Depth": classCount, "StartPoint": 0, + "Host": (*c.Handle).Config.Hostname, }); err != nil { return } @@ -648,6 +650,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load "SaveModelPath": path.Join(getDir(), result_path), "Depth": classCount, "StartPoint": 0, + "Host": (*c.Handle).Config.Hostname, }); err != nil { return } @@ -1618,8 +1621,8 @@ func trainExpandable(c *Context, model *BaseModel) { } func trainRetrain(c *Context, model *BaseModel, defId string) { - var err error - + var err error + failed := func() { ResetClasses(c, model) ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED) @@ -1631,27 +1634,27 @@ func trainRetrain(c *Context, model *BaseModel, defId string) { acc, err := trainDefinitionExpandExp(c, model, defId, false) if err != nil { c.Logger.Error("Failed to retrain the model", "err", err) - failed() - return + failed() + return } c.Logger.Info("Retrained model", "accuracy", acc) - // TODO check accuracy + // TODO check accuracy - err = UpdateStatus(c, "models", model.Id, READY) - if err != nil { - failed() - return - } + err = UpdateStatus(c, "models", model.Id, READY) + if err != nil { + failed() + return + } - c.Logger.Info("model updaded") + c.Logger.Info("model updaded") _, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id) if err != nil { c.Logger.Error("Error while updating the classes", "error", err) - failed() - return + failed() + return } } diff --git a/logic/utils/config.go b/logic/utils/config.go new file mode 100644 index 0000000..a5257b7 --- /dev/null +++ b/logic/utils/config.go @@ -0,0 +1,32 @@ +package utils + +import ( + "os" + + toml "github.com/BurntSushi/toml" + "github.com/charmbracelet/log" +) + +type Config struct { + Hostname string + Port int +} + +func LoadConfig() Config { + + log.Info("Loading the config file") + + dat, err := os.ReadFile("./config.toml") + if err != nil { + log.Error("Failed to load config file", "err", err) + // Use default values + return Config{ + Hostname: "localhost", + Port: 8000, + } + } + + var conf Config + _, err = toml.Decode(string(dat), &conf) + return conf +} diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 03d470f..94c9edc 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -42,6 +42,7 @@ type Handle struct { gets []HandleFunc posts []HandleFunc deletes []HandleFunc + Config Config } func decodeBody(r *http.Request) (string, *Error) { @@ -89,7 +90,7 @@ func (x *Handle) handleGets(context *Context) { return } } - context.ShowMessage = false + context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } @@ -100,7 +101,7 @@ func (x *Handle) handlePosts(context *Context) { return } } - context.ShowMessage = false + context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } @@ -111,7 +112,7 @@ func (x *Handle) handleDeletes(context *Context) { return } } - context.ShowMessage = false + context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } @@ -138,6 +139,7 @@ type Context struct { R *http.Request Tx *sql.Tx ShowMessage bool + Handle *Handle } func (c Context) Prepare(str string) (*sql.Stmt, error) { @@ -315,11 +317,12 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW // TODO check that the token is still valid if token == nil { return &Context{ - Logger: logger, - Db: handler.Db, - Writer: w, - R: r, - ShowMessage: true, + Logger: logger, + Db: handler.Db, + Writer: w, + R: r, + ShowMessage: true, + Handle: &x, }, nil } @@ -328,7 +331,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW return nil, errors.Join(err, LogoffError) } - return &Context{token, user, logger, handler.Db, w, r, nil, true}, nil + return &Context{token, user, logger, handler.Db, w, r, nil, true, &x}, nil } func contextlessLogoff(w http.ResponseWriter) { @@ -468,12 +471,12 @@ func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileType }) } -func NewHandler(db *sql.DB) *Handle { +func NewHandler(db *sql.DB, config Config) *Handle { var gets []HandleFunc var posts []HandleFunc var deletes []HandleFunc - x := &Handle{db, gets, posts, deletes} + x := &Handle{db, gets, posts, deletes, config} http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -482,11 +485,11 @@ func NewHandler(db *sql.DB) *Handle { w.Header().Add("Access-Control-Allow-Methods", "*") // Decide answertype - if !(r.Header.Get("content-type") == "application/json" || r.Header.Get("response-type") == "application/json") { + /* if !(r.Header.Get("content-type") == "application/json" || r.Header.Get("response-type") == "application/json") { w.WriteHeader(500) w.Write([]byte("Please set content-type to application/json or set response-type to application/json\n")) return - } + }*/ if !strings.HasPrefix(r.URL.Path, "/api") { w.WriteHeader(404) @@ -513,13 +516,13 @@ func NewHandler(db *sql.DB) *Handle { x.handleDeletes(context) } else if r.Method == "OPTIONS" { // do nothing - } else { - panic("TODO handle method: " + r.Method) - } - - if context.ShowMessage { - context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path) - } + } else { + panic("TODO handle method: " + r.Method) + } + + if context.ShowMessage { + context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path) + } }) return x @@ -528,10 +531,5 @@ func NewHandler(db *sql.DB) *Handle { func (x Handle) Startup() { log.Info("Starting up!\n") - port := os.Getenv("PORT") - if port == "" { - port = "8000" - } - - log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", port), nil)) + log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", x.Config.Port), nil)) } diff --git a/main.go b/main.go index 0f0fce4..0a77edd 100644 --- a/main.go +++ b/main.go @@ -33,8 +33,11 @@ func main() { defer db.Close() log.Info("Starting server on :5002!") + config := LoadConfig() + log.Info("Config loaded!", "config", config) + //TODO check if file structure exists to save data - handle := NewHandler(db) + handle := NewHandler(db, config) // TODO remove this before commiting _, err = db.Exec("update models set status=$1 where status=$2", models_utils.FAILED_TRAINING, models_utils.TRAINING) diff --git a/views/py/python_model_template.py b/views/py/python_model_template.py index 933effc..3552ed1 100644 --- a/views/py/python_model_template.py +++ b/views/py/python_model_template.py @@ -9,9 +9,9 @@ import requests class NotifyServerCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, log, *args, **kwargs): {{ if .HeadId }} - requests.get(f'http://localhost:8000/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}') + requests.get(f'{{ .Host }}/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}') {{ else }} - requests.get(f'http://localhost:8000/api/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}') + requests.get(f'{{ .Host }}/api/model/epoch/update?model_id={{.Model.Id}}&epoch={epoch + 1}&accuracy={log["accuracy"]}&definition={{.DefId}}') {{end}} diff --git a/views/py/python_model_template_expand.py b/views/py/python_model_template_expand.py index 18e388e..35127eb 100644 --- a/views/py/python_model_template_expand.py +++ b/views/py/python_model_template_expand.py @@ -10,7 +10,7 @@ import numpy as np class NotifyServerCallback(tf.keras.callbacks.Callback): def on_epoch_end(self, epoch, log, *args, **kwargs): - requests.get(f'http://localhost:8000/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}') + requests.get(f'{{ .Host }}/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}') DATA_DIR = "{{ .DataDir }}"