diff --git a/logic/models/train/train.go b/logic/models/train/train.go
index 42d9487..d5d24dc 100644
--- a/logic/models/train/train.go
+++ b/logic/models/train/train.go
@@ -17,6 +17,9 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
+const EPOCH_PER_RUN = 20;
+const MAX_EPOCH = 100
+
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = ""
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
@@ -34,12 +37,13 @@ func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string
type ModelDefinitionStatus int
const (
- MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
- MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
- MODEL_DEFINITION_STATUS_INIT = 2
- MODEL_DEFINITION_STATUS_TRAINING = 3
- MODEL_DEFINITION_STATUS_TRANIED = 4
- MODEL_DEFINITION_STATUS_READY = 5
+ MODEL_DEFINITION_STATUS_CANCELD_TRAINING = -4
+ MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
+ MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
+ MODEL_DEFINITION_STATUS_INIT = 2
+ MODEL_DEFINITION_STATUS_TRAINING = 3
+ MODEL_DEFINITION_STATUS_TRANIED = 4
+ MODEL_DEFINITION_STATUS_READY = 5
)
type LayerType int
@@ -104,7 +108,8 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e
return
}
-func trainDefinition(c *Context, model *BaseModel, definition_id string) (accuracy float64, err error) {
+func trainDefinition(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
+ c.Logger.Warn("About to start training definition")
accuracy = 0
layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
@@ -153,14 +158,20 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return
}
+ // Copy result around
+ result_path := path.Join("savedData", model.Id, "defs", definition_id)
+
if err = tmpl.Execute(f, AnyMap{
- "Layers": got,
- "Size": got[0].Shape,
- "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
- "RunPath": run_path,
- "ColorMode": model.ImageMode,
- "Model": model,
- "DefId": definition_id,
+ "Layers": got,
+ "Size": got[0].Shape,
+ "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
+ "RunPath": run_path,
+ "ColorMode": model.ImageMode,
+ "Model": model,
+ "EPOCH_PER_RUN": EPOCH_PER_RUN,
+ "DefId": definition_id,
+ "LoadPrev": load_prev,
+ "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
}); err != nil {
return
}
@@ -168,13 +179,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
// Run the command
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output()
if err != nil {
- c.Logger.Debug(string(out))
+ c.Logger.Debug(string(out))
return
}
- // Copy result around
- result_path := path.Join("savedData", model.Id, "defs", definition_id)
-
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return
}
@@ -183,6 +191,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return
}
+ if err = exec.Command("cp", "-r", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run(); err != nil {
+ return
+ }
+
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
if err != nil {
return
@@ -194,7 +206,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return
}
- fmt.Println(string(accuracy_file_bytes))
+ c.Logger.Info("Model finished training!", "accuracy", accuracy)
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
if err != nil {
@@ -205,8 +217,25 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return
}
+func remove[T interface{}](lst []T, i int) []T {
+ lng := len(lst)
+ if i >= lng {
+ return []T{}
+ }
+
+ if i+1 >= lng {
+ return lst[:lng-1]
+ }
+
+ if i == 0 {
+ return lst[1:]
+ }
+
+ return append(lst[:i], lst[i+1:]...)
+}
+
func trainModel(c *Context, model *BaseModel) {
- definitionsRows, err := c.Db.Query("select id, target_accuracy from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
+ definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
if err != nil {
c.Logger.Error("Failed to trainModel!Err:")
c.Logger.Error(err)
@@ -218,13 +247,14 @@ func trainModel(c *Context, model *BaseModel) {
type row struct {
id string
target_accuracy int
+ epoch int
}
definitions := []row{}
for definitionsRows.Next() {
var rowv row
- if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy); err != nil {
+ if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
c.Logger.Error("Failed to train Model Could not read definition from db!Err:")
c.Logger.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
@@ -239,30 +269,58 @@ func trainModel(c *Context, model *BaseModel) {
return
}
- for _, def := range definitions {
- ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
- accuracy, err := trainDefinition(c, model, def.id)
- if err != nil {
- c.Logger.Error("Failed to train definition!Err:")
- c.Logger.Error(err)
- ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
- continue
+ toTrain := len(definitions)
+ firstRound := true
+ var newDefinitions = []row{}
+ copy(newDefinitions, definitions)
+ for {
+ for i, def := range definitions {
+ ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
+ accuracy, err := trainDefinition(c, model, def.id, !firstRound)
+ if err != nil {
+ c.Logger.Error("Failed to train definition!Err:", "err", err)
+ ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ toTrain = toTrain - 1
+ newDefinitions = remove(newDefinitions, i)
+ continue
+ }
+ def.epoch += EPOCH_PER_RUN
+
+ int_accuracy := int(accuracy * 100)
+
+ if int_accuracy >= def.target_accuracy {
+ c.Logger.Info("Found a definition that reaches target_accuracy!")
+ _, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
+ if err != nil {
+ c.Logger.Error("Failed to train definition!Err:\n", "err", err)
+ ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
+ return
+ }
+
+ _, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ if err != nil {
+ c.Logger.Error("Failed to train definition!Err:\n", "err", err)
+ ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
+ return
+ }
+
+ toTrain = 0
+ break
+ }
+
+ if def.epoch > MAX_EPOCH {
+ fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
+ ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
+ toTrain = toTrain - 1
+ newDefinitions = remove(newDefinitions, i)
+ continue
+ }
+
}
-
- int_accuracy := int(accuracy * 100)
-
- if int_accuracy < def.target_accuracy {
- fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
- ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
- continue
- }
-
- _, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2 where id=$3", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.id)
- if err != nil {
- fmt.Printf("Failed to train definition!Err:\n")
- fmt.Println(err)
- ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
- return
+ copy(definitions, newDefinitions)
+ firstRound = false
+ if toTrain == 0 {
+ break
}
}
@@ -335,7 +393,7 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
}
defer rows.Close()
- base_path := path.Join("savedData", model.Id, "data")
+ base_path := path.Join("savedData", model.Id, "data")
for rows.Next() {
var dataPointId string
@@ -343,13 +401,13 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
if err != nil {
return
}
- err = os.RemoveAll(path.Join(base_path, dataPointId + model.Format))
- if err != nil {
- return
- }
+ err = os.RemoveAll(path.Join(base_path, dataPointId+model.Format))
+ if err != nil {
+ return
+ }
}
- _, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id)
+ _, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id)
return
}
@@ -484,56 +542,56 @@ func handleTrain(handle *Handle) {
})
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
- // TODO check auth level
+ // TODO check auth level
if c.Mode != NORMAL {
- // This should only handle normal requests
- c.Logger.Warn("This function only works with normal")
- return c.UnsafeErrorCode(nil, 400, nil)
+ // This should only handle normal requests
+ c.Logger.Warn("This function only works with normal")
+ return c.UnsafeErrorCode(nil, 400, nil)
}
- f := r.URL.Query()
+ f := r.URL.Query()
- if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
- c.Logger.Warn("Invalid: model_id or definition or epoch")
- return c.UnsafeErrorCode(nil, 400, nil)
- }
+ if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
+ c.Logger.Warn("Invalid: model_id or definition or epoch")
+ return c.UnsafeErrorCode(nil, 400, nil)
+ }
model_id := f.Get("model_id")
def_id := f.Get("definition")
- epoch, err := strconv.Atoi(f.Get("epoch"))
- if err != nil {
- c.Logger.Warn("Epoch is not a number")
- // No need to improve message because this function is only called internaly
- return c.UnsafeErrorCode(nil, 400, nil)
- }
+ epoch, err := strconv.Atoi(f.Get("epoch"))
+ if err != nil {
+ c.Logger.Warn("Epoch is not a number")
+ // No need to improve message because this function is only called internaly
+ return c.UnsafeErrorCode(nil, 400, nil)
+ }
- rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
- if err != nil {
- return c.Error500(err)
- }
- defer rows.Close()
+ rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
+ if err != nil {
+ return c.Error500(err)
+ }
+ defer rows.Close()
- if !rows.Next() {
- c.Logger.Error("Could not get status of model definition")
- return c.Error500(nil)
- }
+ if !rows.Next() {
+ c.Logger.Error("Could not get status of model definition")
+ return c.Error500(nil)
+ }
- var status int
- err = rows.Scan(&status)
- if err != nil {
- return c.Error500(err)
- }
+ var status int
+ err = rows.Scan(&status)
+ if err != nil {
+ return c.Error500(err)
+ }
- if status != 3 {
- c.Logger.Warn("Definition not on status 3(training)", "status", status)
- // No need to improve message because this function is only called internaly
- return c.UnsafeErrorCode(nil, 400, nil)
- }
+ if status != 3 {
+ c.Logger.Warn("Definition not on status 3(training)", "status", status)
+ // No need to improve message because this function is only called internaly
+ return c.UnsafeErrorCode(nil, 400, nil)
+ }
- _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
- if err != nil {
- return c.Error500(err)
- }
+ _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
+ if err != nil {
+ return c.Error500(err)
+ }
return nil
})
}
diff --git a/sql/models.sql b/sql/models.sql
index 943923e..ba72479 100644
--- a/sql/models.sql
+++ b/sql/models.sql
@@ -40,7 +40,6 @@ create table if not exists model_data_point (
status_message text
);
--- drop table if exists model_definition;
-- drop table if exists model_definition;
create table if not exists model_definition (
id uuid primary key default gen_random_uuid(),
diff --git a/views/models/edit.html b/views/models/edit.html
index d180e6d..fbc5f5a 100644
--- a/views/models/edit.html
+++ b/views/models/edit.html
@@ -434,19 +434,36 @@
{{/* TODO improve this */}}
Training the model...
{{/* TODO Add progress status on definitions */}}
- {{ range .Defs}}
-
+ Status + | ++ EpochProgress + | ++ Accuracy + | +
---|---|---|
+ {{.Status}} + | ++ {{.EpochProgress}} + | ++ {{.Accuracy}} + | +