diff --git a/logic/models/add.go b/logic/models/add.go index 7ce4e11..db00819 100644 --- a/logic/models/add.go +++ b/logic/models/add.go @@ -4,38 +4,41 @@ import ( "bytes" "fmt" "image" - _ "image/png" "image/color" + _ "image/jpeg" + _ "image/png" "io" "net/http" "os" "path" - . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) -func loadBaseImage(handle *Handle, id string) { +func loadBaseImage(c *Context, id string) { // TODO handle more types than png infile, err := os.Open(path.Join("savedData", id, "baseimage.png")) if err != nil { - // TODO better logging - fmt.Println(err) - fmt.Printf("Failed to read image for model with id %s\n", id) - ModelUpdateStatus(handle, id, -1) + c.Logger.Errorf("Failed to read image for model with id %s\n", id) + c.Logger.Error(err) + ModelUpdateStatus(c, id, FAILED_PREPARING) return } defer infile.Close() src, format, err := image.Decode(infile) if err != nil { - // TODO better logging - fmt.Println(err) - fmt.Printf("Failed to load image for model with id %s\n", id) - ModelUpdateStatus(handle, id, -1) + c.Logger.Errorf("Failed to decode image for model with id %s\n", id) + c.Logger.Error(err) + ModelUpdateStatus(c, id, FAILED_PREPARING) return } - if format != "png" { + switch format { + case "png": + case "jpeg": + break + default: // TODO better logging fmt.Printf("Found unkown format '%s'\n", format) panic("Handle diferent files than .png") @@ -51,6 +54,8 @@ func loadBaseImage(handle *Handle, id string) { fallthrough case color.GrayModel: model_color = "greyscale" + case color.YCbCrModel: + model_color = "rgb" default: fmt.Println("Do not know how to handle this color model") @@ -58,8 +63,6 @@ func loadBaseImage(handle *Handle, id string) { fmt.Println("Color is rgb") } else if src.ColorModel() == color.NRGBAModel { fmt.Println("Color is nrgb") - } else if src.ColorModel() == color.YCbCrModel { - fmt.Println("Color is ycbcr") } else if src.ColorModel() == color.AlphaModel { fmt.Println("Color is alpha") } else if src.ColorModel() == color.CMYKModel { @@ -68,31 +71,30 @@ func loadBaseImage(handle *Handle, id string) { fmt.Println("Other so assuming color") } - ModelUpdateStatus(handle, id, -1) + ModelUpdateStatus(c, id, -1) return } // Note: this also updates the status to 2 - _, err = handle.Db.Exec("update models set width=$1, height=$2, color_mode=$3, status=$4 where id=$5", width, height, model_color, CONFIRM_PRE_TRAINING, id) + _, err = c.Db.Exec("update models set width=$1, height=$2, color_mode=$3, format=$4, status=$5 where id=$6", width, height, model_color, format, CONFIRM_PRE_TRAINING, id) if err != nil { - // TODO better logging - fmt.Println(err) - fmt.Printf("Could not update model\n") - ModelUpdateStatus(handle, id, -1) + c.Logger.Error("Could not update model") + c.Logger.Error(err) + ModelUpdateStatus(c, id, -1) return } } func handleAdd(handle *Handle) { -handle.GetHTML("/models/add", AnswerTemplate("models/add.html", nil, 1)) - // TODO json + handle.GetHTML("/models/add", AnswerTemplate("models/add.html", nil, 1)) handle.Post("/models/add", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { - if c.Mode == JSON { - panic("TODO JSON") - } if !CheckAuthLevel(1, w, r, c) { return nil } + if c.Mode == JSON { + // TODO json + panic("TODO JSON") + } read_form, err := r.MultipartReader() if err != nil { @@ -176,7 +178,7 @@ handle.GetHTML("/models/add", AnswerTemplate("models/add.html", nil, 1)) f.Write(file) fmt.Printf("Created model with id %s! Started to proccess image!\n", id) - go loadBaseImage(handle, id) + go loadBaseImage(c, id) Redirect("/models/edit?id="+id, c.Mode, w, r) return nil diff --git a/logic/models/data.go b/logic/models/data.go index 354aae5..737047a 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -13,139 +13,145 @@ import ( "strings" model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" - . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) func InsertIfNotPresent(ss []string, s string) []string { - i := sort.SearchStrings(ss, s) - if len(ss) > i && ss[i] == s { - return ss - } - ss = append(ss, "") - copy(ss[i+1:], ss[i:]) - ss[i] = s - return ss + i := sort.SearchStrings(ss, s) + if len(ss) > i && ss[i] == s { + return ss + } + ss = append(ss, "") + copy(ss[i+1:], ss[i:]) + ss[i] = s + return ss } -func processZipFile(handle *Handle, model_id string) { - reader, err := zip.OpenReader(path.Join("savedData", model_id, "base_data.zip")) - if err != nil { - // TODO add msg to error - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - fmt.Printf("Faield to proccess zip file failed to open reader\n") - fmt.Println(err) - return - } - defer reader.Close() +func processZipFile(c *Context, model *BaseModel) { + reader, err := zip.OpenReader(path.Join("savedData", model.Id, "base_data.zip")) + if err != nil { + // TODO add msg to error + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + fmt.Printf("Faield to proccess zip file failed to open reader\n") + fmt.Println(err) + return + } + defer reader.Close() - training := []string{} - testing := []string{} + training := []string{} + testing := []string{} - for _, file := range reader.Reader.File { + for _, file := range reader.Reader.File { - paths := strings.Split(file.Name, "/") + paths := strings.Split(file.Name, "/") - if paths[1] == "" { - continue - } - - if paths[0] != "training" && paths[0] != "testing" { - fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name) - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - return - } - - if paths[0] != "training" { - training = InsertIfNotPresent(training, paths[1]) - } else if paths[0] != "testing" { - testing = InsertIfNotPresent(testing, paths[1]) - } - } - - if !reflect.DeepEqual(testing, training) { - fmt.Printf("testing and training are diferent\n") - fmt.Println(testing) - fmt.Println(training) - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - return - } - - base_path := path.Join("savedData", model_id, "data") - if err = os.MkdirAll(base_path, os.ModePerm); err != nil { - fmt.Printf("Failed to create base_path dir\n") - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - return - } - - ids := map[string]string{} - - for i, name := range training { - id, err := model_classes.CreateClass(handle.Db, model_id, i, name) - if err != nil { - fmt.Printf("Failed to create class '%s' on db\n", name) - ModelUpdateStatus(handle, id, FAILED_PREPARING_ZIP_FILE) - return - } - ids[name] = id - } - - for _, file := range reader.Reader.File { - if file.Name[len(file.Name) - 1] == '/' { - continue - } - - data, err := reader.Open(file.Name) - if err != nil { - fmt.Printf("Could not open file in zip %s\n", file.Name) - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - return + if paths[1] == "" { + continue } - defer data.Close() - file_data, err := io.ReadAll(data) - if err != nil { - fmt.Printf("Could not read file file in zip %s\n", file.Name) - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - return - } - // TODO check if the file is a valid photo that matched the defined photo on the database + if paths[0] != "training" && paths[0] != "testing" { + fmt.Printf("Invalid file '%s' TODO add msg to response!!!\n", file.Name) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return + } - parts := strings.Split(file.Name, "/") + if paths[0] != "training" { + training = InsertIfNotPresent(training, paths[1]) + } else if paths[0] != "testing" { + testing = InsertIfNotPresent(testing, paths[1]) + } + } - mode := model_classes.DATA_POINT_MODE_TRAINING - if parts[0] == "testing" { - mode = model_classes.DATA_POINT_MODE_TESTING - } + if !reflect.DeepEqual(testing, training) { + fmt.Printf("testing and training are diferent\n") + fmt.Println(testing) + fmt.Println(training) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return + } - data_point_id, err := model_classes.AddDataPoint(handle.Db, ids[parts[1]], "id://", mode) - if err != nil { - fmt.Printf("Failed to add data point for %s\n", model_id) - fmt.Println(err) - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - return - } + base_path := path.Join("savedData", model.Id, "data") + if err = os.MkdirAll(base_path, os.ModePerm); err != nil { + fmt.Printf("Failed to create base_path dir\n") + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return + } - file_path := path.Join(base_path, data_point_id + ".png") + ids := map[string]string{} + + for i, name := range training { + id, err := model_classes.CreateClass(c.Db, model.Id, i, name) + if err != nil { + fmt.Printf("Failed to create class '%s' on db\n", name) + ModelUpdateStatus(c, id, FAILED_PREPARING_ZIP_FILE) + return + } + ids[name] = id + } + + for _, file := range reader.Reader.File { + if file.Name[len(file.Name)-1] == '/' { + continue + } + + data, err := reader.Open(file.Name) + if err != nil { + fmt.Printf("Could not open file in zip %s\n", file.Name) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return + } + defer data.Close() + file_data, err := io.ReadAll(data) + if err != nil { + fmt.Printf("Could not read file file in zip %s\n", file.Name) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return + } + + // TODO check if the file is a valid photo that matched the defined photo on the database + + parts := strings.Split(file.Name, "/") + + mode := model_classes.DATA_POINT_MODE_TRAINING + if parts[0] == "testing" { + mode = model_classes.DATA_POINT_MODE_TESTING + } + + data_point_id, err := model_classes.AddDataPoint(c.Db, ids[parts[1]], "id://", mode) + if err != nil { + fmt.Printf("Failed to add data point for %s\n", model.Id) + fmt.Println(err) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return + } + + file_path := path.Join(base_path, data_point_id+".png") f, err := os.Create(file_path) if err != nil { - fmt.Printf("Could not create file %s\n", file_path) - ModelUpdateStatus(handle, model_id, FAILED_PREPARING_ZIP_FILE) - return + fmt.Printf("Could not create file %s\n", file_path) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return } defer f.Close() f.Write(file_data) - } - - fmt.Printf("Added data to model '%s'!\n", model_id) - ModelUpdateStatus(handle, model_id, CONFIRM_PRE_TRAINING) + + if !testImgForModel(c, model, file_path) { + c.Logger.Errorf("Image did not have valid format for model %s\n", file_path) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_ZIP_FILE) + return + } + } + + fmt.Printf("Added data to model '%s'!\n", model.Id) + ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING) } func handleDataUpload(handle *Handle) { handle.Post("/models/data/upload", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { - if !CheckAuthLevel(1, w, r, c) { - return nil - } + if !CheckAuthLevel(1, w, r, c) { + return nil + } if c.Mode == JSON { // TODO improve message return ErrorCode(nil, 400, nil) @@ -179,15 +185,15 @@ func handleDataUpload(handle *Handle) { } } - _, err = GetBaseModel(handle.Db, id) - if err == ModelNotFoundError { - return ErrorCode(nil, http.StatusNotFound, AnyMap{ + model, err := GetBaseModel(handle.Db, id) + if err == ModelNotFoundError { + return c.ErrorCode(nil, http.StatusNotFound, AnyMap{ "NotFoundMessage": "Model not found", "GoBackLink": "/models", }) - } else if err != nil { - return Error500(err) - } + } else if err != nil { + return Error500(err) + } // TODO mk this path configurable dir_path := path.Join("savedData", id) @@ -200,65 +206,65 @@ func handleDataUpload(handle *Handle) { f.Write(file) - ModelUpdateStatus(handle, id, PREPARING_ZIP_FILE) + ModelUpdateStatus(c, id, PREPARING_ZIP_FILE) - go processZipFile(handle, id) + go processZipFile(c, model) - Redirect("/models/edit?id="+id, c.Mode, w, r) + Redirect("/models/edit?id="+id, c.Mode, w, r) return nil }) - handle.Delete("/models/data/delete-zip-file", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { - if !CheckAuthLevel(1, w, r, c) { - return nil - } - if c.Mode == JSON { - panic("Handle delete zip file json") - } + handle.Delete("/models/data/delete-zip-file", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if !CheckAuthLevel(1, w, r, c) { + return nil + } + if c.Mode == JSON { + panic("Handle delete zip file json") + } - f, err := MyParseForm(r) - if err != nil { - return ErrorCode(err, 400, c.AddMap(nil)) - } + f, err := MyParseForm(r) + if err != nil { + return ErrorCode(err, 400, c.AddMap(nil)) + } - if !CheckId(f, "id") { - return ErrorCode(err, 400, c.AddMap(nil)) - } + if !CheckId(f, "id") { + return ErrorCode(err, 400, c.AddMap(nil)) + } - id := f.Get("id") + id := f.Get("id") - model, err := GetBaseModel(handle.Db, id) - if err == ModelNotFoundError { + model, err := GetBaseModel(handle.Db, id) + if err == ModelNotFoundError { return ErrorCode(nil, http.StatusNotFound, AnyMap{ "NotFoundMessage": "Model not found", "GoBackLink": "/models", }) - } else if err != nil { - return Error500(err) - } + } else if err != nil { + return Error500(err) + } - if model.Status != FAILED_PREPARING_ZIP_FILE { - // TODO add message - return ErrorCode(nil, 400, c.AddMap(nil)) - } + if model.Status != FAILED_PREPARING_ZIP_FILE { + // TODO add message + return ErrorCode(nil, 400, c.AddMap(nil)) + } - err = os.Remove(path.Join("savedData", id, "base_data.zip")); - if err != nil { - return Error500(err) - } + err = os.Remove(path.Join("savedData", id, "base_data.zip")) + if err != nil { + return Error500(err) + } - err = os.RemoveAll(path.Join("savedData", id, "data")); - if err != nil { - return Error500(err) - } + err = os.RemoveAll(path.Join("savedData", id, "data")) + if err != nil { + return Error500(err) + } - _, err = handle.Db.Exec("delete from model_classes where model_id=$1;", id) - if err != nil { - return Error500(err) - } + _, err = handle.Db.Exec("delete from model_classes where model_id=$1;", id) + if err != nil { + return Error500(err) + } - ModelUpdateStatus(handle, id, CONFIRM_PRE_TRAINING) - Redirect("/models/edit?id="+id, c.Mode, w, r) - return nil - }) + ModelUpdateStatus(c, id, CONFIRM_PRE_TRAINING) + Redirect("/models/edit?id="+id, c.Mode, w, r) + return nil + }) } diff --git a/logic/models/delete.go b/logic/models/delete.go index 4a639c9..93fff20 100644 --- a/logic/models/delete.go +++ b/logic/models/delete.go @@ -81,15 +81,14 @@ func handleDelete(handle *Handle) { } switch model.Status { - case FAILED_TRAINING: - fallthrough - case FAILED_PREPARING_TRAINING: - fallthrough + case FAILED_TRAINING: fallthrough + case FAILED_PREPARING_ZIP_FILE: fallthrough + case FAILED_PREPARING_TRAINING: fallthrough case FAILED_PREPARING: deleteModel(handle, id, w, c, model) return nil - case READY: - fallthrough + + case READY: fallthrough case CONFIRM_PRE_TRAINING: if CheckEmpty(f, "name") { return c.Error400(nil, "Name is empty", w, "/models/edit.html", "delete-model-card", AnyMap{ diff --git a/logic/models/test.go b/logic/models/test.go index f3da6b3..bb33d7b 100644 --- a/logic/models/test.go +++ b/logic/models/test.go @@ -25,10 +25,6 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { c.Logger.Errorf("Failed to decode image for model with id %s\nErr:%s", model.Id, err) return } - if format != "png" { - c.Logger.Errorf("Found unkown format '%s' while testing an image\n", format) - return - } var model_color string @@ -36,10 +32,11 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { width, height := bounds.Max.X, bounds.Max.Y switch src.ColorModel() { - case color.Gray16Model: - fallthrough + case color.Gray16Model: fallthrough case color.GrayModel: model_color = "greyscale" + case color.YCbCrModel: + model_color = "rgb" default: c.Logger.Error("Do not know how to handle this color model") @@ -47,8 +44,6 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { c.Logger.Info("Color is rgb") } else if src.ColorModel() == color.NRGBAModel { c.Logger.Info("Color is nrgb") - } else if src.ColorModel() == color.YCbCrModel { - c.Logger.Info("Color is ycbcr") } else if src.ColorModel() == color.AlphaModel { c.Logger.Info("Color is alpha") } else if src.ColorModel() == color.CMYKModel { @@ -69,5 +64,10 @@ func testImgForModel(c *Context, model *BaseModel, path string) (result bool) { return } + if format != model.Format { + c.Logger.Warn("Image format does not match model", format, model.Format) + return + } + return true } diff --git a/logic/models/train/reset.go b/logic/models/train/reset.go index 5662693..4255321 100644 --- a/logic/models/train/reset.go +++ b/logic/models/train/reset.go @@ -21,30 +21,30 @@ func handleRest(handle *Handle) { f, err := MyParseForm(r) if err != nil { // TODO improve response - return ErrorCode(nil, 400, c.AddMap(nil)) + return c.ErrorCode(nil, 400, c.AddMap(nil)) } if !CheckId(f, "id") { // TODO improve response - return ErrorCode(nil, 400, c.AddMap(nil)) + return c.ErrorCode(nil, 400, c.AddMap(nil)) } id := f.Get("id") model, err := GetBaseModel(handle.Db, id) if err == ModelNotFoundError { - return ErrorCode(nil, http.StatusNotFound, AnyMap{ + return c.ErrorCode(nil, http.StatusNotFound, AnyMap{ "NotFoundMessage": "Model not found", "GoBackLink": "/models", }) } else if err != nil { // TODO improve response - return Error500(err) + return c.Error500(err) } if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING { // TODO improve response - return ErrorCode(nil, 400, c.AddMap(nil)) + return c.ErrorCode(nil, 400, c.AddMap(nil)) } os.RemoveAll(path.Join("savedData", model.Id, "defs")) @@ -52,10 +52,10 @@ func handleRest(handle *Handle) { _, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id) if err != nil { // TODO improve response - return Error500(err) + return c.Error500(err) } - ModelUpdateStatus(handle, model.Id, CONFIRM_PRE_TRAINING) + ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING) Redirect("/models/edit?id=" + model.Id, c.Mode, w, r) return nil }) diff --git a/logic/models/train/train.go b/logic/models/train/train.go index d12cea7..3af8261 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -46,8 +46,8 @@ const ( LAYER_FLATTEN = 3 ) -func ModelDefinitionUpdateStatus(handle *Handle, id string, status ModelDefinitionStatus) (err error) { - _, err = handle.Db.Exec("update model_definition set status = $1 where id = $2", status, id) +func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionStatus) (err error) { + _, err = c.Db.Exec("update model_definition set status = $1 where id = $2", status, id) return } @@ -56,15 +56,15 @@ func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, return } -func generateCvs(handle *Handle, run_path string, model_id string) (count int, err error) { +func generateCvs(c *Context, run_path string, model_id string) (count int, err error) { - classes, err := handle.Db.Query("select count(*) from model_classes where model_id=$1;", model_id) + classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id) if err != nil { return } defer classes.Close() if !classes.Next() { return } if err = classes.Scan(&count); err != nil { return } - data, err := handle.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) + data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1;", model_id) if err != nil { return } defer data.Close() @@ -88,9 +88,9 @@ func generateCvs(handle *Handle, run_path string, model_id string) (count int, return } -func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (accuracy float64, err error) { +func trainDefinition(c *Context, model *BaseModel, definition_id string) (accuracy float64, err error) { accuracy = 0 - layers, err := handle.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) + layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { return } @@ -120,7 +120,7 @@ func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (ac return } - _, err = generateCvs(handle, run_path, model.Id) + _, err = generateCvs(c, run_path, model.Id) if err != nil { return } // Create python script @@ -185,12 +185,12 @@ func trainDefinition(handle *Handle, model *BaseModel, definition_id string) (ac return } -func trainModel(handle *Handle, model *BaseModel) { - definitionsRows, err := handle.Db.Query("select id, target_accuracy from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) +func trainModel(c *Context, model *BaseModel) { + definitionsRows, err := c.Db.Query("select id, target_accuracy from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) if err != nil { - fmt.Printf("Failed to trainModel!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + c.Logger.Error("Failed to trainModel!Err:") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } defer definitionsRows.Close() @@ -205,27 +205,26 @@ func trainModel(handle *Handle, model *BaseModel) { for definitionsRows.Next() { var rowv row if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy); err != nil { - fmt.Printf("Failed to trainModel!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + c.Logger.Error("Failed to train Model Could not read definition from db!Err:") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } definitions = append(definitions, rowv) } if len(definitions) == 0 { - fmt.Printf("Failed to trainModel!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + c.Logger.Error("No Definitions defined!") + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } for _, def := range definitions { - accuracy, err := trainDefinition(handle, model, def.id) + accuracy, err := trainDefinition(c, model, def.id) if err != nil { - fmt.Printf("Failed to train definition!Err:\n") - fmt.Println(err) - ModelDefinitionUpdateStatus(handle, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + c.Logger.Error("Failed to train definition!Err:") + c.Logger.Error(err) + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) continue } @@ -233,55 +232,55 @@ func trainModel(handle *Handle, model *BaseModel) { if int_accuracy < def.target_accuracy { fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy) - ModelDefinitionUpdateStatus(handle, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) + ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) continue } - _, err = handle.Db.Exec("update model_definition set accuracy=$1, status=$2 where id=$3", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.id) + _, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2 where id=$3", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.id) if err != nil { fmt.Printf("Failed to train definition!Err:\n") fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } } - rows, err := handle.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) + rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) if err != nil { - fmt.Printf("Db err select!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + c.Logger.Error("DB: failed to read definition") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } defer rows.Close() if !rows.Next() { - // TODO improve message - fmt.Printf("All definitions failed to train!") - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + // TODO Make the Model status have a message + c.Logger.Error("All definitions failed to train!") + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } var id string if err = rows.Scan(&id); err != nil { - fmt.Printf("Db err!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + c.Logger.Error("Failed to read id:") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } - if _, err = handle.Db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil { - fmt.Printf("Db err!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + if _, err = c.Db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil { + c.Logger.Error("Failed to update model definition") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } - to_delete, err := handle.Db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id) + to_delete, err := c.Db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id) if err != nil { - fmt.Printf("Db err!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + c.Logger.Error("Failed to select model_definition to delete") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } defer to_delete.Close() @@ -289,22 +288,23 @@ func trainModel(handle *Handle, model *BaseModel) { for to_delete.Next() { var id string if to_delete.Scan(&id);err != nil { - fmt.Printf("Db err!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + c.Logger.Error("Failed to scan the id of a model_definition to delete") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) } - - if _, err = handle.Db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil { - fmt.Printf("Db err!Err:\n") - fmt.Println(err) - ModelUpdateStatus(handle, model.Id, FAILED_TRAINING) + + // TODO Check if returning also works here + if _, err = c.Db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil { + c.Logger.Error("Failed to delete model_definition") + c.Logger.Error(err) + ModelUpdateStatus(c, model.Id, FAILED_TRAINING) return } - ModelUpdateStatus(handle, model.Id, READY) + ModelUpdateStatus(c, model.Id, READY) } func handleTrain(handle *Handle) { @@ -371,18 +371,18 @@ func handleTrain(handle *Handle) { cls, err := model_classes.ListClasses(handle.Db, model.Id) if err != nil { - ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response - return Error500(err) + return c.Error500(err) } var fid string for i := 0; i < number_of_models; i++ { def_id, err := MakeDefenition(handle.Db, model.Id, accuracy) if err != nil { - ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response - return Error500(err) + return c.Error500(err) } if fid == "" { @@ -392,42 +392,42 @@ func handleTrain(handle *Handle) { // TODO change shape of it depends on the type of the image err = MakeLayer(handle.Db, def_id, 1, LAYER_INPUT, fmt.Sprintf("%d,%d,1", model.Width, model.Height)) if err != nil { - ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response - return Error500(err) + return c.Error500(err) } err = MakeLayer(handle.Db, def_id, 4, LAYER_FLATTEN, fmt.Sprintf("%d,1", len(cls))) if err != nil { - ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response - return Error500(err) + return c.Error500(err) } err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d,1", len(cls) * 3)) if err != nil { - ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response - return Error500(err) + return c.Error500(err) } // Using sparce err = MakeLayer(handle.Db, def_id, 5, LAYER_DENSE, fmt.Sprintf("%d, 1", len(cls))) if err != nil { - ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response - return Error500(err) + return c.Error500(err) } - err = ModelDefinitionUpdateStatus(handle, def_id, MODEL_DEFINITION_STATUS_INIT) + err = ModelDefinitionUpdateStatus(c, def_id, MODEL_DEFINITION_STATUS_INIT) if err != nil { - ModelUpdateStatus(handle, model.Id, FAILED_PREPARING_TRAINING) + ModelUpdateStatus(c, model.Id, FAILED_PREPARING_TRAINING) // TODO improve this response - return Error500(err) + return c.Error500(err) } } // TODO start training with id fid - go trainModel(handle, model) + go trainModel(c, model) - ModelUpdateStatus(handle, model.Id, TRAINING) + ModelUpdateStatus(c, model.Id, TRAINING) Redirect("/models/edit?id="+model.Id, c.Mode, w, r) return nil }) diff --git a/logic/models/utils/types.go b/logic/models/utils/types.go index 5c3447c..df75962 100644 --- a/logic/models/utils/types.go +++ b/logic/models/utils/types.go @@ -13,6 +13,7 @@ type BaseModel struct { ImageMode int Width int Height int + Format string } const ( @@ -31,31 +32,29 @@ const ( var ModelNotFoundError = errors.New("Model not found error") func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { - rows, err := db.Query("select name, status, id, width, height, color_mode from models where id=$1;", id) - if err != nil { - return - } + rows, err := db.Query("select name, status, id, width, height, color_mode, format from models where id=$1;", id) + if err != nil { return } defer rows.Close() - if !rows.Next() { - return nil, ModelNotFoundError - } + if !rows.Next() { return nil, ModelNotFoundError } base = &BaseModel{} var colorMode string - err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode) + err = rows.Scan(&base.Name, &base.Status, &base.Id, &base.Width, &base.Height, &colorMode, &base.Format) if err != nil { return nil, err } - base.ImageMode = StringToImageMode(colorMode) + base.ImageMode = StringToImageMode(colorMode) return } -func StringToImageMode(colorMode string) (int){ +func StringToImageMode(colorMode string) int { switch colorMode { case "greyscale": return 1 + case "rgb": + return 3 default: panic("unkown color mode") } diff --git a/logic/models/utils/utils.go b/logic/models/utils/utils.go index 97e0fe7..d196376 100644 --- a/logic/models/utils/utils.go +++ b/logic/models/utils/utils.go @@ -7,8 +7,8 @@ import ( ) // TODO make this return and caller handle error -func ModelUpdateStatus(handle *Handle, id string, status int) { - _, err := handle.Db.Exec("update models set status = $1 where id = $2", status, id) +func ModelUpdateStatus(c *Context, id string, status int) { + _, err := c.Db.Exec("update models set status = $1 where id = $2", status, id) if err != nil { fmt.Println("Failed to update model status") fmt.Println(err) diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 6b48c3f..bd9ea3a 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -344,6 +344,7 @@ type Context struct { User *dbtypes.User Mode AnswerType Logger *log.Logger + Db *sql.DB } func (c Context) Error400(err error, message string, w http.ResponseWriter, path string, base string, data AnyMap) *Error { @@ -409,7 +410,7 @@ func (c *Context) requireAuth(w http.ResponseWriter, r *http.Request) bool { var LogoffError = errors.New("Invalid token!") -func (x Handle) createContext(mode AnswerType, r *http.Request) (*Context, error) { +func (x Handle) createContext(handler *Handle, mode AnswerType, r *http.Request) (*Context, error) { var token *string @@ -438,7 +439,7 @@ func (x Handle) createContext(mode AnswerType, r *http.Request) (*Context, error Prefix: r.URL.Path, }) - return &Context{token, user, mode, logger}, nil + return &Context{token, user, mode, logger, handler.Db}, nil } // TODO check if I can use http.Redirect @@ -516,6 +517,7 @@ func ErrorCode(err error, code int, data AnyMap) *Error { } func Error500(err error) *Error { + log.Warn("This function is deprecated please use the one provided by context") return ErrorCode(err, http.StatusInternalServerError, nil) } @@ -563,7 +565,7 @@ func NewHandler(db *sql.DB) *Handle { //TODO JSON //Login state - context, err := x.createContext(ans, r) + context, err := x.createContext(x, ans, r) if err != nil { Logoff(ans, w, r) return diff --git a/sql/models.sql b/sql/models.sql index 5024380..7003eb0 100644 --- a/sql/models.sql +++ b/sql/models.sql @@ -13,7 +13,8 @@ create table if not exists models ( width integer, height integer, - color_mode varchar (20) + color_mode varchar (20), + format varchar (20) ); -- drop table if exists model_data_point;