diff --git a/logic/models/classes/list.go b/logic/models/classes/list.go index b6433d9..99ac053 100644 --- a/logic/models/classes/list.go +++ b/logic/models/classes/list.go @@ -6,6 +6,7 @@ import ( "strconv" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" ) func HandleList(handle *Handle) { @@ -48,6 +49,11 @@ func HandleList(handle *Handle) { return Error500(nil) } + model, err := GetBaseModel(c.Db, model_id) + if err != nil { + return Error500(err) + } + rows, err := handle.Db.Query("select id, file_path, model_mode, status from model_data_point where class_id=$1 limit 10 offset $2;", id, page * 10) if err != nil { return Error500(err) @@ -95,7 +101,7 @@ func HandleList(handle *Handle) { "Page": page, "Id": id, "Name": name, - "ModelId": model_id, + "Model": model, })) return nil }) diff --git a/logic/models/edit.go b/logic/models/edit.go index c9ece01..b892781 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -24,11 +24,11 @@ func handleEdit(handle *Handle) { } // TODO handle admin users - rows, err := handle.Db.Query("select name, status, width, height, color_mode from models where id=$1 and user_id=$2;", id, c.User.Id) + rows, err := handle.Db.Query("select name, status, width, height, color_mode, format from models where id=$1 and user_id=$2;", id, c.User.Id) if err != nil { return Error500(err) } - defer rows.Close() + defer rows.Close() if !rows.Next() { return ErrorCode(nil, http.StatusNotFound, AnyMap{ @@ -44,12 +44,13 @@ func handleEdit(handle *Handle) { Width *int Height *int Color_mode *string + Format string } var model rowmodel = rowmodel{} model.Id = id - err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode) + err = rows.Scan(&model.Name, &model.Status, &model.Width, &model.Height, &model.Color_mode, &model.Format) if err != nil { return Error500(err) } @@ -70,60 +71,64 @@ func handleEdit(handle *Handle) { })) case CONFIRM_PRE_TRAINING: - wrong_number, err := model_classes.GetNumberOfWrongDataPoints(c.Db, model.Id) - if err != nil { return c.Error500(err) } + wrong_number, err := model_classes.GetNumberOfWrongDataPoints(c.Db, model.Id) + if err != nil { + return c.Error500(err) + } - cls, err := model_classes.ListClasses(handle.Db, id) - if err != nil { return c.Error500(err) } + cls, err := model_classes.ListClasses(handle.Db, id) + if err != nil { + return c.Error500(err) + } - has_data, err := model_classes.ModelHasDataPoints(handle.Db, id) - if err != nil { - return Error500(err) - } + has_data, err := model_classes.ModelHasDataPoints(handle.Db, id) + if err != nil { + return Error500(err) + } LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ - "Model": model, - "Classes": cls, - "HasData": has_data, - "NumberOfInvalidImages": wrong_number, + "Model": model, + "Classes": cls, + "HasData": has_data, + "NumberOfInvalidImages": wrong_number, })) case READY: LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ "Model": model, })) - case TRAINING: - - type defrow struct { - Status int - EpochProgress int - Accuracy int - } + case TRAINING: - def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id) - if err != nil { - return c.Error500(err) - } - defer def_rows.Close() + type defrow struct { + Status int + EpochProgress int + Accuracy int + } - defs := []defrow{} + def_rows, err := c.Db.Query("select status, epoch_progress, accuracy from model_definition where model_id=$1", model.Id) + if err != nil { + return c.Error500(err) + } + defer def_rows.Close() - for def_rows.Next() { - var def defrow - err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy) - if err != nil { - return c.Error500(err) - } - defs = append(defs, def) - } + defs := []defrow{} - LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ - "Model": model, - "Defs": defs, - })) - case PREPARING_ZIP_FILE: - LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ - "Model": model, - })) + for def_rows.Next() { + var def defrow + err = def_rows.Scan(&def.Status, &def.EpochProgress, &def.Accuracy) + if err != nil { + return c.Error500(err) + } + defs = append(defs, def) + } + + LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ + "Model": model, + "Defs": defs, + })) + case PREPARING_ZIP_FILE: + LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ + "Model": model, + })) default: fmt.Printf("Unkown Status: %d\n", model.Status) return Error500(nil) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 5a1bb10..b25e8db 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -387,8 +387,8 @@ func trainModel(c *Context, model *BaseModel) { ModelUpdateStatus(c, model.Id, READY) } -func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) { - rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id) +func removeFailedDataPoints(c *Context, model *BaseModel) (err error) { + rows, err := c.Db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model.Id) if err != nil { return } @@ -402,13 +402,18 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) { if err != nil { return } - err = os.RemoveAll(path.Join(base_path, dataPointId+model.Format)) + + p := path.Join(base_path, dataPointId + "." + model.Format) + + c.Logger.Warn("Removing image", "path", p) + + err = os.RemoveAll(p) if err != nil { return } } - _, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id) + _, err = c.Db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id) return } @@ -471,7 +476,7 @@ func generateDefinitions(c *Context, model *BaseModel, number_of_models int) *Er return c.Error500(err) } - err = removeFailedDataPoints(c.Db, model) + err = removeFailedDataPoints(c, model) if err != nil { return c.Error500(err) } diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 0d4ebdf..1c51f59 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -13,24 +13,24 @@ import ( "time" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" - "github.com/charmbracelet/log" + "github.com/charmbracelet/log" ) -func Mul (n1 int, n2 int) int { - return n1 * n2 +func Mul(n1 int, n2 int) int { + return n1 * n2 } -func Add (n1 int, n2 int) int { - return n1 + n2 +func Add(n1 int, n2 int) int { + return n1 + n2 } func baseLoadTemplate(base string, path string) (*template.Template, any) { - funcs := map[string]any { - "startsWith": strings.HasPrefix, - "replace": strings.Replace, - "mul": Mul, - "add": Add, - } + funcs := map[string]any{ + "startsWith": strings.HasPrefix, + "replace": strings.Replace, + "mul": Mul, + "add": Add, + } return template.New(base).Funcs(funcs).ParseFiles( "./views/"+base, "./views/"+path, @@ -97,27 +97,27 @@ func LoadHtml(writer http.ResponseWriter, path string, data interface{}) { } func LoadDefineTemplate(writer http.ResponseWriter, path string, base string, data AnyMap) { - if data == nil { - data = map[string]interface{} { - "Error": true, - } - } else { - data["Error"] = true - } + if data == nil { + data = map[string]interface{}{ + "Error": true, + } + } else { + data["Error"] = true + } - funcs := map[string]any { - "startsWith": strings.HasPrefix, - "mul": Mul, - "replace": strings.Replace, - "add": Add, - } + funcs := map[string]any{ + "startsWith": strings.HasPrefix, + "mul": Mul, + "replace": strings.Replace, + "add": Add, + } tmpl, err := template.New("").Funcs(funcs).Parse("{{template \"" + base + "\" . }}") - if err != nil { - panic("Lol") - } + if err != nil { + panic("Lol") + } - tmpl, err = tmpl.ParseFiles( + tmpl, err = tmpl.ParseFiles( "./views/"+path, "./views/partials/header.html", ) @@ -222,10 +222,10 @@ func handleError(err *Error, w http.ResponseWriter, context *Context) { if err != nil { data := context.AddMap(err.data) if err.Code == http.StatusNotFound { - if context.Mode == HTML { - w.WriteHeader(309) - context.Mode = HTMLFULL - } + if context.Mode == HTML { + w.WriteHeader(309) + context.Mode = HTMLFULL + } LoadBasedOnAnswer(context.Mode, w, "404.html", data) return } @@ -340,46 +340,46 @@ func AnswerTemplate(path string, data AnyMap, authLevel int) func(w http.Respons } type Context struct { - Token *string - User *dbtypes.User - Mode AnswerType - Logger *log.Logger - Db *sql.DB + Token *string + User *dbtypes.User + Mode AnswerType + Logger *log.Logger + Db *sql.DB } func (c Context) Error400(err error, message string, w http.ResponseWriter, path string, base string, data AnyMap) *Error { - c.SetReportCaller(true) - c.Logger.Error(message) - c.SetReportCaller(false) + c.SetReportCaller(true) + c.Logger.Error(message) + c.SetReportCaller(false) if err != nil { c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", http.StatusBadRequest) c.Logger.Error(err) } - if c.Mode == JSON { - return &Error{http.StatusBadRequest, nil, c.AddMap(data)} - } + if c.Mode == JSON { + return &Error{http.StatusBadRequest, nil, c.AddMap(data)} + } - LoadDefineTemplate(w, path, base, c.AddMap(data)) - return nil + LoadDefineTemplate(w, path, base, c.AddMap(data)) + return nil } func (c Context) SetReportCaller(report bool) { - if (report) { - c.Logger.SetCallerOffset(2) - c.Logger.SetReportCaller(true) - } else { - c.Logger.SetCallerOffset(1) - c.Logger.SetReportCaller(false) - } + if report { + c.Logger.SetCallerOffset(2) + c.Logger.SetReportCaller(true) + } else { + c.Logger.SetCallerOffset(1) + c.Logger.SetReportCaller(false) + } } func (c Context) ErrorCode(err error, code int, data AnyMap) *Error { - if (code == 400) { - c.SetReportCaller(true) - c.Logger.Warn("When returning BadRequest(400) please use context.Error400\n") - c.SetReportCaller(false) - } + if code == 400 { + c.SetReportCaller(true) + c.Logger.Warn("When returning BadRequest(400) please use context.Error400\n") + c.SetReportCaller(false) + } if err != nil { c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code) c.Logger.Error(err) @@ -422,11 +422,11 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) var token *string - logger := log.NewWithOptions(os.Stdout, log.Options{ - ReportTimestamp: true, - TimeFormat: time.Kitchen, - Prefix: r.URL.Path, - }) + logger := log.NewWithOptions(os.Stdout, log.Options{ + ReportTimestamp: true, + TimeFormat: time.Kitchen, + Prefix: r.URL.Path, + }) for _, r := range r.Cookies() { if r.Name == "auth" { @@ -438,9 +438,9 @@ func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) if token == nil { return &Context{ - Mode: mode, - Logger: logger, - Db: handler.Db, + Mode: mode, + Logger: logger, + Db: handler.Db, }, nil } @@ -461,7 +461,7 @@ func Redirect(path string, mode AnswerType, w http.ResponseWriter, r *http.Reque return } if mode&(HTMLFULL|HTML) != 0 { - w.Header().Add("HX-Redirect", path) + w.Header().Add("HX-Redirect", path) w.WriteHeader(204) } else { w.WriteHeader(http.StatusSeeOther) @@ -517,7 +517,7 @@ func (x Handle) StaticFiles(pathTest string, fileType string, contentType string } func ErrorCode(err error, code int, data AnyMap) *Error { - log.Warn("This function is deprecated please use the one provided by context") + log.Warn("This function is deprecated please use the one provided by context") // TODO Improve Logging if err != nil { fmt.Printf("Something went wrong returning with: %d\n.Err:\n", code) @@ -527,7 +527,7 @@ func ErrorCode(err error, code int, data AnyMap) *Error { } func Error500(err error) *Error { - log.Warn("This function is deprecated please use the one provided by context") + log.Warn("This function is deprecated please use the one provided by context") return ErrorCode(err, http.StatusInternalServerError, nil) } @@ -556,6 +556,42 @@ func (x Handle) ReadFiles(pathTest string, baseFilePath string, fileType string, }) } +func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { + http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { + user_path := r.URL.Path[len(pathTest):] + + // fmt.Printf("Requested path: %s\n", user_path) + + found := false + index := -1; + + for i, fileType := range fileTypes { + if strings.HasSuffix(user_path, fileType) { + found = true + index = i; + break + } + } + + if !found { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("File not found")) + return + } + + bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) + if err != nil { + fmt.Println(err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Failed to load file")) + return + } + + w.Header().Set("Content-Type", contentTypes[index]) + w.Write(bytes) + }) +} + func NewHandler(db *sql.DB) *Handle { var gets []HandleFunc diff --git a/main.go b/main.go index fc60c77..0017644 100644 --- a/main.go +++ b/main.go @@ -31,25 +31,24 @@ func main() { defer db.Close() fmt.Println("Starting server on :8000!") - //TODO check if file structure exists to save data + //TODO check if file structure exists to save data handle := NewHandler(db) - _, err = db.Exec("update models set status=$1 where status=$2", models_utils.FAILED_TRAINING, models_utils.TRAINING); - if err != nil { - panic(err) - } - + _, err = db.Exec("update models set status=$1 where status=$2", models_utils.FAILED_TRAINING, models_utils.TRAINING) + if err != nil { + panic(err) + } // TODO Handle this in other way - handle.StaticFiles("/styles/", ".css", "text/css"); - handle.StaticFiles("/js/", ".js", "text/javascript"); - handle.ReadFiles("/imgs/", "views", ".png", "image/png;"); - handle.ReadFiles("/savedData/", ".", ".png", "image/png;"); + handle.StaticFiles("/styles/", ".css", "text/css") + handle.StaticFiles("/js/", ".js", "text/javascript") + handle.ReadFiles("/imgs/", "views", ".png", "image/png;") + handle.ReadTypesFiles("/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"}) handle.GetHTML("/", AnswerTemplate("index.html", nil, 0)) - usersEndpints(db, handle) - HandleModels(handle) + usersEndpints(db, handle) + HandleModels(handle) handle.Startup() } diff --git a/views/models/edit.html b/views/models/edit.html index fbc5f5a..011d7c6 100644 --- a/views/models/edit.html +++ b/views/models/edit.html @@ -81,7 +81,7 @@