feat: closes #40

This commit is contained in:
Andre Henriques 2023-10-19 10:44:13 +01:00
parent f163e25fba
commit 2c3539b81a
4 changed files with 184 additions and 105 deletions

View File

@ -17,6 +17,9 @@ import (
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
) )
const EPOCH_PER_RUN = 20;
const MAX_EPOCH = 100
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = "" id = ""
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy) rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
@ -34,12 +37,13 @@ func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string
type ModelDefinitionStatus int type ModelDefinitionStatus int
const ( const (
MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3 MODEL_DEFINITION_STATUS_CANCELD_TRAINING = -4
MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1 MODEL_DEFINITION_STATUS_FAILED_TRAINING = -3
MODEL_DEFINITION_STATUS_INIT = 2 MODEL_DEFINITION_STATUS_PRE_INIT ModelDefinitionStatus = 1
MODEL_DEFINITION_STATUS_TRAINING = 3 MODEL_DEFINITION_STATUS_INIT = 2
MODEL_DEFINITION_STATUS_TRANIED = 4 MODEL_DEFINITION_STATUS_TRAINING = 3
MODEL_DEFINITION_STATUS_READY = 5 MODEL_DEFINITION_STATUS_TRANIED = 4
MODEL_DEFINITION_STATUS_READY = 5
) )
type LayerType int type LayerType int
@ -104,7 +108,8 @@ func generateCvs(c *Context, run_path string, model_id string) (count int, err e
return return
} }
func trainDefinition(c *Context, model *BaseModel, definition_id string) (accuracy float64, err error) { func trainDefinition(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
c.Logger.Warn("About to start training definition")
accuracy = 0 accuracy = 0
layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) layers, err := c.Db.Query("select layer_type, shape from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil { if err != nil {
@ -153,14 +158,20 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return return
} }
// Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = tmpl.Execute(f, AnyMap{ if err = tmpl.Execute(f, AnyMap{
"Layers": got, "Layers": got,
"Size": got[0].Shape, "Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"RunPath": run_path, "RunPath": run_path,
"ColorMode": model.ImageMode, "ColorMode": model.ImageMode,
"Model": model, "Model": model,
"DefId": definition_id, "EPOCH_PER_RUN": EPOCH_PER_RUN,
"DefId": definition_id,
"LoadPrev": load_prev,
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
}); err != nil { }); err != nil {
return return
} }
@ -168,13 +179,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
// Run the command // Run the command
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output() out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).Output()
if err != nil { if err != nil {
c.Logger.Debug(string(out)) c.Logger.Debug(string(out))
return return
} }
// Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = os.MkdirAll(result_path, os.ModePerm); err != nil { if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return return
} }
@ -183,6 +191,10 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return return
} }
if err = exec.Command("cp", "-r", path.Join(run_path, "model.keras"), path.Join(result_path, "model.keras")).Run(); err != nil {
return
}
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val")) accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
if err != nil { if err != nil {
return return
@ -194,7 +206,7 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return return
} }
fmt.Println(string(accuracy_file_bytes)) c.Logger.Info("Model finished training!", "accuracy", accuracy)
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64) accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
if err != nil { if err != nil {
@ -205,8 +217,25 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string) (accura
return return
} }
func remove[T interface{}](lst []T, i int) []T {
lng := len(lst)
if i >= lng {
return []T{}
}
if i+1 >= lng {
return lst[:lng-1]
}
if i == 0 {
return lst[1:]
}
return append(lst[:i], lst[i+1:]...)
}
func trainModel(c *Context, model *BaseModel) { func trainModel(c *Context, model *BaseModel) {
definitionsRows, err := c.Db.Query("select id, target_accuracy from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id) definitionsRows, err := c.Db.Query("select id, target_accuracy, epoch from model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_INIT, model.Id)
if err != nil { if err != nil {
c.Logger.Error("Failed to trainModel!Err:") c.Logger.Error("Failed to trainModel!Err:")
c.Logger.Error(err) c.Logger.Error(err)
@ -218,13 +247,14 @@ func trainModel(c *Context, model *BaseModel) {
type row struct { type row struct {
id string id string
target_accuracy int target_accuracy int
epoch int
} }
definitions := []row{} definitions := []row{}
for definitionsRows.Next() { for definitionsRows.Next() {
var rowv row var rowv row
if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy); err != nil { if err = definitionsRows.Scan(&rowv.id, &rowv.target_accuracy, &rowv.epoch); err != nil {
c.Logger.Error("Failed to train Model Could not read definition from db!Err:") c.Logger.Error("Failed to train Model Could not read definition from db!Err:")
c.Logger.Error(err) c.Logger.Error(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING) ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
@ -239,30 +269,58 @@ func trainModel(c *Context, model *BaseModel) {
return return
} }
for _, def := range definitions { toTrain := len(definitions)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING) firstRound := true
accuracy, err := trainDefinition(c, model, def.id) var newDefinitions = []row{}
if err != nil { copy(newDefinitions, definitions)
c.Logger.Error("Failed to train definition!Err:") for {
c.Logger.Error(err) for i, def := range definitions {
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING) ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_TRAINING)
continue accuracy, err := trainDefinition(c, model, def.id, !firstRound)
if err != nil {
c.Logger.Error("Failed to train definition!Err:", "err", err)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toTrain = toTrain - 1
newDefinitions = remove(newDefinitions, i)
continue
}
def.epoch += EPOCH_PER_RUN
int_accuracy := int(accuracy * 100)
if int_accuracy >= def.target_accuracy {
c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2, epoch=$3 where id=$4", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.epoch, def.id)
if err != nil {
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
_, err = c.Db.Exec("update model_definition set status=$1 where id!=$2 and model_id=$3 and status!=$4", MODEL_DEFINITION_STATUS_CANCELD_TRAINING, def.id, model.Id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
if err != nil {
c.Logger.Error("Failed to train definition!Err:\n", "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
}
toTrain = 0
break
}
if def.epoch > MAX_EPOCH {
fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
toTrain = toTrain - 1
newDefinitions = remove(newDefinitions, i)
continue
}
} }
copy(definitions, newDefinitions)
int_accuracy := int(accuracy * 100) firstRound = false
if toTrain == 0 {
if int_accuracy < def.target_accuracy { break
fmt.Printf("Failed to train definition! Accuracy less %d < %d\n", int_accuracy, def.target_accuracy)
ModelDefinitionUpdateStatus(c, def.id, MODEL_DEFINITION_STATUS_FAILED_TRAINING)
continue
}
_, err = c.Db.Exec("update model_definition set accuracy=$1, status=$2 where id=$3", int_accuracy, MODEL_DEFINITION_STATUS_TRANIED, def.id)
if err != nil {
fmt.Printf("Failed to train definition!Err:\n")
fmt.Println(err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
return
} }
} }
@ -335,7 +393,7 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
} }
defer rows.Close() defer rows.Close()
base_path := path.Join("savedData", model.Id, "data") base_path := path.Join("savedData", model.Id, "data")
for rows.Next() { for rows.Next() {
var dataPointId string var dataPointId string
@ -343,13 +401,13 @@ func removeFailedDataPoints(db *sql.DB, model *BaseModel) (err error) {
if err != nil { if err != nil {
return return
} }
err = os.RemoveAll(path.Join(base_path, dataPointId + model.Format)) err = os.RemoveAll(path.Join(base_path, dataPointId+model.Format))
if err != nil { if err != nil {
return return
} }
} }
_, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id) _, err = db.Exec("delete from model_data_point as mdp using model_classes as mc where mdp.class_id = mc.id and mc.model_id=$1 and mdp.status=-1;", model.Id)
return return
} }
@ -484,56 +542,56 @@ func handleTrain(handle *Handle) {
}) })
handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { handle.Get("/model/epoch/update", func(w http.ResponseWriter, r *http.Request, c *Context) *Error {
// TODO check auth level // TODO check auth level
if c.Mode != NORMAL { if c.Mode != NORMAL {
// This should only handle normal requests // This should only handle normal requests
c.Logger.Warn("This function only works with normal") c.Logger.Warn("This function only works with normal")
return c.UnsafeErrorCode(nil, 400, nil) return c.UnsafeErrorCode(nil, 400, nil)
} }
f := r.URL.Query() f := r.URL.Query()
if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") { if !CheckId(f, "model_id") || !CheckId(f, "definition") || CheckEmpty(f, "epoch") {
c.Logger.Warn("Invalid: model_id or definition or epoch") c.Logger.Warn("Invalid: model_id or definition or epoch")
return c.UnsafeErrorCode(nil, 400, nil) return c.UnsafeErrorCode(nil, 400, nil)
} }
model_id := f.Get("model_id") model_id := f.Get("model_id")
def_id := f.Get("definition") def_id := f.Get("definition")
epoch, err := strconv.Atoi(f.Get("epoch")) epoch, err := strconv.Atoi(f.Get("epoch"))
if err != nil { if err != nil {
c.Logger.Warn("Epoch is not a number") c.Logger.Warn("Epoch is not a number")
// No need to improve message because this function is only called internaly // No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil) return c.UnsafeErrorCode(nil, 400, nil)
} }
rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id) rows, err := c.Db.Query("select md.status from model_definition as md where md.model_id=$1 and md.id=$2", model_id, def_id)
if err != nil { if err != nil {
return c.Error500(err) return c.Error500(err)
} }
defer rows.Close() defer rows.Close()
if !rows.Next() { if !rows.Next() {
c.Logger.Error("Could not get status of model definition") c.Logger.Error("Could not get status of model definition")
return c.Error500(nil) return c.Error500(nil)
} }
var status int var status int
err = rows.Scan(&status) err = rows.Scan(&status)
if err != nil { if err != nil {
return c.Error500(err) return c.Error500(err)
} }
if status != 3 { if status != 3 {
c.Logger.Warn("Definition not on status 3(training)", "status", status) c.Logger.Warn("Definition not on status 3(training)", "status", status)
// No need to improve message because this function is only called internaly // No need to improve message because this function is only called internaly
return c.UnsafeErrorCode(nil, 400, nil) return c.UnsafeErrorCode(nil, 400, nil)
} }
_, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id) _, err = c.Db.Exec("update model_definition set epoch_progress=$1 where id=$2", epoch, def_id)
if err != nil { if err != nil {
return c.Error500(err) return c.Error500(err)
} }
return nil return nil
}) })
} }

View File

@ -40,7 +40,6 @@ create table if not exists model_data_point (
status_message text status_message text
); );
-- drop table if exists model_definition;
-- drop table if exists model_definition; -- drop table if exists model_definition;
create table if not exists model_definition ( create table if not exists model_definition (
id uuid primary key default gen_random_uuid(), id uuid primary key default gen_random_uuid(),

View File

@ -434,19 +434,36 @@
{{/* TODO improve this */}} {{/* TODO improve this */}}
Training the model...<br/> Training the model...<br/>
{{/* TODO Add progress status on definitions */}} {{/* TODO Add progress status on definitions */}}
{{ range .Defs}} <table>
<div> <thead>
<div> <tr>
{{.Status}} <th>
</div> Status
<div> </th>
{{.EpochProgress}} <th>
</div> EpochProgress
<div> </th>
{{.Accuracy}} <th>
</div> Accuracy
</div> </th>
{{ end }} </tr>
</thead>
<tbody>
{{ range .Defs}}
<tr>
<td>
{{.Status}}
</td>
<td>
{{.EpochProgress}}
</td>
<td>
{{.Accuracy}}
</td>
</tr>
{{ end }}
</tbody>
</table>
{{/* TODO Add ability to stop training */}} {{/* TODO Add ability to stop training */}}
</div> </div>
{{/* Model Ready */}} {{/* Model Ready */}}

View File

@ -93,6 +93,10 @@ val_ds = list_ds.take(val_size)
dataset = prepare_dataset(train_ds) dataset = prepare_dataset(train_ds)
dataset_validation = prepare_dataset(val_ds) dataset_validation = prepare_dataset(val_ds)
{{ if .LoadPrev }}
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
{{ else }}
model = keras.Sequential([ model = keras.Sequential([
{{- range .Layers }} {{- range .Layers }}
{{- if eq .LayerType 1}} {{- if eq .LayerType 1}}
@ -106,13 +110,14 @@ model = keras.Sequential([
{{- end }} {{- end }}
{{- end }} {{- end }}
]) ])
{{ end }}
model.compile( model.compile(
loss=losses.SparseCategoricalCrossentropy(), loss=losses.SparseCategoricalCrossentropy(),
optimizer=tf.keras.optimizers.Adam(), optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy']) metrics=['accuracy'])
his = model.fit(dataset, validation_data= dataset_validation, epochs=50, callbacks=[NotifyServerCallback()]) his = model.fit(dataset, validation_data= dataset_validation, epochs={{.EPOCH_PER_RUN}}, callbacks=[NotifyServerCallback()])
acc = his.history["accuracy"] acc = his.history["accuracy"]
@ -120,6 +125,6 @@ f = open("accuracy.val", "w")
f.write(str(acc[-1])) f.write(str(acc[-1]))
f.close() f.close()
tf.saved_model.save(model, "model")
# model.save("model.keras", save_format="tf") tf.saved_model.save(model, "model")
model.save("model.keras")