diff --git a/logic/models/classes/list.go b/logic/models/classes/list.go index a873953..6762ad9 100644 --- a/logic/models/classes/list.go +++ b/logic/models/classes/list.go @@ -38,10 +38,10 @@ func models_data_list_json(w http.ResponseWriter, r *http.Request, c *Context) * } type baserow struct { - Id string - File_Path string - Model_Mode int - Status int + Id string `json:"id"` + File_Path string `json:"file_path"` + Model_Mode int `json:"model_mode"` + Status int `json:"status"` } rows, err := utils.GetDbMultitple[baserow](c, "model_data_point where class_id=$1 limit 11 offset $2", id, page*10) diff --git a/logic/models/edit.go b/logic/models/edit.go index b4d8c03..0baa990 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -43,141 +43,9 @@ func handleJson(w http.ResponseWriter, r *http.Request, c *Context) *Error { /* - // Handle errors - // All errors will be negative - if model.Status < 0 { - LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ - "Model": model, - })) - return nil - } - switch model.Status { - case READY: - LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ - "Model": model, - })) case TRAINING: - - type defrow struct { - Id string - Status int - EpochProgress int - Epoch int - Accuracy float64 - } - - defs := []defrow{} - - if model.Type == 2 { - def_rows, err := c.Db.Query("select md.id, md.status, md.epoch, h.epoch_progress, h.accuracy from model_definition as md inner join exp_model_head as h on h.def_id = md.id where md.model_id=$1 order by md.created_on asc", model.Id) - if err != nil { - return c.Error500(err) - } - defer def_rows.Close() - - for def_rows.Next() { - var def defrow - err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy) - if err != nil { - return c.Error500(err) - } - defs = append(defs, def) - } - } else { - def_rows, err := c.Db.Query("select id, status, epoch, epoch_progress, accuracy from model_definition where model_id=$1 order by created_on asc", model.Id) - if err != nil { - return c.Error500(err) - } - defer def_rows.Close() - - for def_rows.Next() { - var def defrow - err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy) - if err != nil { - return c.Error500(err) - } - defs = append(defs, def) - } - } - - type layerdef struct { - id string - LayerType int - Shape string - } - - layers := []layerdef{} - - for _, def := range defs { - if def.Status == MODEL_DEFINITION_STATUS_TRAINING { - rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id) - if err != nil { - return c.Error500(err) - } - defer rows.Close() - - for rows.Next() { - var layerdef layerdef - err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape) - if err != nil { - return c.Error500(err) - } - layers = append(layers, layerdef) - } - - if model.Type == 2 { - - type lastLayerType struct { - Id string - Range_start int - Range_end int - } - - var lastLayer lastLayerType - - err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id) - if err != nil { - return c.Error500(err) - } - - layers = append(layers, layerdef{ - id: lastLayer.Id, - LayerType: LAYER_DENSE, - Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start+1), - }) - } - - break - } - } - - sep_mod := 100 - if len(layers) > 8 { - sep_mod = 100 - (len(layers)-8)*10 - } - - if sep_mod < 10 { - sep_mod = 10 - } - - LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ - "Model": model, - "Defs": defs, - "Layers": layers, - "SepMod": sep_mod, - })) - case PREPARING_ZIP_FILE: - LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{ - "Model": model, - })) - default: - fmt.Printf("Unkown Status: %d\n", model.Status) - return Error500(nil) } - - - return nil */ } @@ -230,6 +98,151 @@ func handleEdit(handle *Handle) { }) }) + handle.Get("/models/edit/definitions", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if !CheckAuthLevel(1, w, r, c) { + return nil + } + if c.Mode != JSON { + return c.ErrorCode(nil, 400, AnyMap{}) + } + + id, err := GetIdFromUrl(r, "id") + if err != nil { + return c.SendJSONStatus(http.StatusNotFound, "Model not found") + } + + model, err := GetBaseModel(c.Db, id) + if err == ModelNotFoundError { + return c.SendJSONStatus(http.StatusNotFound, "Model not found") + } else if err != nil { + return c.Error500(err) + } + + type defrow struct { + Id string + Status int + EpochProgress int + Epoch int + Accuracy float64 + } + + defs := []defrow{} + + if model.ModelType == 2 { + def_rows, err := c.Db.Query("select md.id, md.status, md.epoch, h.epoch_progress, h.accuracy from model_definition as md inner join exp_model_head as h on h.def_id = md.id where md.model_id=$1 order by md.created_on asc", model.Id) + if err != nil { + return c.Error500(err) + } + defer def_rows.Close() + + for def_rows.Next() { + var def defrow + err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy) + if err != nil { + return c.Error500(err) + } + defs = append(defs, def) + } + } else { + def_rows, err := c.Db.Query("select id, status, epoch, epoch_progress, accuracy from model_definition where model_id=$1 order by created_on asc", model.Id) + if err != nil { + return c.Error500(err) + } + defer def_rows.Close() + + for def_rows.Next() { + var def defrow + err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy) + if err != nil { + return c.Error500(err) + } + defs = append(defs, def) + } + } + + type layerdef struct { + id string + LayerType int `json:"layer_type"` + Shape string `json:"shape"` + } + + layers := []layerdef{} + + for _, def := range defs { + if def.Status == MODEL_DEFINITION_STATUS_TRAINING { + rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id) + if err != nil { + return c.Error500(err) + } + defer rows.Close() + + for rows.Next() { + var layerdef layerdef + err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape) + if err != nil { + return c.Error500(err) + } + layers = append(layers, layerdef) + } + + if model.ModelType == 2 { + + type lastLayerType struct { + Id string + Range_start int + Range_end int + } + + var lastLayer lastLayerType + + err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id) + if err != nil { + return c.Error500(err) + } + + layers = append(layers, layerdef{ + id: lastLayer.Id, + LayerType: LAYER_DENSE, + Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start+1), + }) + } + + break + } + } + + type Definitions struct { + Id string `json:"id"` + Status int `json:"status"` + EpochProgress int `json:"epoch_progress"` + Epoch int `json:"epoch"` + Accuracy float64 `json:"accuracy"` + Layers *[]layerdef `json:"layers"` + } + + defsToReturn := make([]Definitions, len(defs), len(defs)) + + setLayers := false + + for i, def := range defs { + var lay *[]layerdef = nil + if def.Status == MODEL_DEFINITION_STATUS_TRAINING && !setLayers { + lay = &layers + setLayers = true + } + defsToReturn[i] = Definitions{ + Id: def.Id, + Status: def.Status, + EpochProgress: def.EpochProgress, + Epoch: def.Epoch, + Accuracy: def.Accuracy, + Layers: lay, + } + } + + return c.SendJSON(defsToReturn) + }) + handle.Get("/models/edit", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { return nil diff --git a/logic/models/run.go b/logic/models/run.go index decb7c5..e79d470 100644 --- a/logic/models/run.go +++ b/logic/models/run.go @@ -39,7 +39,7 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image { func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) { order = 0 - err = nil + err = nil tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil) @@ -63,55 +63,55 @@ func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf. } func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) { - - err = nil - order = 0 - base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil) + err = nil + order = 0 + + base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil) //results := base_model.Exec([]tf.Output{ - base_results := base_model.Exec([]tf.Output{ + base_results := base_model.Exec([]tf.Output{ base_model.Op("StatefulPartitionedCall", 0), }, map[tf.Output]*tf.Tensor{ //base_model.Op("serving_default_rescaling_input", 0): inputImage, - base_model.Op("serving_default_input_1", 0): inputImage, + base_model.Op("serving_default_input_1", 0): inputImage, }) - type head struct { - Id string - Range_start int - } + type head struct { + Id string + Range_start int + } - heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id) - if err != nil { - return - } + heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id) + if err != nil { + return + } - var vmax float32 = 0.0 + var vmax float32 = 0.0 - for _, element := range heads { - head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil) + for _, element := range heads { + head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil) - results := head_model.Exec([]tf.Output{ - head_model.Op("StatefulPartitionedCall", 0), - }, map[tf.Output]*tf.Tensor{ - head_model.Op("serving_default_input_2", 0): base_results[0], - }) + results := head_model.Exec([]tf.Output{ + head_model.Op("StatefulPartitionedCall", 0), + }, map[tf.Output]*tf.Tensor{ + head_model.Op("serving_default_input_2", 0): base_results[0], + }) - var predictions = results[0].Value().([][]float32)[0] + var predictions = results[0].Value().([][]float32)[0] - for i, v := range predictions { - if v > vmax { - order = element.Range_start + i - vmax = v - } - } - } + for i, v := range predictions { + if v > vmax { + order = element.Range_start + i + vmax = v + } + } + } - // TODO runthe head model + // TODO runthe head model - c.Logger.Info("Got", "heads", len(heads)) - return + c.Logger.Info("Got", "heads", len(heads)) + return } func handleRun(handle *Handle) { @@ -120,8 +120,123 @@ func handleRun(handle *Handle) { return nil } if c.Mode == JSON { - // TODO improve message - return ErrorCode(nil, 400, nil) + + read_form, err := r.MultipartReader() + if err != nil { + // TODO improve message + return ErrorCode(nil, 400, nil) + } + + var id string + var file []byte + + for { + part, err_part := read_form.NextPart() + if err_part == io.EOF { + break + } else if err_part != nil { + return c.JsonBadRequest("Invalid multipart data") + } + if part.FormName() == "id" { + buf := new(bytes.Buffer) + buf.ReadFrom(part) + id = buf.String() + } + if part.FormName() == "file" { + buf := new(bytes.Buffer) + buf.ReadFrom(part) + file = buf.Bytes() + } + } + + model, err := GetBaseModel(handle.Db, id) + if err == ModelNotFoundError { + return c.JsonBadRequest("Models not found"); + } else if err != nil { + return c.Error500(err) + } + + if model.Status != READY { + return c.JsonBadRequest("Model not ready to run images") + } + + def := JustId{} + err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id) + if err == NotFoundError { + return c.JsonBadRequest("Could not find definition") + } else if err != nil { + return c.Error500(err) + } + + def_id := def.Id + + // TODO create a database table with tasks + run_path := path.Join("/tmp", model.Id, "runs") + os.MkdirAll(run_path, os.ModePerm) + img_path := path.Join(run_path, "img."+model.Format) + + img_file, err := os.Create(img_path) + if err != nil { + return c.Error500(err) + } + defer img_file.Close() + img_file.Write(file) + + if !testImgForModel(c, model, img_path) { + return c.JsonBadRequest("Provided image does not match the model") + } + + root := tg.NewRoot() + + var tf_img *image.Image = nil + + switch model.Format { + case "png": + tf_img = ReadPNG(root, img_path, int64(model.ImageMode)) + case "jpeg": + tf_img = ReadJPG(root, img_path, int64(model.ImageMode)) + default: + panic("Not sure what to do with '" + model.Format + "'") + } + + exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{}) + inputImage, err := tf.NewTensor(exec_results[0].Value()) + if err != nil { + return c.Error500(err) + } + + vi := -1 + + if model.ModelType == 2 { + c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) + vi, err = runModelExp(c, model, def_id, inputImage) + if err != nil { + return c.Error500(err) + } + } else { + c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) + vi, err = runModelNormal(c, model, def_id, inputImage) + if err != nil { + return c.Error500(err) + } + } + + os.RemoveAll(run_path) + + rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi) + if err != nil { + return c.Error500(err) + } + if !rows.Next() { + return c.SendJSON(nil) + } + + var name string + if err = rows.Scan(&name); err != nil { + return c.Error500(err) + } + + return c.SendJSON(name) } read_form, err := r.MultipartReader() @@ -220,21 +335,21 @@ func handleRun(handle *Handle) { return Error500(err) } - vi := -1 + vi := -1 - if model.ModelType == 2 { - c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) - vi, err = runModelExp(c, model, def_id, inputImage) - if err != nil { - return c.Error500(err); - } - } else { - c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) - vi, err = runModelNormal(c, model, def_id, inputImage) - if err != nil { - return c.Error500(err); - } - } + if model.ModelType == 2 { + c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) + vi, err = runModelExp(c, model, def_id, inputImage) + if err != nil { + return c.Error500(err) + } + } else { + c.Logger.Info("Running model normal", "model", model.Id, "def", def_id) + vi, err = runModelNormal(c, model, def_id, inputImage) + if err != nil { + return c.Error500(err) + } + } os.RemoveAll(run_path) diff --git a/logic/models/train/reset.go b/logic/models/train/reset.go index 4255321..1fde63a 100644 --- a/logic/models/train/reset.go +++ b/logic/models/train/reset.go @@ -10,53 +10,82 @@ import ( ) func handleRest(handle *Handle) { - handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { - if !CheckAuthLevel(1, w, r, c) { - return nil - } - if c.Mode == JSON { - panic("handle JSON /models/train/reset") - } + handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if !CheckAuthLevel(1, w, r, c) { + return nil + } + if c.Mode == JSON { + var dat struct { + Id string `json:"id"` + } + + if err := c.ToJSON(r, &dat); err != nil { + return err; + } - f, err := MyParseForm(r) - if err != nil { - // TODO improve response - return c.ErrorCode(nil, 400, c.AddMap(nil)) - } + model, err := GetBaseModel(c.Db, dat.Id) + if err == ModelNotFoundError { + return c.JsonBadRequest("Model not found"); + } else if err != nil { + // TODO improve response + return c.Error500(err) + } - if !CheckId(f, "id") { - // TODO improve response - return c.ErrorCode(nil, 400, c.AddMap(nil)) - } + if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING { + return c.JsonBadRequest("Model is not in status that be reset") + } - id := f.Get("id") + os.RemoveAll(path.Join("savedData", model.Id, "defs")) - model, err := GetBaseModel(handle.Db, id) - if err == ModelNotFoundError { + _, err = c.Db.Exec("delete from model_definition where model_id=$1", model.Id) + if err != nil { + // TODO improve response + return c.Error500(err) + } + + ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING) + return c.SendJSON(model.Id) + } + + f, err := MyParseForm(r) + if err != nil { + // TODO improve response + return c.ErrorCode(nil, 400, c.AddMap(nil)) + } + + if !CheckId(f, "id") { + // TODO improve response + return c.ErrorCode(nil, 400, c.AddMap(nil)) + } + + id := f.Get("id") + + model, err := GetBaseModel(handle.Db, id) + if err == ModelNotFoundError { return c.ErrorCode(nil, http.StatusNotFound, AnyMap{ "NotFoundMessage": "Model not found", "GoBackLink": "/models", }) - } else if err != nil { - // TODO improve response - return c.Error500(err) - } + } else if err != nil { + // TODO improve response + return c.Error500(err) + } - if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING { - // TODO improve response - return c.ErrorCode(nil, 400, c.AddMap(nil)) - } + if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING { + // TODO improve response + return c.ErrorCode(nil, 400, c.AddMap(nil)) + } - os.RemoveAll(path.Join("savedData", model.Id, "defs")) + os.RemoveAll(path.Join("savedData", model.Id, "defs")) - _, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id) - if err != nil { - // TODO improve response - return c.Error500(err) - } + _, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id) + if err != nil { + // TODO improve response + return c.Error500(err) + } - ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING) - Redirect("/models/edit?id=" + model.Id, c.Mode, w, r) - return nil - }) + ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING) + Redirect("/models/edit?id="+model.Id, c.Mode, w, r) + return nil + }) } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 144cd57..029d3e8 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -42,9 +42,9 @@ func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionSt return } -func UpdateStatus (c *Context, table string, id string, status int) (err error) { - _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) - return +func UpdateStatus(c *Context, table string, id string, status int) (err error) { + _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) + return } func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) { @@ -246,7 +246,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load return } - UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING) + UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING) layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) if err != nil { @@ -283,7 +283,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load got = append(got, layerrow{ LayerType: LAYER_DENSE, - Shape: fmt.Sprintf("%d", exp.end-exp.start + 1), + Shape: fmt.Sprintf("%d", exp.end-exp.start+1), ExpType: 2, LayerNum: i, }) @@ -625,14 +625,14 @@ func trainModelExp(c *Context, model *BaseModel) { var rowv TrainModelRow rowv.acuracy = 0 if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil { - failed("Failed to train Model Could not read definition from db!") + failed("Failed to train Model Could not read definition from db!") return } definitions = append(definitions, rowv) } if len(definitions) == 0 { - failed("No Definitions defined!") + failed("No Definitions defined!") return } @@ -661,13 +661,13 @@ func trainModelExp(c *Context, model *BaseModel) { c.Logger.Info("Found a definition that reaches target_accuracy!") _, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id) if err != nil { - failed("Failed to train definition!") + failed("Failed to train definition!") return } _, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) if err != nil { - failed("Failed to train definition!") + failed("Failed to train definition!") return } @@ -684,7 +684,7 @@ func trainModelExp(c *Context, model *BaseModel) { _, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id) if err != nil { - failed("Failed to train definition!") + failed("Failed to train definition!") return } } @@ -737,30 +737,30 @@ func trainModelExp(c *Context, model *BaseModel) { rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) if err != nil { - failed("DB: failed to read definition") + failed("DB: failed to read definition") return } defer rows.Close() if !rows.Next() { - failed("All definitions failed to train!") + failed("All definitions failed to train!") return } var id string if err = rows.Scan(&id); err != nil { - failed("Failed to read id") + failed("Failed to read id") return } if _, err = c.Db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil { - failed("Failed to update model definition") + failed("Failed to update model definition") return } to_delete, err := c.Db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id) if err != nil { - failed("Failed to select model_definition to delete") + failed("Failed to select model_definition to delete") return } defer to_delete.Close() @@ -768,7 +768,7 @@ func trainModelExp(c *Context, model *BaseModel) { for to_delete.Next() { var id string if to_delete.Scan(&id); err != nil { - failed("Failed to scan the id of a model_definition to delete") + failed("Failed to scan the id of a model_definition to delete") return } os.RemoveAll(path.Join("savedData", model.Id, "defs", id)) @@ -776,24 +776,24 @@ func trainModelExp(c *Context, model *BaseModel) { // TODO Check if returning also works here if _, err = c.Db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil { - failed("Failed to delete model_definition") + failed("Failed to delete model_definition") return } - if err = splitModel(c, model); err != nil { - failed("Failed to split the model") + if err = splitModel(c, model); err != nil { + failed("Failed to split the model") return - } - - // There should only be one def availabale - def := JustId{} + } - if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { - return - } - - // Remove the base model - c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id) + // There should only be one def availabale + def := JustId{} + + if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { + return + } + + // Remove the base model + c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id) os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model")) os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras")) @@ -802,17 +802,17 @@ func trainModelExp(c *Context, model *BaseModel) { func splitModel(c *Context, model *BaseModel) (err error) { - def := JustId{} + def := JustId{} - if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { - return - } + if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { + return + } - head := JustId{} + head := JustId{} - if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil { - return - } + if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil { + return + } // Generate run folder run_path := path.Join("/tmp", model.Id, "defs", def.Id) @@ -821,7 +821,7 @@ func splitModel(c *Context, model *BaseModel) (err error) { if err != nil { return } - // TODO reneable it + // TODO reneable it // defer os.RemoveAll(run_path) // Create python script @@ -838,9 +838,9 @@ func splitModel(c *Context, model *BaseModel) (err error) { // Copy result around result_path := path.Join(getDir(), "savedData", model.Id, "defs", def.Id) - - // TODO maybe move this to a select count(*) - // Get only fixed lawers + + // TODO maybe move this to a select count(*) + // Get only fixed lawers layers, err := c.Db.Query("select exp_type from model_definition_layer where def_id=$1 and exp_type=$2 order by layer_order asc;", def.Id, 1) if err != nil { return @@ -848,24 +848,24 @@ func splitModel(c *Context, model *BaseModel) (err error) { defer layers.Close() type layerrow struct { - ExpType int + ExpType int } - - count := -1 + + count := -1 for layers.Next() { - count += 1 + count += 1 } - if count == -1 { - err = errors.New("Can not get layers") - return - } + if count == -1 { + err = errors.New("Can not get layers") + return + } - log.Warn("Spliting model", "def", def.Id, "head", head.Id, "count", count) + log.Warn("Spliting model", "def", def.Id, "head", head.Id, "count", count) - basePath := path.Join(result_path, "base") - headPath := path.Join(result_path, "head", head.Id) + basePath := path.Join(result_path, "base") + headPath := path.Join(result_path, "head", head.Id) if err = os.MkdirAll(basePath, os.ModePerm); err != nil { return @@ -876,10 +876,10 @@ func splitModel(c *Context, model *BaseModel) (err error) { } if err = tmpl.Execute(f, AnyMap{ - "SplitLen": count, - "ModelPath": path.Join(result_path, "model.keras"), - "BaseModelPath": basePath, - "HeadModelPath": headPath, + "SplitLen": count, + "ModelPath": path.Join(result_path, "model.keras"), + "BaseModelPath": basePath, + "HeadModelPath": headPath, }); err != nil { return } @@ -892,7 +892,7 @@ func splitModel(c *Context, model *BaseModel) (err error) { c.Logger.Info("Python finished running") - return + return } func removeFailedDataPoints(c *Context, model *BaseModel) (err error) { @@ -1089,7 +1089,7 @@ func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatu // This generates a definition func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error { - c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity) + c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity) var err error = nil failed := func() *Error { @@ -1133,13 +1133,13 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy // Create the blocks loop := int((math.Log(float64(model.Width)) / math.Log(float64(10)))) - if model.Width < 50 && model.Height < 50 { - loop = 0 - } + if model.Width < 50 && model.Height < 50 { + loop = 0 + } - log.Info("Size of the simple block", "loop", loop) + log.Info("Size of the simple block", "loop", loop) - //loop = max(loop, 3) + //loop = max(loop, 3) for i := 0; i < loop; i++ { err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1) @@ -1165,9 +1165,9 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2) - log.Info("Size of the dense layers", "loop", loop) + log.Info("Size of the dense layers", "loop", loop) - // loop = max(loop, 3) + // loop = max(loop, 3) for i := 0; i < loop; i++ { err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i))) @@ -1176,7 +1176,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy return failed() } } - + _, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT) if err != nil { return failed() @@ -1190,6 +1190,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy return nil } +// TODO make this json friendy func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy int, number_of_models int) *Error { cls, err := model_classes.ListClasses(c.Db, model.Id) if err != nil { @@ -1225,13 +1226,77 @@ func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy return nil } +func handle_models_train_json(w http.ResponseWriter, r *http.Request, c *Context) *Error { + var dat struct { + Id string `json:"id"` + ModelType string `json:"model_type"` + NumberOfModels int `json:"number_of_models"` + Accuracy int `json:"accuracy"` + } + + if err_ := c.ToJSON(r, &dat); err_ != nil { + return err_ + } + + if dat.Id == "" { + return c.JsonBadRequest("Please provide a id") + } + + modelTypeId := 1 + if dat.ModelType == "expandable" { + modelTypeId = 2 + } else if dat.ModelType != "simple" { + return c.JsonBadRequest("Invalid model type!") + } + + model, err := GetBaseModel(c.Db, dat.Id) + if err == ModelNotFoundError { + return c.JsonBadRequest("Model not found") + } else if err != nil { + return c.Error500(err) + } + + if model.Status != CONFIRM_PRE_TRAINING { + return c.JsonBadRequest("Model in invalid status for training") + } + + if modelTypeId == 2 { + full_error := generateExpandableDefinitions(c, model, dat.Accuracy, dat.NumberOfModels) + if full_error != nil { + return full_error + } + } else { + full_error := generateDefinitions(c, model, dat.Accuracy, dat.NumberOfModels) + if full_error != nil { + return full_error + } + } + + if modelTypeId == 2 { + go trainModelExp(c, model) + } else { + go trainModel(c, model) + } + + _, err = c.Db.Exec("update models set status = $1, model_type = $2 where id = $3", TRAINING, modelTypeId, model.Id) + if err != nil { + fmt.Println("Failed to update model status") + fmt.Println(err) + // TODO improve this response + return Error500(err) + } + + return c.SendJSON(model.Id) +} + func handleTrain(handle *Handle) { handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { if !CheckAuthLevel(1, w, r, c) { return nil } + if c.Mode == JSON { - panic("TODO /models/train JSON") + return handle_models_train_json(w, r, c) } r.ParseForm() diff --git a/logic/utils/handler.go b/logic/utils/handler.go index c2b7b29..a556263 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -387,6 +387,7 @@ func (c Context) SendJSONStatus(status int, dat any) *Error { } func (c Context) JsonBadRequest(dat any) *Error { + c.Logger.Warn("Request failed with a bad request", "dat", dat) return c.SendJSONStatus(http.StatusBadRequest, dat) } @@ -620,6 +621,7 @@ func (x Handle) ReadFiles(pathTest string, baseFilePath string, fileType string, }) } +// TODO remove this func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { user_path := r.URL.Path[len(pathTest):] @@ -656,6 +658,42 @@ func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes [ }) } +func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { + http.HandleFunc("/api" + pathTest, func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1) + + user_path := r.URL.Path[len(pathTest):] + + found := false + index := -1 + + for i, fileType := range fileTypes { + if strings.HasSuffix(user_path, fileType) { + found = true + index = i + break + } + } + + if !found { + w.WriteHeader(http.StatusNotFound) + w.Write([]byte("File not found")) + return + } + + bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) + if err != nil { + fmt.Println(err) + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Failed to load file")) + return + } + + w.Header().Set("Content-Type", contentTypes[index]) + w.Write(bytes) + }) +} + func NewHandler(db *sql.DB) *Handle { var gets []HandleFunc diff --git a/main.go b/main.go index 6bfa1ac..3aa7a6f 100644 --- a/main.go +++ b/main.go @@ -48,7 +48,7 @@ func main() { handle.StaticFiles("/js/", ".js", "text/javascript") handle.ReadFiles("/imgs/", "views", ".png", "image/png;") handle.ReadTypesFiles("/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"}) - handle.ReadTypesFiles("/api/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"}) + handle.ReadTypesFilesApi("/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"}) handle.GetHTML("/", AnswerTemplate("index.html", nil, 0)) diff --git a/webpage/src/app.html b/webpage/src/app.html index 77a5ff5..7e570fb 100644 --- a/webpage/src/app.html +++ b/webpage/src/app.html @@ -4,6 +4,8 @@ + + %sveltekit.head% diff --git a/webpage/src/lib/MessageSimple.svelte b/webpage/src/lib/MessageSimple.svelte index 86d1d2e..8da5f0d 100644 --- a/webpage/src/lib/MessageSimple.svelte +++ b/webpage/src/lib/MessageSimple.svelte @@ -14,6 +14,11 @@ let timeout: number | undefined = undefined; + export function clear() { + if (timeout) clearTimeout(timeout); + message = undefined; + } + export function display( msg: string, options?: { diff --git a/webpage/src/lib/requests.svelte.ts b/webpage/src/lib/requests.svelte.ts index 37f52a5..bcfc445 100644 --- a/webpage/src/lib/requests.svelte.ts +++ b/webpage/src/lib/requests.svelte.ts @@ -16,7 +16,10 @@ export async function get(url: string) { headers: headers, }); - if (r.status !== 200) { + if (r.status === 401) { + userStore.user = undefined; + goto("/login") + } else if (r.status !== 200) { throw r; } @@ -35,8 +38,11 @@ export async function post(url: string, body: any) { headers: headers, body: JSON.stringify(body), }); - - if (r.status !== 200) { + + if (r.status === 401) { + userStore.user = undefined; + goto("/login") + } else if (r.status !== 200) { throw r; } @@ -56,7 +62,10 @@ export async function rdelete(url: string, body: any) { body: JSON.stringify(body), }); - if (r.status !== 200) { + if (r.status === 401) { + userStore.user = undefined; + goto("/login") + } else if (r.status !== 200) { throw r; } diff --git a/webpage/src/routes/models/edit/+page.svelte b/webpage/src/routes/models/edit/+page.svelte index bb734fb..e478dfe 100644 --- a/webpage/src/routes/models/edit/+page.svelte +++ b/webpage/src/routes/models/edit/+page.svelte @@ -1,25 +1,25 @@ - {#await model} - - Model - - {:then m} - {#if m} - - Model: {m.name} - - {:else} - - Model - - {/if} - {/await} + {#await model} + Model + {:then m} + {#if m} + + Model: {m.name} + + {:else} + Model + {/if} + {/await} - - - - - -
- {#await model} - Loading - {:then m} - {#if m.status == 1} -
-

- { m.name } -

- -

- Preparing the model -

-
- {:else if m.status == -1} -
-

- {m.name} -

- -

- Failed to prepare model -

+ {#await model} + Loading + {:then m} + {#if m.status == 1} +
+

+ {m.name} +

+ +

Preparing the model

+
+ {:else if m.status == -1} +
+

+ {m.name} +

+ +

Failed to prepare model

-
- TODO button delete -
- - -
- - {:else if m.status == 2 } - - - - - {:else if m.status == -2 } - - - - {:else if m.status == 3 } - -
- - Processing zip file... -
- {:else if m.status == -3 || m.status == -4} - -
- Failed Prepare for training.
-
-
+ + {:else if m.status == 2} + + + + + {:else if m.status == -2} + + + + {:else if m.status == 3} + +
+ + Processing zip file... +
+ {:else if m.status == -3 || m.status == -4} + +
+ Failed Prepare for training.
+
+ + + + + {:else if m.status == 4} + + +
+ + Training the model...
+ + {#await definitions} + Loading + {:then defs} + + + + + + + + + + + {#each defs as def} + + + + + + + {#if def.status == 3 && def.layers} + + + + {/if} + {/each} + +
Done Progress Training Round Progress Accuracy Status
+ {def.epoch} + + {def.epoch_progress}/20 + + {def.accuracy}% + + {#if def.status == 2} + + {:else if [3, 6, -3].includes(def.status)} + + {:else} + {def.status} + {/if} +
+ + {#each def.layers as layer, i} + {@const sep_mod = + def.layers.length > 8 + ? Math.max(10, 100 - (def.layers.length - 8) * 10) + : 100} + {#if layer.layer_type == 1} + + {:else if layer.layer_type == 4} + + + {:else if layer.layer_type == 3} + + + {:else if layer.layer_type == 2} + + + {:else} +
+ {layer.layer_type} + {layer.shape} +
+ {/if} + {/each} + +
+ {/await} + +
+ {:else if m.status == 5} + + + + {:else} +

Unknown Status of the model.

+ {/if} + {/await}
diff --git a/webpage/src/routes/models/edit/ModelData.svelte b/webpage/src/routes/models/edit/ModelData.svelte index 5d06dc9..0a97a32 100644 --- a/webpage/src/routes/models/edit/ModelData.svelte +++ b/webpage/src/routes/models/edit/ModelData.svelte @@ -12,10 +12,12 @@ import MessageSimple from "src/lib/MessageSimple.svelte"; import { createEventDispatcher } from "svelte"; import ModelTable from "./ModelTable.svelte"; + import TrainModel from "./TrainModel.svelte"; let { model } = $props<{model: Model}>(); let classes: Class[] = $state([]); + let has_data: boolean = $state(false); let file: File | undefined = $state(); @@ -61,6 +63,7 @@ let data = await get(`models/edit/classes?id=${model.id}`); classes = data.classes numberOfInvalidImages = data.number_of_invalid_images; + has_data = data.has_data; } catch { return; } @@ -150,7 +153,7 @@
- +
TODO @@ -177,7 +180,7 @@
- +
TODO @@ -185,3 +188,5 @@ {/if}
+ + dispatch('reload')} /> diff --git a/webpage/src/routes/models/edit/ModelTable.svelte b/webpage/src/routes/models/edit/ModelTable.svelte index a2ab808..bc5ec37 100644 --- a/webpage/src/routes/models/edit/ModelTable.svelte +++ b/webpage/src/routes/models/edit/ModelTable.svelte @@ -1,15 +1,28 @@ + + {#if classes.length == 0} @@ -56,8 +77,109 @@ {/each} +
+ + + + + + + + + + + {#each image_list as image} + + + + + + + {/each} + +
File Path Mode + + + +
+ {#if image.file_path == 'id://'} + Managed + {:else} + {image.file_path} + {/if} + + {#if image.mode == 2} + Testing + {:else} + Training + {/if} + + {#if image.file_path == 'id://'} + + {:else} + TODO img {image.file_path} + {/if} + + {#if image.status == 1} + + {:else} + + {/if} +
+
+
+ {#if page > 0} + + {/if} +
+ +
+ {page} +
+ +
+ {#if showNext} + + {/if} +
+
+
{/if} diff --git a/webpage/src/routes/models/edit/RunModel.svelte b/webpage/src/routes/models/edit/RunModel.svelte new file mode 100644 index 0000000..c069c95 --- /dev/null +++ b/webpage/src/routes/models/edit/RunModel.svelte @@ -0,0 +1,78 @@ + +
+
+ +
+ Run image through them model and get the result +
+ + + + + Upload image + +
+ + Image selected + +
+
+
+ + + {#if run} + {#if !result} +
+

+ The class was not found +

+
+ {:else} +
+

+ Result +

+ The image was classified as {result} +
+ {/if} + {/if} + diff --git a/webpage/src/routes/models/edit/TrainModel.svelte b/webpage/src/routes/models/edit/TrainModel.svelte new file mode 100644 index 0000000..7d0a957 --- /dev/null +++ b/webpage/src/routes/models/edit/TrainModel.svelte @@ -0,0 +1,95 @@ + + +
+ {#if has_data} + {#if number_of_invalid_images > 0} +

+ There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip + These images will be delete when the model trains. +

+ {/if} + + +
+ Model Type +
+ +
+ + +
+
+ +
+ + +
+ +
+ + +
+ + + + {:else} +

To train the model please provide data to the model first

+ {/if} +