diff --git a/logic/models/classes/list.go b/logic/models/classes/list.go
index a873953..6762ad9 100644
--- a/logic/models/classes/list.go
+++ b/logic/models/classes/list.go
@@ -38,10 +38,10 @@ func models_data_list_json(w http.ResponseWriter, r *http.Request, c *Context) *
}
type baserow struct {
- Id string
- File_Path string
- Model_Mode int
- Status int
+ Id string `json:"id"`
+ File_Path string `json:"file_path"`
+ Model_Mode int `json:"model_mode"`
+ Status int `json:"status"`
}
rows, err := utils.GetDbMultitple[baserow](c, "model_data_point where class_id=$1 limit 11 offset $2", id, page*10)
diff --git a/logic/models/edit.go b/logic/models/edit.go
index b4d8c03..0baa990 100644
--- a/logic/models/edit.go
+++ b/logic/models/edit.go
@@ -43,141 +43,9 @@ func handleJson(w http.ResponseWriter, r *http.Request, c *Context) *Error {
/*
- // Handle errors
- // All errors will be negative
- if model.Status < 0 {
- LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
- "Model": model,
- }))
- return nil
- }
-
switch model.Status {
- case READY:
- LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
- "Model": model,
- }))
case TRAINING:
-
- type defrow struct {
- Id string
- Status int
- EpochProgress int
- Epoch int
- Accuracy float64
- }
-
- defs := []defrow{}
-
- if model.Type == 2 {
- def_rows, err := c.Db.Query("select md.id, md.status, md.epoch, h.epoch_progress, h.accuracy from model_definition as md inner join exp_model_head as h on h.def_id = md.id where md.model_id=$1 order by md.created_on asc", model.Id)
- if err != nil {
- return c.Error500(err)
- }
- defer def_rows.Close()
-
- for def_rows.Next() {
- var def defrow
- err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy)
- if err != nil {
- return c.Error500(err)
- }
- defs = append(defs, def)
- }
- } else {
- def_rows, err := c.Db.Query("select id, status, epoch, epoch_progress, accuracy from model_definition where model_id=$1 order by created_on asc", model.Id)
- if err != nil {
- return c.Error500(err)
- }
- defer def_rows.Close()
-
- for def_rows.Next() {
- var def defrow
- err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy)
- if err != nil {
- return c.Error500(err)
- }
- defs = append(defs, def)
- }
- }
-
- type layerdef struct {
- id string
- LayerType int
- Shape string
- }
-
- layers := []layerdef{}
-
- for _, def := range defs {
- if def.Status == MODEL_DEFINITION_STATUS_TRAINING {
- rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
- if err != nil {
- return c.Error500(err)
- }
- defer rows.Close()
-
- for rows.Next() {
- var layerdef layerdef
- err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape)
- if err != nil {
- return c.Error500(err)
- }
- layers = append(layers, layerdef)
- }
-
- if model.Type == 2 {
-
- type lastLayerType struct {
- Id string
- Range_start int
- Range_end int
- }
-
- var lastLayer lastLayerType
-
- err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id)
- if err != nil {
- return c.Error500(err)
- }
-
- layers = append(layers, layerdef{
- id: lastLayer.Id,
- LayerType: LAYER_DENSE,
- Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start+1),
- })
- }
-
- break
- }
- }
-
- sep_mod := 100
- if len(layers) > 8 {
- sep_mod = 100 - (len(layers)-8)*10
- }
-
- if sep_mod < 10 {
- sep_mod = 10
- }
-
- LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
- "Model": model,
- "Defs": defs,
- "Layers": layers,
- "SepMod": sep_mod,
- }))
- case PREPARING_ZIP_FILE:
- LoadBasedOnAnswer(c.Mode, w, "/models/edit.html", c.AddMap(AnyMap{
- "Model": model,
- }))
- default:
- fmt.Printf("Unkown Status: %d\n", model.Status)
- return Error500(nil)
}
-
-
- return nil
*/
}
@@ -230,6 +98,151 @@ func handleEdit(handle *Handle) {
})
})
+ handle.Get("/models/edit/definitions", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
+ if !CheckAuthLevel(1, w, r, c) {
+ return nil
+ }
+ if c.Mode != JSON {
+ return c.ErrorCode(nil, 400, AnyMap{})
+ }
+
+ id, err := GetIdFromUrl(r, "id")
+ if err != nil {
+ return c.SendJSONStatus(http.StatusNotFound, "Model not found")
+ }
+
+ model, err := GetBaseModel(c.Db, id)
+ if err == ModelNotFoundError {
+ return c.SendJSONStatus(http.StatusNotFound, "Model not found")
+ } else if err != nil {
+ return c.Error500(err)
+ }
+
+ type defrow struct {
+ Id string
+ Status int
+ EpochProgress int
+ Epoch int
+ Accuracy float64
+ }
+
+ defs := []defrow{}
+
+ if model.ModelType == 2 {
+ def_rows, err := c.Db.Query("select md.id, md.status, md.epoch, h.epoch_progress, h.accuracy from model_definition as md inner join exp_model_head as h on h.def_id = md.id where md.model_id=$1 order by md.created_on asc", model.Id)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defer def_rows.Close()
+
+ for def_rows.Next() {
+ var def defrow
+ err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defs = append(defs, def)
+ }
+ } else {
+ def_rows, err := c.Db.Query("select id, status, epoch, epoch_progress, accuracy from model_definition where model_id=$1 order by created_on asc", model.Id)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defer def_rows.Close()
+
+ for def_rows.Next() {
+ var def defrow
+ err = def_rows.Scan(&def.Id, &def.Status, &def.Epoch, &def.EpochProgress, &def.Accuracy)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defs = append(defs, def)
+ }
+ }
+
+ type layerdef struct {
+ id string
+ LayerType int `json:"layer_type"`
+ Shape string `json:"shape"`
+ }
+
+ layers := []layerdef{}
+
+ for _, def := range defs {
+ if def.Status == MODEL_DEFINITION_STATUS_TRAINING {
+ rows, err := c.Db.Query("select id, layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", def.Id)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var layerdef layerdef
+ err = rows.Scan(&layerdef.id, &layerdef.LayerType, &layerdef.Shape)
+ if err != nil {
+ return c.Error500(err)
+ }
+ layers = append(layers, layerdef)
+ }
+
+ if model.ModelType == 2 {
+
+ type lastLayerType struct {
+ Id string
+ Range_start int
+ Range_end int
+ }
+
+ var lastLayer lastLayerType
+
+ err := GetDBOnce(c, &lastLayer, "exp_model_head where def_id=$1 and status=3;", def.Id)
+ if err != nil {
+ return c.Error500(err)
+ }
+
+ layers = append(layers, layerdef{
+ id: lastLayer.Id,
+ LayerType: LAYER_DENSE,
+ Shape: fmt.Sprintf("%d, 1", lastLayer.Range_end-lastLayer.Range_start+1),
+ })
+ }
+
+ break
+ }
+ }
+
+ type Definitions struct {
+ Id string `json:"id"`
+ Status int `json:"status"`
+ EpochProgress int `json:"epoch_progress"`
+ Epoch int `json:"epoch"`
+ Accuracy float64 `json:"accuracy"`
+ Layers *[]layerdef `json:"layers"`
+ }
+
+ defsToReturn := make([]Definitions, len(defs), len(defs))
+
+ setLayers := false
+
+ for i, def := range defs {
+ var lay *[]layerdef = nil
+ if def.Status == MODEL_DEFINITION_STATUS_TRAINING && !setLayers {
+ lay = &layers
+ setLayers = true
+ }
+ defsToReturn[i] = Definitions{
+ Id: def.Id,
+ Status: def.Status,
+ EpochProgress: def.EpochProgress,
+ Epoch: def.Epoch,
+ Accuracy: def.Accuracy,
+ Layers: lay,
+ }
+ }
+
+ return c.SendJSON(defsToReturn)
+ })
+
handle.Get("/models/edit", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
diff --git a/logic/models/run.go b/logic/models/run.go
index decb7c5..e79d470 100644
--- a/logic/models/run.go
+++ b/logic/models/run.go
@@ -39,7 +39,7 @@ func ReadJPG(scope *op.Scope, imagePath string, channels int64) *image.Image {
func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
order = 0
- err = nil
+ err = nil
tf_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "model"), []string{"serve"}, nil)
@@ -63,55 +63,55 @@ func runModelNormal(c *Context, model *BaseModel, def_id string, inputImage *tf.
}
func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Tensor) (order int, err error) {
-
- err = nil
- order = 0
- base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
+ err = nil
+ order = 0
+
+ base_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "base", "model"), []string{"serve"}, nil)
//results := base_model.Exec([]tf.Output{
- base_results := base_model.Exec([]tf.Output{
+ base_results := base_model.Exec([]tf.Output{
base_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
//base_model.Op("serving_default_rescaling_input", 0): inputImage,
- base_model.Op("serving_default_input_1", 0): inputImage,
+ base_model.Op("serving_default_input_1", 0): inputImage,
})
- type head struct {
- Id string
- Range_start int
- }
+ type head struct {
+ Id string
+ Range_start int
+ }
- heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id)
- if err != nil {
- return
- }
+ heads, err := GetDbMultitple[head](c, "exp_model_head where def_id=$1;", def_id)
+ if err != nil {
+ return
+ }
- var vmax float32 = 0.0
+ var vmax float32 = 0.0
- for _, element := range heads {
- head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil)
+ for _, element := range heads {
+ head_model := tg.LoadModel(path.Join("savedData", model.Id, "defs", def_id, "head", element.Id, "model"), []string{"serve"}, nil)
- results := head_model.Exec([]tf.Output{
- head_model.Op("StatefulPartitionedCall", 0),
- }, map[tf.Output]*tf.Tensor{
- head_model.Op("serving_default_input_2", 0): base_results[0],
- })
+ results := head_model.Exec([]tf.Output{
+ head_model.Op("StatefulPartitionedCall", 0),
+ }, map[tf.Output]*tf.Tensor{
+ head_model.Op("serving_default_input_2", 0): base_results[0],
+ })
- var predictions = results[0].Value().([][]float32)[0]
+ var predictions = results[0].Value().([][]float32)[0]
- for i, v := range predictions {
- if v > vmax {
- order = element.Range_start + i
- vmax = v
- }
- }
- }
+ for i, v := range predictions {
+ if v > vmax {
+ order = element.Range_start + i
+ vmax = v
+ }
+ }
+ }
- // TODO runthe head model
+ // TODO runthe head model
- c.Logger.Info("Got", "heads", len(heads))
- return
+ c.Logger.Info("Got", "heads", len(heads))
+ return
}
func handleRun(handle *Handle) {
@@ -120,8 +120,123 @@ func handleRun(handle *Handle) {
return nil
}
if c.Mode == JSON {
- // TODO improve message
- return ErrorCode(nil, 400, nil)
+
+ read_form, err := r.MultipartReader()
+ if err != nil {
+ // TODO improve message
+ return ErrorCode(nil, 400, nil)
+ }
+
+ var id string
+ var file []byte
+
+ for {
+ part, err_part := read_form.NextPart()
+ if err_part == io.EOF {
+ break
+ } else if err_part != nil {
+ return c.JsonBadRequest("Invalid multipart data")
+ }
+ if part.FormName() == "id" {
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(part)
+ id = buf.String()
+ }
+ if part.FormName() == "file" {
+ buf := new(bytes.Buffer)
+ buf.ReadFrom(part)
+ file = buf.Bytes()
+ }
+ }
+
+ model, err := GetBaseModel(handle.Db, id)
+ if err == ModelNotFoundError {
+ return c.JsonBadRequest("Models not found");
+ } else if err != nil {
+ return c.Error500(err)
+ }
+
+ if model.Status != READY {
+ return c.JsonBadRequest("Model not ready to run images")
+ }
+
+ def := JustId{}
+ err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id)
+ if err == NotFoundError {
+ return c.JsonBadRequest("Could not find definition")
+ } else if err != nil {
+ return c.Error500(err)
+ }
+
+ def_id := def.Id
+
+ // TODO create a database table with tasks
+ run_path := path.Join("/tmp", model.Id, "runs")
+ os.MkdirAll(run_path, os.ModePerm)
+ img_path := path.Join(run_path, "img."+model.Format)
+
+ img_file, err := os.Create(img_path)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defer img_file.Close()
+ img_file.Write(file)
+
+ if !testImgForModel(c, model, img_path) {
+ return c.JsonBadRequest("Provided image does not match the model")
+ }
+
+ root := tg.NewRoot()
+
+ var tf_img *image.Image = nil
+
+ switch model.Format {
+ case "png":
+ tf_img = ReadPNG(root, img_path, int64(model.ImageMode))
+ case "jpeg":
+ tf_img = ReadJPG(root, img_path, int64(model.ImageMode))
+ default:
+ panic("Not sure what to do with '" + model.Format + "'")
+ }
+
+ exec_results := tg.Exec(root, []tf.Output{tf_img.Value()}, nil, &tf.SessionOptions{})
+ inputImage, err := tf.NewTensor(exec_results[0].Value())
+ if err != nil {
+ return c.Error500(err)
+ }
+
+ vi := -1
+
+ if model.ModelType == 2 {
+ c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
+ vi, err = runModelExp(c, model, def_id, inputImage)
+ if err != nil {
+ return c.Error500(err)
+ }
+ } else {
+ c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
+ vi, err = runModelNormal(c, model, def_id, inputImage)
+ if err != nil {
+ return c.Error500(err)
+ }
+ }
+
+ os.RemoveAll(run_path)
+
+ rows, err := handle.Db.Query("select name from model_classes where model_id=$1 and class_order=$2;", model.Id, vi)
+ if err != nil {
+ return c.Error500(err)
+ }
+ if !rows.Next() {
+ return c.SendJSON(nil)
+ }
+
+ var name string
+ if err = rows.Scan(&name); err != nil {
+ return c.Error500(err)
+ }
+
+ return c.SendJSON(name)
}
read_form, err := r.MultipartReader()
@@ -220,21 +335,21 @@ func handleRun(handle *Handle) {
return Error500(err)
}
- vi := -1
+ vi := -1
- if model.ModelType == 2 {
- c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
- vi, err = runModelExp(c, model, def_id, inputImage)
- if err != nil {
- return c.Error500(err);
- }
- } else {
- c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
- vi, err = runModelNormal(c, model, def_id, inputImage)
- if err != nil {
- return c.Error500(err);
- }
- }
+ if model.ModelType == 2 {
+ c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
+ vi, err = runModelExp(c, model, def_id, inputImage)
+ if err != nil {
+ return c.Error500(err)
+ }
+ } else {
+ c.Logger.Info("Running model normal", "model", model.Id, "def", def_id)
+ vi, err = runModelNormal(c, model, def_id, inputImage)
+ if err != nil {
+ return c.Error500(err)
+ }
+ }
os.RemoveAll(run_path)
diff --git a/logic/models/train/reset.go b/logic/models/train/reset.go
index 4255321..1fde63a 100644
--- a/logic/models/train/reset.go
+++ b/logic/models/train/reset.go
@@ -10,53 +10,82 @@ import (
)
func handleRest(handle *Handle) {
- handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
- if !CheckAuthLevel(1, w, r, c) {
- return nil
- }
- if c.Mode == JSON {
- panic("handle JSON /models/train/reset")
- }
+ handle.Delete("/models/train/reset", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
+ if !CheckAuthLevel(1, w, r, c) {
+ return nil
+ }
+ if c.Mode == JSON {
+ var dat struct {
+ Id string `json:"id"`
+ }
+
+ if err := c.ToJSON(r, &dat); err != nil {
+ return err;
+ }
- f, err := MyParseForm(r)
- if err != nil {
- // TODO improve response
- return c.ErrorCode(nil, 400, c.AddMap(nil))
- }
+ model, err := GetBaseModel(c.Db, dat.Id)
+ if err == ModelNotFoundError {
+ return c.JsonBadRequest("Model not found");
+ } else if err != nil {
+ // TODO improve response
+ return c.Error500(err)
+ }
- if !CheckId(f, "id") {
- // TODO improve response
- return c.ErrorCode(nil, 400, c.AddMap(nil))
- }
+ if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING {
+ return c.JsonBadRequest("Model is not in status that be reset")
+ }
- id := f.Get("id")
+ os.RemoveAll(path.Join("savedData", model.Id, "defs"))
- model, err := GetBaseModel(handle.Db, id)
- if err == ModelNotFoundError {
+ _, err = c.Db.Exec("delete from model_definition where model_id=$1", model.Id)
+ if err != nil {
+ // TODO improve response
+ return c.Error500(err)
+ }
+
+ ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
+ return c.SendJSON(model.Id)
+ }
+
+ f, err := MyParseForm(r)
+ if err != nil {
+ // TODO improve response
+ return c.ErrorCode(nil, 400, c.AddMap(nil))
+ }
+
+ if !CheckId(f, "id") {
+ // TODO improve response
+ return c.ErrorCode(nil, 400, c.AddMap(nil))
+ }
+
+ id := f.Get("id")
+
+ model, err := GetBaseModel(handle.Db, id)
+ if err == ModelNotFoundError {
return c.ErrorCode(nil, http.StatusNotFound, AnyMap{
"NotFoundMessage": "Model not found",
"GoBackLink": "/models",
})
- } else if err != nil {
- // TODO improve response
- return c.Error500(err)
- }
+ } else if err != nil {
+ // TODO improve response
+ return c.Error500(err)
+ }
- if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING {
- // TODO improve response
- return c.ErrorCode(nil, 400, c.AddMap(nil))
- }
+ if model.Status != FAILED_PREPARING_TRAINING && model.Status != FAILED_TRAINING {
+ // TODO improve response
+ return c.ErrorCode(nil, 400, c.AddMap(nil))
+ }
- os.RemoveAll(path.Join("savedData", model.Id, "defs"))
+ os.RemoveAll(path.Join("savedData", model.Id, "defs"))
- _, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id)
- if err != nil {
- // TODO improve response
- return c.Error500(err)
- }
+ _, err = handle.Db.Exec("delete from model_definition where model_id=$1", model.Id)
+ if err != nil {
+ // TODO improve response
+ return c.Error500(err)
+ }
- ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
- Redirect("/models/edit?id=" + model.Id, c.Mode, w, r)
- return nil
- })
+ ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
+ Redirect("/models/edit?id="+model.Id, c.Mode, w, r)
+ return nil
+ })
}
diff --git a/logic/models/train/train.go b/logic/models/train/train.go
index 144cd57..029d3e8 100644
--- a/logic/models/train/train.go
+++ b/logic/models/train/train.go
@@ -42,9 +42,9 @@ func ModelDefinitionUpdateStatus(c *Context, id string, status ModelDefinitionSt
return
}
-func UpdateStatus (c *Context, table string, id string, status int) (err error) {
- _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
- return
+func UpdateStatus(c *Context, table string, id string, status int) (err error) {
+ _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
+ return
}
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
@@ -246,7 +246,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return
}
- UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
+ UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
@@ -283,7 +283,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
got = append(got, layerrow{
LayerType: LAYER_DENSE,
- Shape: fmt.Sprintf("%d", exp.end-exp.start + 1),
+ Shape: fmt.Sprintf("%d", exp.end-exp.start+1),
ExpType: 2,
LayerNum: i,
})
@@ -625,14 +625,14 @@ func trainModelExp(c *Context, model *BaseModel) {
var rowv TrainModelRow
rowv.acuracy = 0
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
- failed("Failed to train Model Could not read definition from db!")
+ failed("Failed to train Model Could not read definition from db!")
return
}
definitions = append(definitions, rowv)
}
if len(definitions) == 0 {
- failed("No Definitions defined!")
+ failed("No Definitions defined!")
return
}
@@ -661,13 +661,13 @@ func trainModelExp(c *Context, model *BaseModel) {
c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
if err != nil {
- failed("Failed to train definition!")
+ failed("Failed to train definition!")
return
}
_, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
if err != nil {
- failed("Failed to train definition!")
+ failed("Failed to train definition!")
return
}
@@ -684,7 +684,7 @@ func trainModelExp(c *Context, model *BaseModel) {
_, err = c.Db.Exec("update model_definition set accuracy=$1, epoch=$2, status=$3 where id=$4", accuracy, def.epoch, MODEL_DEFINITION_STATUS_PAUSED_TRAINING, def.id)
if err != nil {
- failed("Failed to train definition!")
+ failed("Failed to train definition!")
return
}
}
@@ -737,30 +737,30 @@ func trainModelExp(c *Context, model *BaseModel) {
rows, err := c.Db.Query("select id from model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
if err != nil {
- failed("DB: failed to read definition")
+ failed("DB: failed to read definition")
return
}
defer rows.Close()
if !rows.Next() {
- failed("All definitions failed to train!")
+ failed("All definitions failed to train!")
return
}
var id string
if err = rows.Scan(&id); err != nil {
- failed("Failed to read id")
+ failed("Failed to read id")
return
}
if _, err = c.Db.Exec("update model_definition set status=$1 where id=$2;", MODEL_DEFINITION_STATUS_READY, id); err != nil {
- failed("Failed to update model definition")
+ failed("Failed to update model definition")
return
}
to_delete, err := c.Db.Query("select id from model_definition where status != $1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
- failed("Failed to select model_definition to delete")
+ failed("Failed to select model_definition to delete")
return
}
defer to_delete.Close()
@@ -768,7 +768,7 @@ func trainModelExp(c *Context, model *BaseModel) {
for to_delete.Next() {
var id string
if to_delete.Scan(&id); err != nil {
- failed("Failed to scan the id of a model_definition to delete")
+ failed("Failed to scan the id of a model_definition to delete")
return
}
os.RemoveAll(path.Join("savedData", model.Id, "defs", id))
@@ -776,24 +776,24 @@ func trainModelExp(c *Context, model *BaseModel) {
// TODO Check if returning also works here
if _, err = c.Db.Exec("delete from model_definition where status!=$1 and model_id=$2;", MODEL_DEFINITION_STATUS_READY, model.Id); err != nil {
- failed("Failed to delete model_definition")
+ failed("Failed to delete model_definition")
return
}
- if err = splitModel(c, model); err != nil {
- failed("Failed to split the model")
+ if err = splitModel(c, model); err != nil {
+ failed("Failed to split the model")
return
- }
-
- // There should only be one def availabale
- def := JustId{}
+ }
- if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
- return
- }
-
- // Remove the base model
- c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
+ // There should only be one def availabale
+ def := JustId{}
+
+ if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
+ return
+ }
+
+ // Remove the base model
+ c.Logger.Warn("Removing base model for", "model", model.Id, "def", def.Id)
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model"))
os.RemoveAll(path.Join("savedData", model.Id, "defs", def.Id, "model.keras"))
@@ -802,17 +802,17 @@ func trainModelExp(c *Context, model *BaseModel) {
func splitModel(c *Context, model *BaseModel) (err error) {
- def := JustId{}
+ def := JustId{}
- if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
- return
- }
+ if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
+ return
+ }
- head := JustId{}
+ head := JustId{}
- if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
- return
- }
+ if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
+ return
+ }
// Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", def.Id)
@@ -821,7 +821,7 @@ func splitModel(c *Context, model *BaseModel) (err error) {
if err != nil {
return
}
- // TODO reneable it
+ // TODO reneable it
// defer os.RemoveAll(run_path)
// Create python script
@@ -838,9 +838,9 @@ func splitModel(c *Context, model *BaseModel) (err error) {
// Copy result around
result_path := path.Join(getDir(), "savedData", model.Id, "defs", def.Id)
-
- // TODO maybe move this to a select count(*)
- // Get only fixed lawers
+
+ // TODO maybe move this to a select count(*)
+ // Get only fixed lawers
layers, err := c.Db.Query("select exp_type from model_definition_layer where def_id=$1 and exp_type=$2 order by layer_order asc;", def.Id, 1)
if err != nil {
return
@@ -848,24 +848,24 @@ func splitModel(c *Context, model *BaseModel) (err error) {
defer layers.Close()
type layerrow struct {
- ExpType int
+ ExpType int
}
-
- count := -1
+
+ count := -1
for layers.Next() {
- count += 1
+ count += 1
}
- if count == -1 {
- err = errors.New("Can not get layers")
- return
- }
+ if count == -1 {
+ err = errors.New("Can not get layers")
+ return
+ }
- log.Warn("Spliting model", "def", def.Id, "head", head.Id, "count", count)
+ log.Warn("Spliting model", "def", def.Id, "head", head.Id, "count", count)
- basePath := path.Join(result_path, "base")
- headPath := path.Join(result_path, "head", head.Id)
+ basePath := path.Join(result_path, "base")
+ headPath := path.Join(result_path, "head", head.Id)
if err = os.MkdirAll(basePath, os.ModePerm); err != nil {
return
@@ -876,10 +876,10 @@ func splitModel(c *Context, model *BaseModel) (err error) {
}
if err = tmpl.Execute(f, AnyMap{
- "SplitLen": count,
- "ModelPath": path.Join(result_path, "model.keras"),
- "BaseModelPath": basePath,
- "HeadModelPath": headPath,
+ "SplitLen": count,
+ "ModelPath": path.Join(result_path, "model.keras"),
+ "BaseModelPath": basePath,
+ "HeadModelPath": headPath,
}); err != nil {
return
}
@@ -892,7 +892,7 @@ func splitModel(c *Context, model *BaseModel) (err error) {
c.Logger.Info("Python finished running")
- return
+ return
}
func removeFailedDataPoints(c *Context, model *BaseModel) (err error) {
@@ -1089,7 +1089,7 @@ func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatu
// This generates a definition
func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy int, number_of_classes int, complexity int) *Error {
- c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity)
+ c.Logger.Info("Generating expandable new definition for model", "id", model.Id, "complexity", complexity)
var err error = nil
failed := func() *Error {
@@ -1133,13 +1133,13 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
// Create the blocks
loop := int((math.Log(float64(model.Width)) / math.Log(float64(10))))
- if model.Width < 50 && model.Height < 50 {
- loop = 0
- }
+ if model.Width < 50 && model.Height < 50 {
+ loop = 0
+ }
- log.Info("Size of the simple block", "loop", loop)
+ log.Info("Size of the simple block", "loop", loop)
- //loop = max(loop, 3)
+ //loop = max(loop, 3)
for i := 0; i < loop; i++ {
err = MakeLayerExpandable(c.Db, def_id, order, LAYER_SIMPLE_BLOCK, "", 1)
@@ -1165,9 +1165,9 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
loop = int((math.Log(float64(number_of_classes)) / math.Log(float64(10))) / 2)
- log.Info("Size of the dense layers", "loop", loop)
+ log.Info("Size of the dense layers", "loop", loop)
- // loop = max(loop, 3)
+ // loop = max(loop, 3)
for i := 0; i < loop; i++ {
err = MakeLayer(c.Db, def_id, order, LAYER_DENSE, fmt.Sprintf("%d,1", number_of_classes*(loop-i)))
@@ -1176,7 +1176,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
return failed()
}
}
-
+
_, err = CreateExpModelHead(c, def_id, 0, number_of_classes-1, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
return failed()
@@ -1190,6 +1190,7 @@ func generateExpandableDefinition(c *Context, model *BaseModel, target_accuracy
return nil
}
+// TODO make this json friendy
func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy int, number_of_models int) *Error {
cls, err := model_classes.ListClasses(c.Db, model.Id)
if err != nil {
@@ -1225,13 +1226,77 @@ func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy
return nil
}
+func handle_models_train_json(w http.ResponseWriter, r *http.Request, c *Context) *Error {
+ var dat struct {
+ Id string `json:"id"`
+ ModelType string `json:"model_type"`
+ NumberOfModels int `json:"number_of_models"`
+ Accuracy int `json:"accuracy"`
+ }
+
+ if err_ := c.ToJSON(r, &dat); err_ != nil {
+ return err_
+ }
+
+ if dat.Id == "" {
+ return c.JsonBadRequest("Please provide a id")
+ }
+
+ modelTypeId := 1
+ if dat.ModelType == "expandable" {
+ modelTypeId = 2
+ } else if dat.ModelType != "simple" {
+ return c.JsonBadRequest("Invalid model type!")
+ }
+
+ model, err := GetBaseModel(c.Db, dat.Id)
+ if err == ModelNotFoundError {
+ return c.JsonBadRequest("Model not found")
+ } else if err != nil {
+ return c.Error500(err)
+ }
+
+ if model.Status != CONFIRM_PRE_TRAINING {
+ return c.JsonBadRequest("Model in invalid status for training")
+ }
+
+ if modelTypeId == 2 {
+ full_error := generateExpandableDefinitions(c, model, dat.Accuracy, dat.NumberOfModels)
+ if full_error != nil {
+ return full_error
+ }
+ } else {
+ full_error := generateDefinitions(c, model, dat.Accuracy, dat.NumberOfModels)
+ if full_error != nil {
+ return full_error
+ }
+ }
+
+ if modelTypeId == 2 {
+ go trainModelExp(c, model)
+ } else {
+ go trainModel(c, model)
+ }
+
+ _, err = c.Db.Exec("update models set status = $1, model_type = $2 where id = $3", TRAINING, modelTypeId, model.Id)
+ if err != nil {
+ fmt.Println("Failed to update model status")
+ fmt.Println(err)
+ // TODO improve this response
+ return Error500(err)
+ }
+
+ return c.SendJSON(model.Id)
+}
+
func handleTrain(handle *Handle) {
handle.Post("/models/train", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
if !CheckAuthLevel(1, w, r, c) {
return nil
}
+
if c.Mode == JSON {
- panic("TODO /models/train JSON")
+ return handle_models_train_json(w, r, c)
}
r.ParseForm()
diff --git a/logic/utils/handler.go b/logic/utils/handler.go
index c2b7b29..a556263 100644
--- a/logic/utils/handler.go
+++ b/logic/utils/handler.go
@@ -387,6 +387,7 @@ func (c Context) SendJSONStatus(status int, dat any) *Error {
}
func (c Context) JsonBadRequest(dat any) *Error {
+ c.Logger.Warn("Request failed with a bad request", "dat", dat)
return c.SendJSONStatus(http.StatusBadRequest, dat)
}
@@ -620,6 +621,7 @@ func (x Handle) ReadFiles(pathTest string, baseFilePath string, fileType string,
})
}
+// TODO remove this
func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) {
http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) {
user_path := r.URL.Path[len(pathTest):]
@@ -656,6 +658,42 @@ func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes [
})
}
+func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) {
+ http.HandleFunc("/api" + pathTest, func(w http.ResponseWriter, r *http.Request) {
+ r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1)
+
+ user_path := r.URL.Path[len(pathTest):]
+
+ found := false
+ index := -1
+
+ for i, fileType := range fileTypes {
+ if strings.HasSuffix(user_path, fileType) {
+ found = true
+ index = i
+ break
+ }
+ }
+
+ if !found {
+ w.WriteHeader(http.StatusNotFound)
+ w.Write([]byte("File not found"))
+ return
+ }
+
+ bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path))
+ if err != nil {
+ fmt.Println(err)
+ w.WriteHeader(http.StatusInternalServerError)
+ w.Write([]byte("Failed to load file"))
+ return
+ }
+
+ w.Header().Set("Content-Type", contentTypes[index])
+ w.Write(bytes)
+ })
+}
+
func NewHandler(db *sql.DB) *Handle {
var gets []HandleFunc
diff --git a/main.go b/main.go
index 6bfa1ac..3aa7a6f 100644
--- a/main.go
+++ b/main.go
@@ -48,7 +48,7 @@ func main() {
handle.StaticFiles("/js/", ".js", "text/javascript")
handle.ReadFiles("/imgs/", "views", ".png", "image/png;")
handle.ReadTypesFiles("/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"})
- handle.ReadTypesFiles("/api/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"})
+ handle.ReadTypesFilesApi("/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"})
handle.GetHTML("/", AnswerTemplate("index.html", nil, 0))
diff --git a/webpage/src/app.html b/webpage/src/app.html
index 77a5ff5..7e570fb 100644
--- a/webpage/src/app.html
+++ b/webpage/src/app.html
@@ -4,6 +4,8 @@
+
+
%sveltekit.head%
diff --git a/webpage/src/lib/MessageSimple.svelte b/webpage/src/lib/MessageSimple.svelte
index 86d1d2e..8da5f0d 100644
--- a/webpage/src/lib/MessageSimple.svelte
+++ b/webpage/src/lib/MessageSimple.svelte
@@ -14,6 +14,11 @@
let timeout: number | undefined = undefined;
+ export function clear() {
+ if (timeout) clearTimeout(timeout);
+ message = undefined;
+ }
+
export function display(
msg: string,
options?: {
diff --git a/webpage/src/lib/requests.svelte.ts b/webpage/src/lib/requests.svelte.ts
index 37f52a5..bcfc445 100644
--- a/webpage/src/lib/requests.svelte.ts
+++ b/webpage/src/lib/requests.svelte.ts
@@ -16,7 +16,10 @@ export async function get(url: string) {
headers: headers,
});
- if (r.status !== 200) {
+ if (r.status === 401) {
+ userStore.user = undefined;
+ goto("/login")
+ } else if (r.status !== 200) {
throw r;
}
@@ -35,8 +38,11 @@ export async function post(url: string, body: any) {
headers: headers,
body: JSON.stringify(body),
});
-
- if (r.status !== 200) {
+
+ if (r.status === 401) {
+ userStore.user = undefined;
+ goto("/login")
+ } else if (r.status !== 200) {
throw r;
}
@@ -56,7 +62,10 @@ export async function rdelete(url: string, body: any) {
body: JSON.stringify(body),
});
- if (r.status !== 200) {
+ if (r.status === 401) {
+ userStore.user = undefined;
+ goto("/login")
+ } else if (r.status !== 200) {
throw r;
}
diff --git a/webpage/src/routes/models/edit/+page.svelte b/webpage/src/routes/models/edit/+page.svelte
index bb734fb..e478dfe 100644
--- a/webpage/src/routes/models/edit/+page.svelte
+++ b/webpage/src/routes/models/edit/+page.svelte
@@ -1,25 +1,25 @@
- {#await model}
-
- Model
-
- {:then m}
- {#if m}
-
- Model: {m.name}
-
- {:else}
-
- Model
-
- {/if}
- {/await}
+ {#await model}
+ Model
+ {:then m}
+ {#if m}
+
+ Model: {m.name}
+
+ {:else}
+ Model
+ {/if}
+ {/await}
-
-
-
-
-
-
- {#await model}
- Loading
- {:then m}
- {#if m.status == 1}
-
-
- { m.name }
-
-
-
- Preparing the model
-
-
- {:else if m.status == -1}
-
-
- {m.name}
-
-
-
- Failed to prepare model
-
+ {#await model}
+ Loading
+ {:then m}
+ {#if m.status == 1}
+
+
+ {m.name}
+
+
+ Preparing the model
+
+ {:else if m.status == -1}
+
+
+ {m.name}
+
+
+
Failed to prepare model
-
- TODO button delete
-
-
-
-
-
- {:else if m.status == 2 }
-
-
-
-
- {:else if m.status == -2 }
-
-
-
- {:else if m.status == 3 }
-
-
-
- Processing zip file...
-
- {:else if m.status == -3 || m.status == -4}
-
-
-
- {:else if m.status == 4}
-
-
-
-
- Training the model...
-
- {#await definitions}
- Loading
- {:then defs}
-
-
-
-
- Done Progress
- |
-
- Training Round Progress
- |
-
- Accuracy
- |
-
- Status
- |
-
-
-
- {#each defs as def}
-
-
- {def.epoch}
- |
-
- {def.epoch_progress}/20
- |
-
- {def.accuracy}%
- |
-
- {#if def.status == 2}
-
- {:else if [3,6,-3].includes(def.status) }
-
- {:else}
- {def.status}
- {/if}
- |
-
- {#if def.status == 3 && def.layers}
-
-
-
- |
-
- {/if}
- {/each}
-
-
- {/await}
- {{/* TODO Add ability to stop training */}}
-
- {:else if m.status == 5}
-
- TODO run model
-
-
- {:else}
-
- Unknown Status of the model.
-
- {/if}
- {/await}
+
+
+ {:else if m.status == 2}
+
+
+
+
+ {:else if m.status == -2}
+
+
+
+ {:else if m.status == 3}
+
+
+
+ Processing zip file...
+
+ {:else if m.status == -3 || m.status == -4}
+
+
+
+ {:else if m.status == 4}
+
+
+
+
+ Training the model...
+
+ {#await definitions}
+ Loading
+ {:then defs}
+
+
+
+ Done Progress |
+ Training Round Progress |
+ Accuracy |
+ Status |
+
+
+
+ {#each defs as def}
+
+
+ {def.epoch}
+ |
+
+ {def.epoch_progress}/20
+ |
+
+ {def.accuracy}%
+ |
+
+ {#if def.status == 2}
+
+ {:else if [3, 6, -3].includes(def.status)}
+
+ {:else}
+ {def.status}
+ {/if}
+ |
+
+ {#if def.status == 3 && def.layers}
+
+
+
+ |
+
+ {/if}
+ {/each}
+
+
+ {/await}
+
+
+ {:else if m.status == 5}
+
+
+
+ {:else}
+ Unknown Status of the model.
+ {/if}
+ {/await}
diff --git a/webpage/src/routes/models/edit/ModelData.svelte b/webpage/src/routes/models/edit/ModelData.svelte
index 5d06dc9..0a97a32 100644
--- a/webpage/src/routes/models/edit/ModelData.svelte
+++ b/webpage/src/routes/models/edit/ModelData.svelte
@@ -12,10 +12,12 @@
import MessageSimple from "src/lib/MessageSimple.svelte";
import { createEventDispatcher } from "svelte";
import ModelTable from "./ModelTable.svelte";
+ import TrainModel from "./TrainModel.svelte";
let { model } = $props<{model: Model}>();
let classes: Class[] = $state([]);
+ let has_data: boolean = $state(false);
let file: File | undefined = $state();
@@ -61,6 +63,7 @@
let data = await get(`models/edit/classes?id=${model.id}`);
classes = data.classes
numberOfInvalidImages = data.number_of_invalid_images;
+ has_data = data.has_data;
} catch {
return;
}
@@ -150,7 +153,7 @@
-
+
TODO
@@ -177,7 +180,7 @@
-
+
TODO
@@ -185,3 +188,5 @@
{/if}
+
+ dispatch('reload')} />
diff --git a/webpage/src/routes/models/edit/ModelTable.svelte b/webpage/src/routes/models/edit/ModelTable.svelte
index a2ab808..bc5ec37 100644
--- a/webpage/src/routes/models/edit/ModelTable.svelte
+++ b/webpage/src/routes/models/edit/ModelTable.svelte
@@ -1,15 +1,28 @@
+
+
{#if classes.length == 0}
@@ -56,8 +77,109 @@
{/each}
+
+
+
+
+ File Path |
+ Mode |
+
+
+ |
+
+
+ |
+
+
+
+ {#each image_list as image}
+
+
+ {#if image.file_path == 'id://'}
+ Managed
+ {:else}
+ {image.file_path}
+ {/if}
+ |
+
+ {#if image.mode == 2}
+ Testing
+ {:else}
+ Training
+ {/if}
+ |
+
+ {#if image.file_path == 'id://'}
+
+ {:else}
+ TODO img {image.file_path}
+ {/if}
+ |
+
+ {#if image.status == 1}
+
+ {:else}
+
+ {/if}
+ |
+
+ {/each}
+
+
+
+
+ {#if page > 0}
+
+ {/if}
+
+
+
+ {page}
+
+
+
+ {#if showNext}
+
+ {/if}
+
+
+
{/if}
diff --git a/webpage/src/routes/models/edit/RunModel.svelte b/webpage/src/routes/models/edit/RunModel.svelte
new file mode 100644
index 0000000..c069c95
--- /dev/null
+++ b/webpage/src/routes/models/edit/RunModel.svelte
@@ -0,0 +1,78 @@
+
+
diff --git a/webpage/src/routes/models/edit/TrainModel.svelte b/webpage/src/routes/models/edit/TrainModel.svelte
new file mode 100644
index 0000000..7d0a957
--- /dev/null
+++ b/webpage/src/routes/models/edit/TrainModel.svelte
@@ -0,0 +1,95 @@
+
+
+