added the ability to expand the models

This commit is contained in:
Andre Henriques 2024-04-08 14:17:13 +01:00
parent 274d7d22aa
commit de0b430467
15 changed files with 1086 additions and 197 deletions

View File

@ -60,6 +60,7 @@ func HandleList(handle *Handle) {
max_len := min(11, len(rows)) max_len := min(11, len(rows))
c.ShowMessage = false;
return c.SendJSON(ReturnType{ return c.SendJSON(ReturnType{
ImageList: rows[0:max_len], ImageList: rows[0:max_len],
Page: page, Page: page,

View File

@ -156,7 +156,7 @@ func processZipFileExpand(c *Context, model *BaseModel) {
failed := func(msg string) { failed := func(msg string) {
c.Logger.Error(msg, "err", err) c.Logger.Error(msg, "err", err)
ModelUpdateStatus(c, model.Id, READY_FAILED) ModelUpdateStatus(c, model.Id, READY_ALTERATION_FAILED)
} }
reader, err := zip.OpenReader(path.Join("savedData", model.Id, "expand_data.zip")) reader, err := zip.OpenReader(path.Join("savedData", model.Id, "expand_data.zip"))
@ -202,8 +202,19 @@ func processZipFileExpand(c *Context, model *BaseModel) {
ids := map[string]string{} ids := map[string]string{}
var baseOrder struct {
Order int `db:"class_order"`
}
err = GetDBOnce(c, &baseOrder, "model_classes where model_id=$1 order by class_order desc;", model.Id)
if err != nil {
failed("Failed to get the last class_order")
}
base := baseOrder.Order + 1
for i, name := range training { for i, name := range training {
id, err := model_classes.CreateClass(c.Db, model.Id, i, name) id, err := model_classes.CreateClass(c.Db, model.Id, base + i, name)
if err != nil { if err != nil {
failed(fmt.Sprintf("Failed to create class '%s' on db\n", name)) failed(fmt.Sprintf("Failed to create class '%s' on db\n", name))
return return
@ -416,7 +427,7 @@ func handleDataUpload(handle *Handle) {
delete_path := "base_data.zip" delete_path := "base_data.zip"
if model.Status == READY_FAILED { if model.Status == READY_ALTERATION_FAILED {
delete_path = "expand_data.zip" delete_path = "expand_data.zip"
} else if model.Status != FAILED_PREPARING_ZIP_FILE { } else if model.Status != FAILED_PREPARING_ZIP_FILE {
return c.JsonBadRequest("Model not in the correct status") return c.JsonBadRequest("Model not in the correct status")
@ -427,7 +438,7 @@ func handleDataUpload(handle *Handle) {
return c.Error500(err) return c.Error500(err)
} }
if model.Status != READY_FAILED { if model.Status != READY_ALTERATION_FAILED {
err = os.RemoveAll(path.Join("savedData", model.Id, "data")) err = os.RemoveAll(path.Join("savedData", model.Id, "data"))
if err != nil { if err != nil {
return c.Error500(err) return c.Error500(err)
@ -436,7 +447,7 @@ func handleDataUpload(handle *Handle) {
c.Logger.Warn("Handle failed to remove the savedData when deleteing the zip file while expanding") c.Logger.Warn("Handle failed to remove the savedData when deleteing the zip file while expanding")
} }
if model.Status != READY_FAILED { if model.Status != READY_ALTERATION_FAILED {
_, err = handle.Db.Exec("delete from model_classes where model_id=$1;", model.Id) _, err = handle.Db.Exec("delete from model_classes where model_id=$1;", model.Id)
if err != nil { if err != nil {
return c.Error500(err) return c.Error500(err)
@ -448,7 +459,7 @@ func handleDataUpload(handle *Handle) {
} }
} }
if model.Status != READY_FAILED { if model.Status != READY_ALTERATION_FAILED {
ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING) ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
} else { } else {
ModelUpdateStatus(c, model.Id, READY) ModelUpdateStatus(c, model.Id, READY)

View File

@ -29,7 +29,7 @@ func deleteModelJSON(c *Context, id string) *Error {
func handleDelete(handle *Handle) { func handleDelete(handle *Handle) {
handle.Delete("/models/delete", func(c *Context) *Error { handle.Delete("/models/delete", func(c *Context) *Error {
if c.CheckAuthLevel(1) { if !c.CheckAuthLevel(1) {
return nil return nil
} }
var dat struct { var dat struct {
@ -66,6 +66,10 @@ func handleDelete(handle *Handle) {
case READY: case READY:
fallthrough fallthrough
case READY_RETRAIN_FAILED:
fallthrough
case READY_ALTERATION_FAILED:
fallthrough
case CONFIRM_PRE_TRAINING: case CONFIRM_PRE_TRAINING:
if dat.Name == nil { if dat.Name == nil {
return c.JsonBadRequest("Provided name does not match the model name") return c.JsonBadRequest("Provided name does not match the model name")

View File

@ -41,6 +41,7 @@ func handleEdit(handle *Handle) {
NumberOfInvalidImages int `json:"number_of_invalid_images"` NumberOfInvalidImages int `json:"number_of_invalid_images"`
} }
c.ShowMessage = false;
return c.SendJSON(ReturnType{ return c.SendJSON(ReturnType{
Classes: cls, Classes: cls,
HasData: has_data, HasData: has_data,
@ -180,6 +181,7 @@ func handleEdit(handle *Handle) {
} }
} }
c.ShowMessage = false;
return c.SendJSON(defsToReturn) return c.SendJSON(defsToReturn)
}) })
@ -188,10 +190,6 @@ func handleEdit(handle *Handle) {
return nil return nil
} }
if !c.CheckAuthLevel(1) {
return nil
}
id, err := GetIdFromUrl(c, "id") id, err := GetIdFromUrl(c, "id")
if err != nil { if err != nil {
return c.JsonBadRequest("Model not found") return c.JsonBadRequest("Model not found")
@ -216,6 +214,7 @@ func handleEdit(handle *Handle) {
return c.Error500(err) return c.Error500(err)
} }
c.ShowMessage = false
return c.SendJSON(model) return c.SendJSON(model)
}) })
} }

View File

@ -87,6 +87,8 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
return return
} }
c.Logger.Info("test", "count", len(heads))
var vmax float32 = 0.0 var vmax float32 = 0.0
for _, element := range heads { for _, element := range heads {
@ -95,12 +97,14 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
results := head_model.Exec([]tf.Output{ results := head_model.Exec([]tf.Output{
head_model.Op("StatefulPartitionedCall", 0), head_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{ }, map[tf.Output]*tf.Tensor{
head_model.Op("serving_default_input_2", 0): base_results[0], head_model.Op("serving_default_head_input", 0): base_results[0],
}) })
var predictions = results[0].Value().([][]float32)[0] var predictions = results[0].Value().([][]float32)[0]
for i, v := range predictions { for i, v := range predictions {
c.Logger.Info("predictions", "class", i, "preds", v)
if v > vmax { if v > vmax {
order = element.Range_start + i order = element.Range_start + i
vmax = v vmax = v
@ -111,7 +115,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
// TODO runthe head model // TODO runthe head model
confidence = vmax confidence = vmax
c.Logger.Info("Got", "heads", len(heads)) c.Logger.Info("Got", "heads", len(heads), "order", order, "vmax", vmax)
return return
} }
@ -155,7 +159,7 @@ func handleRun(handle *Handle) {
return c.Error500(err) return c.Error500(err)
} }
if model.Status != READY { if model.Status != READY && model.Status != READY_RETRAIN && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model not ready to run images") return c.JsonBadRequest("Model not ready to run images")
} }

View File

@ -36,13 +36,16 @@ func getDir() string {
return dir return dir
} }
// This function creates a new model_definition
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = "" id = ""
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy) rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
if err != nil { if err != nil {
return return
} }
defer rows.Close() defer rows.Close()
if !rows.Next() { if !rows.Next() {
return id, errors.New("Something wrong!") return id, errors.New("Something wrong!")
} }
@ -72,17 +75,14 @@ func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type
func generateCvs(c *Context, run_path string, model_id string) (count int, err error) { func generateCvs(c *Context, run_path string, model_id string) (count int, err error) {
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id) var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1;", model_id)
if err != nil { if err != nil {
return return
} }
defer classes.Close() count = co.Count
if !classes.Next() {
return
}
if err = classes.Scan(&count); err != nil {
return
}
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING) data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING)
if err != nil { if err != nil {
@ -121,19 +121,14 @@ func setModelClassStatus(c *Context, status ModelClassStatus, filter string, arg
func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) { func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) {
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING) var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
if err != nil { if err != nil {
return return
} }
defer classes.Close() count = co.Count
if !classes.Next() {
return
}
if err = classes.Scan(&count); err != nil {
return
}
if count == 0 { if count == 0 {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN) err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
@ -214,7 +209,6 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
if err != nil { if err != nil {
return return
} }
defer removeAll(run_path, err)
classCount, err := generateCvs(c, run_path, model.Id) classCount, err := generateCvs(c, run_path, model.Id)
if err != nil { if err != nil {
@ -283,55 +277,125 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
return return
} }
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy) c.Logger.Info("Model finished training!", "accuracy", accuracy)
return return
} }
func removeAll(path string, err error) { func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
if err != nil {
os.RemoveAll(path) var co struct {
Count int `db:"count(*)"`
} }
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
c.Logger.Info("test here", "count", co)
count_re = co.Count
count := co.Count
if count == 0 {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
if err != nil {
return
} else if doPanic {
return 0, errors.New("No model classes available")
}
return generateCvsExpandExp(c, run_path, model_id, offset, true)
}
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
defer data.Close()
f, err := os.Create(path.Join(run_path, "train.csv"))
if err != nil {
return
}
defer f.Close()
f.Write([]byte("Id,Index\n"))
count = 0
for data.Next() {
var id string
var class_order int
var file_path string
if err = data.Scan(&id, &class_order, &file_path); err != nil {
return
}
if file_path == "id://" {
f.Write([]byte(id + "," + strconv.Itoa(class_order-offset) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
count += 1
}
//
// This is to load some extra data so that the model has more things to train on
//
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count)
if err != nil {
return
}
defer data_other.Close()
for data_other.Next() {
var id string
var class_order int
var file_path string
if err = data_other.Scan(&id, &class_order, &file_path); err != nil {
return
}
if file_path == "id://" {
f.Write([]byte(id + "," + strconv.Itoa(-1) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
}
return
} }
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) { func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0 accuracy = 0
c.Logger.Warn("About to start training definition") c.Logger.Warn("About to retrain model")
// Get untrained models heads // Get untrained models heads
// Status = 2 (INIT) type ExpHead struct {
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id) Id string
Start int `db:"range_start"`
End int `db:"range_end"`
}
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil { if err != nil {
return return
} } else if len(heads) == 0 {
defer rows.Close()
type ExpHead struct {
id string
start int
end int
}
exp := ExpHead{}
if rows.Next() {
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
return
}
} else {
log.Error("Failed to get the exp head of the model") log.Error("Failed to get the exp head of the model")
err = errors.New("Failed to get the exp head of the model")
return return
} } else if len(heads) != 1 {
if rows.Next() {
log.Error("This training function can only train one model at the time") log.Error("This training function can only train one model at the time")
err = errors.New("This training function can only train one model at the time") err = errors.New("This training function can only train one model at the time")
return return
} }
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING) exp := heads[0]
c.Logger.Info("Got exp head", "head", exp)
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return
}
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id) layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil { if err != nil {
@ -348,8 +412,178 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
got := []layerrow{} got := []layerrow{}
remove_top_count := 1 i := 1
var last *layerrow = nil
got_2 := false
var first *layerrow = nil
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
return
}
// Keep track of the first layer so we can keep the size of the image
if first == nil {
first = &row
}
row.LayerNum = i
row.Shape = shapeToSize(row.Shape)
if row.ExpType == 2 {
if !got_2 {
got = append(got, *last)
got_2 = true
}
got = append(got, row)
}
last = &row
i += 1
}
got = append(got, layerrow{
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2,
LayerNum: i,
})
c.Logger.Info("Got layers", "layers", got)
// Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
classCount, err := generateCvsExpandExp(c, run_path, model.Id, exp.Start, false)
if err != nil {
return
}
c.Logger.Info("Generated cvs", "classCount", classCount)
// TODO update the run script
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
if err != nil {
return
}
defer f.Close()
c.Logger.Info("About to run python!")
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
if err != nil {
return
}
// Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = tmpl.Execute(f, AnyMap{
"Layers": got,
"Size": first.Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.Id,
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN,
"LoadPrev": load_prev,
"BaseModel": path.Join(getDir(), result_path, "base", "model.keras"),
"LastModelRunPath": path.Join(getDir(), result_path, "head", exp.Id, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
"Depth": classCount,
"StartPoint": 0,
}); err != nil {
return
}
// Run the command
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
if err != nil {
c.Logger.Warn("Python failed to run", "err", err, "out", string(out))
return
}
c.Logger.Info("Python finished running")
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return
}
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
if err != nil {
return
}
defer accuracy_file.Close()
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
if err != nil {
return
}
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
if err != nil {
return
}
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy)
return
}
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0
c.Logger.Warn("About to start training definition")
// Get untrained models heads
type ExpHead struct {
Id string
Start int `db:"range_start"`
End int `db:"range_end"`
}
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil {
return
} else if len(heads) == 0 {
log.Error("Failed to get the exp head of the model")
return
} else if len(heads) != 1 {
log.Error("This training function can only train one model at the time")
err = errors.New("This training function can only train one model at the time")
return
}
exp := heads[0]
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return
}
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
return
}
defer layers.Close()
type layerrow struct {
LayerType int
Shape string
ExpType int
LayerNum int
}
got := []layerrow{}
i := 1 i := 1
for layers.Next() { for layers.Next() {
@ -358,9 +592,6 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return return
} }
row.LayerNum = i row.LayerNum = i
if row.ExpType == 2 {
remove_top_count += 1
}
row.Shape = shapeToSize(row.Shape) row.Shape = shapeToSize(row.Shape)
got = append(got, row) got = append(got, row)
i += 1 i += 1
@ -368,19 +599,18 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
got = append(got, layerrow{ got = append(got, layerrow{
LayerType: LAYER_DENSE, LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.end-exp.start+1), Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2, ExpType: 2,
LayerNum: i, LayerNum: i,
}) })
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", definition_id) run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id)
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
return return
} }
defer removeAll(run_path, err)
classCount, err := generateCvsExp(c, run_path, model.Id, false) classCount, err := generateCvsExp(c, run_path, model.Id, false)
if err != nil { if err != nil {
@ -408,16 +638,14 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
"Layers": got, "Layers": got,
"Size": got[0].Shape, "Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"), "DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.id, "HeadId": exp.Id,
"RunPath": run_path, "RunPath": run_path,
"ColorMode": model.ImageMode, "ColorMode": model.ImageMode,
"Model": model, "Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN, "EPOCH_PER_RUN": EPOCH_PER_RUN,
"DefId": definition_id,
"LoadPrev": load_prev, "LoadPrev": load_prev,
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"), "LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path), "SaveModelPath": path.Join(getDir(), result_path),
"RemoveTopCount": remove_top_count,
"Depth": classCount, "Depth": classCount,
"StartPoint": 0, "StartPoint": 0,
}); err != nil { }); err != nil {
@ -453,6 +681,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return return
} }
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy) c.Logger.Info("Model finished training!", "accuracy", accuracy)
return return
} }
@ -762,6 +991,12 @@ func trainModelExp(c *Context, model *BaseModel) {
return return
} }
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2;", MODEL_HEAD_STATUS_READY, def.Id)
if err != nil {
failed("Failed to train definition!")
return
}
finished = true finished = true
break break
} }
@ -823,17 +1058,17 @@ func trainModelExp(c *Context, model *BaseModel) {
} }
} }
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
var dat JustId var dat JustId
err = GetDBOnce(c, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED) err = GetDBOnce(c, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
if err == NotFoundError { if err == NotFoundError {
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("All definitions failed to train! And Failed to set class status")
return
}
failed("All definitions failed to train!") failed("All definitions failed to train!")
return return
} else if err != nil { } else if err != nil {
@ -863,10 +1098,23 @@ func trainModelExp(c *Context, model *BaseModel) {
} }
if err = splitModel(c, model); err != nil { if err = splitModel(c, model); err != nil {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to split the model! And Failed to set class status")
return
}
failed("Failed to split the model") failed("Failed to split the model")
return return
} }
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
// There should only be one def availabale // There should only be one def availabale
def := JustId{} def := JustId{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
@ -884,25 +1132,22 @@ func trainModelExp(c *Context, model *BaseModel) {
func splitModel(c *Context, model *BaseModel) (err error) { func splitModel(c *Context, model *BaseModel) (err error) {
def := JustId{} def := JustId{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil { if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
return return
} }
head := JustId{} head := JustId{}
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil { if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
return return
} }
// Generate run folder // Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", def.Id) run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id+"-split")
err = os.MkdirAll(run_path, os.ModePerm) err = os.MkdirAll(run_path, os.ModePerm)
if err != nil { if err != nil {
return return
} }
defer removeAll(run_path, err)
// Create python script // Create python script
f, err := os.Create(path.Join(run_path, "run.py")) f, err := os.Create(path.Join(run_path, "run.py"))
@ -970,8 +1215,8 @@ func splitModel(c *Context, model *BaseModel) (err error) {
return return
} }
os.RemoveAll(run_path)
c.Logger.Info("Python finished running") c.Logger.Info("Python finished running")
return return
} }
@ -1141,6 +1386,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
} }
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) { func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status) rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
if err != nil { if err != nil {
@ -1306,6 +1552,244 @@ func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy
return nil return nil
} }
func ResetClasses(c *Context, model *BaseModel) {
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.Logger.Error("Error while reseting the classes", "error", err)
}
}
func trainExpandable(c *Context, model *BaseModel) {
var err error = nil
failed := func(msg string) {
c.Logger.Error(msg, "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
ResetClasses(c, model)
}
var definitions TrainModelRowUsables
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
failed("Failed to get definitions")
return
}
if len(definitions) != 1 {
failed("There should only be one definition available!")
return
}
firstRound := true
def := definitions[0]
epoch := 0
for {
acc, err := trainDefinitionExp(c, model, def.Id, !firstRound)
if err != nil {
failed("Failed to train definition!")
return
}
epoch += EPOCH_PER_RUN
if float64(acc*100) >= float64(def.Acuracy) {
c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING)
if err != nil {
failed("Failed to train definition!")
return
}
break
} else if def.Epoch > MAX_EPOCH {
failed(fmt.Sprintf("Failed to train definition! Accuracy less %f < %d\n", acc*100, def.TargetAccuracy))
return
}
}
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
ModelUpdateStatus(c, model.Id, READY)
}
func trainRetrain(c *Context, model *BaseModel, defId string) {
var err error
failed := func() {
ResetClasses(c, model)
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
return
}
// This is something I have to check
acc, err := trainDefinitionExpandExp(c, model, defId, false)
if err != nil {
c.Logger.Error("Failed to retrain the model", "err", err)
failed()
return
}
c.Logger.Info("Retrained model", "accuracy", acc)
// TODO check accuracy
err = UpdateStatus(c, "models", model.Id, READY)
if err != nil {
failed()
return
}
c.Logger.Info("model updaded")
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.Logger.Error("Error while updating the classes", "error", err)
failed()
return
}
}
func handleRetrain(c *Context) *Error {
var err error = nil
if !c.CheckAuthLevel(1) {
return nil
}
var dat JustId
if err_ := c.ToJSON(&dat); err_ != nil {
return err_
}
if dat.Id == "" {
return c.JsonBadRequest("Please provide a id")
}
model, err := GetBaseModel(c.Db, dat.Id)
if err == ModelNotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.Error500(err)
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model in invalid status for re-training")
}
c.Logger.Info("Expanding definitions for models", "id", model.Id)
classesUpdated := false
failed := func() *Error {
if classesUpdated {
ResetClasses(c, model)
}
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
// TODO improve this response
return c.Error500(err)
}
var def struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
}
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
if err != nil {
return failed()
}
type C struct {
Id string
ClassOrder int `db:"class_order"`
}
err = c.StartTx()
if err != nil {
return failed()
}
classes, err := GetDbMultitple[C](
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
MODEL_CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
if len(classes) == 0 {
c.Logger.Error("No classes are available!")
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
err = c.CommitTx()
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
classesUpdated = true
}
_, err = CreateExpModelHead(c, def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
return failed()
}
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil {
fmt.Println("Failed to update model status")
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
}
return c.SendJSON(model.Id)
}
func handleTrain(handle *Handle) { func handleTrain(handle *Handle) {
handle.Post("/models/train", func(c *Context) *Error { handle.Post("/models/train", func(c *Context) *Error {
if !c.CheckAuthLevel(1) { if !c.CheckAuthLevel(1) {
@ -1374,6 +1858,8 @@ func handleTrain(handle *Handle) {
return c.SendJSON(model.Id) return c.SendJSON(model.Id)
}) })
handle.Post("/model/train/retrain", handleRetrain)
handle.Get("/model/epoch/update", func(c *Context) *Error { handle.Get("/model/epoch/update", func(c *Context) *Error {
f := c.R.URL.Query() f := c.R.URL.Query()

View File

@ -29,7 +29,10 @@ const (
TRAINING = 4 TRAINING = 4
READY = 5 READY = 5
READY_ALTERATION = 6 READY_ALTERATION = 6
READY_FAILED = -6 READY_ALTERATION_FAILED = -6
READY_RETRAIN = 7
READY_RETRAIN_FAILED = -7
) )
type ModelDefinitionStatus int type ModelDefinitionStatus int
@ -62,6 +65,16 @@ const (
MODEL_CLASS_STATUS_TRAINED = 3 MODEL_CLASS_STATUS_TRAINED = 3
) )
type ModelHeadStatus int
const (
MODEL_HEAD_STATUS_PRE_INIT ModelHeadStatus = 1
MODEL_HEAD_STATUS_INIT = 2
MODEL_HEAD_STATUS_TRAINING = 3
MODEL_HEAD_STATUS_TRAINED = 4
MODEL_HEAD_STATUS_READY = 5
)
var ModelNotFoundError = errors.New("Model not found error") var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {

View File

@ -89,6 +89,7 @@ func (x *Handle) handleGets(context *Context) {
return return
} }
} }
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context) handleError(&Error{404, "Endpoint not found"}, context)
} }
@ -99,6 +100,7 @@ func (x *Handle) handlePosts(context *Context) {
return return
} }
} }
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context) handleError(&Error{404, "Endpoint not found"}, context)
} }
@ -109,6 +111,7 @@ func (x *Handle) handleDeletes(context *Context) {
return return
} }
} }
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context) handleError(&Error{404, "Endpoint not found"}, context)
} }
@ -133,6 +136,52 @@ type Context struct {
Db *sql.DB Db *sql.DB
Writer http.ResponseWriter Writer http.ResponseWriter
R *http.Request R *http.Request
Tx *sql.Tx
ShowMessage bool
}
func (c Context) Prepare(str string) (*sql.Stmt, error) {
if c.Tx == nil {
return c.Db.Prepare(str)
}
return c.Tx.Prepare(str)
}
var TransactionAlreadyStarted = errors.New("Transaction already started")
var TransactionNotStarted = errors.New("Transaction not started")
func (c *Context) StartTx() error {
if c.Tx != nil {
return TransactionAlreadyStarted
}
var err error = nil
c.Tx, err = c.Db.Begin()
return err
}
func (c *Context) CommitTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Commit()
if err != nil {
return err
}
c.Tx = nil
return nil
}
func (c *Context) RollbackTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Rollback()
if err != nil {
return err
}
c.Tx = nil
return nil
} }
func (c Context) ToJSON(dat any) *Error { func (c Context) ToJSON(dat any) *Error {
@ -270,6 +319,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
Db: handler.Db, Db: handler.Db,
Writer: w, Writer: w,
R: r, R: r,
ShowMessage: true,
}, nil }, nil
} }
@ -278,7 +328,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
return nil, errors.Join(err, LogoffError) return nil, errors.Join(err, LogoffError)
} }
return &Context{token, user, logger, handler.Db, w, r}, nil return &Context{token, user, logger, handler.Db, w, r, nil, true}, nil
} }
func contextlessLogoff(w http.ResponseWriter) { func contextlessLogoff(w http.ResponseWriter) {
@ -457,20 +507,19 @@ func NewHandler(db *sql.DB) *Handle {
if r.Method == "GET" { if r.Method == "GET" {
x.handleGets(context) x.handleGets(context)
return } else if r.Method == "POST" {
}
if r.Method == "POST" {
x.handlePosts(context) x.handlePosts(context)
return } else if r.Method == "DELETE" {
}
if r.Method == "DELETE" {
x.handleDeletes(context) x.handleDeletes(context)
return } else if r.Method == "OPTIONS" {
} // do nothing
if r.Method == "OPTIONS" { } else {
return
}
panic("TODO handle method: " + r.Method) panic("TODO handle method: " + r.Method)
}
if context.ShowMessage {
context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
}
}) })
return x return x

View File

@ -189,6 +189,7 @@ type JustId struct { Id string }
type Generic struct{ reflect.Type } type Generic struct{ reflect.Type }
var NotFoundError = errors.New("Not found") var NotFoundError = errors.New("Not found")
var CouldNotInsert = errors.New("Could not insert")
func generateQuery(t reflect.Type) (query string, nargs int) { func generateQuery(t reflect.Type) (query string, nargs int) {
nargs = t.NumField() nargs = t.NumField()
@ -200,6 +201,10 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
if !ok { if !ok {
name = field.Name; name = field.Name;
} }
if name == "__nil__" {
continue
}
query += strings.ToLower(name) + "," query += strings.ToLower(name) + ","
} }
@ -214,7 +219,13 @@ func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([
query, nargs := generateQuery(t) query, nargs := generateQuery(t)
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
if err != nil {
return nil, err
}
defer db_query.Close()
rows, err := db_query.Query(args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -251,6 +262,43 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
return nil return nil
} }
func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) {
t := reflect.TypeOf(store).Elem()
query, nargs := generateQuery(t)
query2 := ""
for i := 0; i < nargs; i += 1 {
query2 += fmt.Sprintf("$%d,", i)
}
// Remove last quotation
query2 = query2[0 : len(query2)-1]
val := reflect.ValueOf(store).Elem()
scan_args := make([]interface{}, nargs);
for i := 0; i < nargs; i++ {
valueField := val.Field(i)
scan_args[i] = valueField.Interface()
}
rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return "", CouldNotInsert
}
err = rows.Scan(&id)
if err != nil {
return
}
return
}
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error { func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem() t := reflect.TypeOf(store).Elem()

View File

@ -0,0 +1,223 @@
import tensorflow as tf
import random
import pandas as pd
from tensorflow import keras
from tensorflow.data import AUTOTUNE
from keras import layers, losses, optimizers
import requests
import numpy as np
class NotifyServerCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, log, *args, **kwargs):
requests.get(f'http://localhost:8000/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}')
DATA_DIR = "{{ .DataDir }}"
image_size = ({{ .Size }})
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
keys = tf.constant(df['Id'].dropna())
values = tf.constant(list(map(int, df['Index'].dropna())))
depth = {{ .Depth }}
diff = {{ .StartPoint }}
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=keys,
values=values,
),
default_value=tf.constant(-1),
name="Indexes"
)
DATA_DIR_PREPARE = DATA_DIR + "/"
# based on https://www.tensorflow.org/tutorials/load_data/images
def pathToLabel(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }}
ERROR
{{ end }}
num = table.lookup(tf.strings.as_string([path]))
return tf.cond(
tf.math.equal(num, tf.constant(-1)),
lambda: tf.zeros([depth]),
lambda: tf.one_hot(table.lookup(tf.strings.as_string([path])) - diff, depth)[0]
)
old_model = keras.models.load_model("{{ .BaseModel }}")
def decode_image(img):
{{ if eq .Model.Format "png" }}
img = tf.io.decode_png(img, channels={{.ColorMode}})
{{ else if eq .Model.Format "jpeg" }}
img = tf.io.decode_jpeg(img, channels={{.ColorMode}})
{{ else }}
ERROR
{{ end }}
return tf.image.resize(img, image_size)
def process_path(path):
label = pathToLabel(path)
img = tf.io.read_file(path)
img = decode_image(img)
return img, label
def configure_for_performance(ds: tf.data.Dataset, size: int, shuffle: bool) -> tf.data.Dataset:
# ds = ds.cache()
if shuffle:
ds = ds.shuffle(buffer_size=size)
ds = ds.batch(batch_size)
ds = ds.prefetch(AUTOTUNE)
return ds
def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
ds = configure_for_performance(ds, size, False)
return ds
def filterDataset(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }}
ERROR
{{ end }}
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
seed = random.randint(0, 100000000)
batch_size = 64
# Read all the files from the direcotry
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
list_ds = list_ds.filter(filterDataset)
image_count = len(list(list_ds.as_numpy_iterator()))
list_ds = list_ds.shuffle(image_count, seed=seed)
val_size = int(image_count * 0.3)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)
dataset = prepare_dataset(train_ds, image_count)
dataset_validation = prepare_dataset(val_ds, val_size)
track = 0
def addBlock(
b_size: int,
filter_size: int,
kernel_size: int = 3,
top: bool = True,
pooling_same: bool = False,
pool_func=layers.MaxPool2D
):
global track
model = keras.Sequential(
name=f"{track}-{b_size}-{filter_size}-{kernel_size}"
)
track += 1
for _ in range(b_size):
model.add(layers.Conv2D(
filter_size,
kernel_size,
padding="same"
))
model.add(layers.ReLU())
if top:
if pooling_same:
model.add(pool_func(padding="same", strides=(1, 1)))
else:
model.add(pool_func())
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.4))
return model
# Proccess old data
new_data = old_model.predict(dataset)
labels = np.concatenate([y for _, y in dataset], axis=0)
new_data = tf.data.Dataset.from_tensor_slices(
(new_data, labels))
new_data = configure_for_performance(new_data, batch_size, True)
new_data_val = old_model.predict(dataset_validation)
labels_val = np.concatenate([y for _, y in dataset_validation], axis=0)
new_data_val = tf.data.Dataset.from_tensor_slices(
(new_data_val, labels_val))
new_data_val = configure_for_performance(new_data_val, batch_size, True)
{{ if .LoadPrev }}
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
{{ else }}
model = keras.Sequential()
{{- range .Layers }}
{{- if eq .LayerType 1}}
model.add(layers.Rescaling(1./255))
{{- else if eq .LayerType 2 }}
model.add(layers.Dense({{ .Shape }}, activation="sigmoid"))
{{- else if eq .LayerType 3}}
model.add(layers.Flatten())
{{- else if eq .LayerType 4}}
model.add(addBlock(2, 128, 3, pool_func=layers.AveragePooling2D))
{{- else }}
ERROR
{{- end }}
{{- end }}
{{ end }}
model.layers[0]._name = "head"
model.compile(
loss=losses.BinaryCrossentropy(from_logits=False),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
his = model.fit(
new_data,
validation_data=new_data_val,
epochs={{.EPOCH_PER_RUN}},
callbacks=[
NotifyServerCallback(),
tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)
],
use_multiprocessing=True
)
acc = his.history["accuracy"]
f = open("accuracy.val", "w")
f.write(str(acc[-1]))
f.close()
tf.saved_model.save(model, "{{ .SaveModelPath }}/model")
model.save("{{ .SaveModelPath }}/model.keras")

View File

@ -17,7 +17,7 @@ split_len = {{ .SplitLen }}
bottom_input = Input(model.input_shape[1:]) bottom_input = Input(model.input_shape[1:])
bottom_output = bottom_input bottom_output = bottom_input
top_input = Input(model.layers[split_len + 1].input_shape[1:]) top_input = Input(model.layers[split_len + 1].input_shape[1:], name="head_input")
top_output = top_input top_output = top_input
for i, layer in enumerate(model.layers): for i, layer in enumerate(model.layers):

View File

@ -146,14 +146,7 @@
<!-- TODO improve message --> <!-- TODO improve message -->
<h2 class="text-center">Failed to prepare model</h2> <h2 class="text-center">Failed to prepare model</h2>
<div>TODO button delete</div> <DeleteModel model={m} />
<!--form hx-delete="/models/delete">
<input type="hidden" name="id" value="{{ .Model.Id }}" />
<button class="danger">
Delete
</button>
</form-->
</div> </div>
<!-- PRE TRAINING STATUS --> <!-- PRE TRAINING STATUS -->
{:else if m.status == 2} {:else if m.status == 2}
@ -288,7 +281,7 @@
{/await} {/await}
<!-- TODO Add ability to stop training --> <!-- TODO Add ability to stop training -->
</div> </div>
{:else if [5, 6, -6].includes(m.status)} {:else if [5, 6, -6, 7, -7].includes(m.status)}
<BaseModelInfo model={m} /> <BaseModelInfo model={m} />
<RunModel model={m} /> <RunModel model={m} />
{#if m.status == 6} {#if m.status == 6}
@ -299,6 +292,13 @@
{#if m.status == -6} {#if m.status == -6}
<DeleteZip model={m} on:reload={getModel} expand /> <DeleteZip model={m} on:reload={getModel} expand />
{/if} {/if}
{#if m.status == -7}
<form>
<!-- TODO add more info about the failure -->
Failed to train the model!
Try to retrain
</form>
{/if}
{#if m.model_type == 2} {#if m.model_type == 2}
<ModelData model={m} on:reload={getModel} /> <ModelData model={m} on:reload={getModel} />
{/if} {/if}

View File

@ -142,6 +142,7 @@
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={uploadImage} /> <MessageSimple bind:this={uploadImage} />
{#if file}
{#await uploading} {#await uploading}
<button disabled> <button disabled>
Uploading Uploading
@ -151,6 +152,7 @@
Add Add
</button> </button>
{/await} {/await}
{/if}
</form> </form>
</div> </div>
<div class="content" class:selected={isActive("create-class")}> <div class="content" class:selected={isActive("create-class")}>
@ -190,4 +192,6 @@
{/if} {/if}
</div> </div>
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} /> {#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)}
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} />
{/if}

View File

@ -106,6 +106,15 @@
class:selected={isActive(item.name)} class:selected={isActive(item.name)}
> >
{item.name} {item.name}
{#if model.model_type == 2}
{#if item.status == 1}
<span class="bi bi-book" style="color: orange;" />
{:else if item.status == 2}
<span class="bi bi-book" style="color: green;" />
{:else if item.status == 3}
<span class="bi bi-check" style="color: green;" />
{/if}
{/if}
</button> </button>
{/each} {/each}
</div> </div>
@ -170,6 +179,7 @@
</FileUpload> </FileUpload>
</fieldset> </fieldset>
<MessageSimple bind:this={uploadImage} /> <MessageSimple bind:this={uploadImage} />
{#if file}
{#await uploading} {#await uploading}
<button disabled> <button disabled>
Uploading Uploading
@ -179,6 +189,7 @@
Add Add
</button> </button>
{/await} {/await}
{/if}
</form> </form>
</div> </div>
{/if} {/if}

View File

@ -23,6 +23,7 @@
let messages: MessageSimple; let messages: MessageSimple;
async function submit() { async function submit() {
messages.clear();
submitted = true; submitted = true;
try { try {
await post('models/train', { await post('models/train', {
@ -34,13 +35,29 @@
if (e instanceof Response) { if (e instanceof Response) {
messages.display(await e.json()); messages.display(await e.json());
} else { } else {
messages.display("Could not start the training of the model"); messages.display('Could not start the training of the model');
}
}
}
async function submitRetrain() {
messages.clear();
submitted = true;
try {
await post('model/train/retrain', { id: model.id });
dispatch('reload');
} catch (e) {
if (e instanceof Response) {
messages.display(await e.json());
} else {
messages.display('Could not start the training of the model');
} }
} }
} }
</script> </script>
<form class:submitted on:submit|preventDefault={submit}> {#if model.status == 2}
<form class:submitted on:submit|preventDefault={submit}>
{#if has_data} {#if has_data}
{#if number_of_invalid_images > 0} {#if number_of_invalid_images > 0}
<p class="danger"> <p class="danger">
@ -92,4 +109,23 @@
{:else} {:else}
<h2>To train the model please provide data to the model first</h2> <h2>To train the model please provide data to the model first</h2>
{/if} {/if}
</form> </form>
{:else}
<form class:submitted on:submit|preventDefault={submitRetrain}>
{#if has_data}
<h2>
This model has new classes and can be expanded
</h2>
{#if number_of_invalid_images > 0}
<p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be delete when the model trains.
</p>
{/if}
<MessageSimple bind:this={messages} />
<button> Retrain </button>
{:else}
<h2>To train the model please provide data to the model first</h2>
{/if}
</form>
{/if}