added the ability to expand the models

This commit is contained in:
Andre Henriques 2024-04-08 14:17:13 +01:00
parent 274d7d22aa
commit de0b430467
15 changed files with 1086 additions and 197 deletions

View File

@ -60,6 +60,7 @@ func HandleList(handle *Handle) {
max_len := min(11, len(rows))
c.ShowMessage = false;
return c.SendJSON(ReturnType{
ImageList: rows[0:max_len],
Page: page,

View File

@ -156,7 +156,7 @@ func processZipFileExpand(c *Context, model *BaseModel) {
failed := func(msg string) {
c.Logger.Error(msg, "err", err)
ModelUpdateStatus(c, model.Id, READY_FAILED)
ModelUpdateStatus(c, model.Id, READY_ALTERATION_FAILED)
}
reader, err := zip.OpenReader(path.Join("savedData", model.Id, "expand_data.zip"))
@ -202,8 +202,19 @@ func processZipFileExpand(c *Context, model *BaseModel) {
ids := map[string]string{}
var baseOrder struct {
Order int `db:"class_order"`
}
err = GetDBOnce(c, &baseOrder, "model_classes where model_id=$1 order by class_order desc;", model.Id)
if err != nil {
failed("Failed to get the last class_order")
}
base := baseOrder.Order + 1
for i, name := range training {
id, err := model_classes.CreateClass(c.Db, model.Id, i, name)
id, err := model_classes.CreateClass(c.Db, model.Id, base + i, name)
if err != nil {
failed(fmt.Sprintf("Failed to create class '%s' on db\n", name))
return
@ -416,7 +427,7 @@ func handleDataUpload(handle *Handle) {
delete_path := "base_data.zip"
if model.Status == READY_FAILED {
if model.Status == READY_ALTERATION_FAILED {
delete_path = "expand_data.zip"
} else if model.Status != FAILED_PREPARING_ZIP_FILE {
return c.JsonBadRequest("Model not in the correct status")
@ -427,7 +438,7 @@ func handleDataUpload(handle *Handle) {
return c.Error500(err)
}
if model.Status != READY_FAILED {
if model.Status != READY_ALTERATION_FAILED {
err = os.RemoveAll(path.Join("savedData", model.Id, "data"))
if err != nil {
return c.Error500(err)
@ -436,7 +447,7 @@ func handleDataUpload(handle *Handle) {
c.Logger.Warn("Handle failed to remove the savedData when deleteing the zip file while expanding")
}
if model.Status != READY_FAILED {
if model.Status != READY_ALTERATION_FAILED {
_, err = handle.Db.Exec("delete from model_classes where model_id=$1;", model.Id)
if err != nil {
return c.Error500(err)
@ -448,7 +459,7 @@ func handleDataUpload(handle *Handle) {
}
}
if model.Status != READY_FAILED {
if model.Status != READY_ALTERATION_FAILED {
ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
} else {
ModelUpdateStatus(c, model.Id, READY)

View File

@ -29,7 +29,7 @@ func deleteModelJSON(c *Context, id string) *Error {
func handleDelete(handle *Handle) {
handle.Delete("/models/delete", func(c *Context) *Error {
if c.CheckAuthLevel(1) {
if !c.CheckAuthLevel(1) {
return nil
}
var dat struct {
@ -66,6 +66,10 @@ func handleDelete(handle *Handle) {
case READY:
fallthrough
case READY_RETRAIN_FAILED:
fallthrough
case READY_ALTERATION_FAILED:
fallthrough
case CONFIRM_PRE_TRAINING:
if dat.Name == nil {
return c.JsonBadRequest("Provided name does not match the model name")

View File

@ -41,6 +41,7 @@ func handleEdit(handle *Handle) {
NumberOfInvalidImages int `json:"number_of_invalid_images"`
}
c.ShowMessage = false;
return c.SendJSON(ReturnType{
Classes: cls,
HasData: has_data,
@ -179,7 +180,8 @@ func handleEdit(handle *Handle) {
Layers: lay,
}
}
c.ShowMessage = false;
return c.SendJSON(defsToReturn)
})
@ -188,10 +190,6 @@ func handleEdit(handle *Handle) {
return nil
}
if !c.CheckAuthLevel(1) {
return nil
}
id, err := GetIdFromUrl(c, "id")
if err != nil {
return c.JsonBadRequest("Model not found")
@ -215,7 +213,8 @@ func handleEdit(handle *Handle) {
} else if err != nil {
return c.Error500(err)
}
c.ShowMessage = false
return c.SendJSON(model)
})
}

View File

@ -87,6 +87,8 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
return
}
c.Logger.Info("test", "count", len(heads))
var vmax float32 = 0.0
for _, element := range heads {
@ -95,12 +97,14 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
results := head_model.Exec([]tf.Output{
head_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
head_model.Op("serving_default_input_2", 0): base_results[0],
head_model.Op("serving_default_head_input", 0): base_results[0],
})
var predictions = results[0].Value().([][]float32)[0]
for i, v := range predictions {
c.Logger.Info("predictions", "class", i, "preds", v)
if v > vmax {
order = element.Range_start + i
vmax = v
@ -111,7 +115,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
// TODO runthe head model
confidence = vmax
c.Logger.Info("Got", "heads", len(heads))
c.Logger.Info("Got", "heads", len(heads), "order", order, "vmax", vmax)
return
}
@ -155,7 +159,7 @@ func handleRun(handle *Handle) {
return c.Error500(err)
}
if model.Status != READY {
if model.Status != READY && model.Status != READY_RETRAIN && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model not ready to run images")
}

View File

@ -36,13 +36,16 @@ func getDir() string {
return dir
}
// This function creates a new model_definition
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = ""
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return id, errors.New("Something wrong!")
}
@ -72,17 +75,14 @@ func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type
func generateCvs(c *Context, run_path string, model_id string) (count int, err error) {
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id)
var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1;", model_id)
if err != nil {
return
}
defer classes.Close()
if !classes.Next() {
return
}
if err = classes.Scan(&count); err != nil {
return
}
count = co.Count
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING)
if err != nil {
@ -121,19 +121,14 @@ func setModelClassStatus(c *Context, status ModelClassStatus, filter string, arg
func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) {
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
defer classes.Close()
if !classes.Next() {
return
}
if err = classes.Scan(&count); err != nil {
return
}
count = co.Count
if count == 0 {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
@ -214,7 +209,6 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
if err != nil {
return
}
defer removeAll(run_path, err)
classCount, err := generateCvs(c, run_path, model.Id)
if err != nil {
@ -283,55 +277,125 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
return
}
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy)
return
}
func removeAll(path string, err error) {
if err != nil {
os.RemoveAll(path)
func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
c.Logger.Info("test here", "count", co)
count_re = co.Count
count := co.Count
if count == 0 {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
if err != nil {
return
} else if doPanic {
return 0, errors.New("No model classes available")
}
return generateCvsExpandExp(c, run_path, model_id, offset, true)
}
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
defer data.Close()
f, err := os.Create(path.Join(run_path, "train.csv"))
if err != nil {
return
}
defer f.Close()
f.Write([]byte("Id,Index\n"))
count = 0
for data.Next() {
var id string
var class_order int
var file_path string
if err = data.Scan(&id, &class_order, &file_path); err != nil {
return
}
if file_path == "id://" {
f.Write([]byte(id + "," + strconv.Itoa(class_order-offset) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
count += 1
}
//
// This is to load some extra data so that the model has more things to train on
//
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count)
if err != nil {
return
}
defer data_other.Close()
for data_other.Next() {
var id string
var class_order int
var file_path string
if err = data_other.Scan(&id, &class_order, &file_path); err != nil {
return
}
if file_path == "id://" {
f.Write([]byte(id + "," + strconv.Itoa(-1) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
}
return
}
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0
c.Logger.Warn("About to start training definition")
c.Logger.Warn("About to retrain model")
// Get untrained models heads
// Status = 2 (INIT)
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
type ExpHead struct {
Id string
Start int `db:"range_start"`
End int `db:"range_end"`
}
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil {
return
}
defer rows.Close()
type ExpHead struct {
id string
start int
end int
}
exp := ExpHead{}
if rows.Next() {
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
return
}
} else {
} else if len(heads) == 0 {
log.Error("Failed to get the exp head of the model")
err = errors.New("Failed to get the exp head of the model")
return
}
if rows.Next() {
} else if len(heads) != 1 {
log.Error("This training function can only train one model at the time")
err = errors.New("This training function can only train one model at the time")
return
}
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
exp := heads[0]
c.Logger.Info("Got exp head", "head", exp)
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return
}
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
@ -348,8 +412,178 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
got := []layerrow{}
remove_top_count := 1
i := 1
var last *layerrow = nil
got_2 := false
var first *layerrow = nil
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
return
}
// Keep track of the first layer so we can keep the size of the image
if first == nil {
first = &row
}
row.LayerNum = i
row.Shape = shapeToSize(row.Shape)
if row.ExpType == 2 {
if !got_2 {
got = append(got, *last)
got_2 = true
}
got = append(got, row)
}
last = &row
i += 1
}
got = append(got, layerrow{
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2,
LayerNum: i,
})
c.Logger.Info("Got layers", "layers", got)
// Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
classCount, err := generateCvsExpandExp(c, run_path, model.Id, exp.Start, false)
if err != nil {
return
}
c.Logger.Info("Generated cvs", "classCount", classCount)
// TODO update the run script
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
if err != nil {
return
}
defer f.Close()
c.Logger.Info("About to run python!")
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
if err != nil {
return
}
// Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = tmpl.Execute(f, AnyMap{
"Layers": got,
"Size": first.Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.Id,
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN,
"LoadPrev": load_prev,
"BaseModel": path.Join(getDir(), result_path, "base", "model.keras"),
"LastModelRunPath": path.Join(getDir(), result_path, "head", exp.Id, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
"Depth": classCount,
"StartPoint": 0,
}); err != nil {
return
}
// Run the command
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
if err != nil {
c.Logger.Warn("Python failed to run", "err", err, "out", string(out))
return
}
c.Logger.Info("Python finished running")
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return
}
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
if err != nil {
return
}
defer accuracy_file.Close()
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
if err != nil {
return
}
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
if err != nil {
return
}
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy)
return
}
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0
c.Logger.Warn("About to start training definition")
// Get untrained models heads
type ExpHead struct {
Id string
Start int `db:"range_start"`
End int `db:"range_end"`
}
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil {
return
} else if len(heads) == 0 {
log.Error("Failed to get the exp head of the model")
return
} else if len(heads) != 1 {
log.Error("This training function can only train one model at the time")
err = errors.New("This training function can only train one model at the time")
return
}
exp := heads[0]
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return
}
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
return
}
defer layers.Close()
type layerrow struct {
LayerType int
Shape string
ExpType int
LayerNum int
}
got := []layerrow{}
i := 1
for layers.Next() {
@ -358,9 +592,6 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return
}
row.LayerNum = i
if row.ExpType == 2 {
remove_top_count += 1
}
row.Shape = shapeToSize(row.Shape)
got = append(got, row)
i += 1
@ -368,19 +599,18 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
got = append(got, layerrow{
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.end-exp.start+1),
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2,
LayerNum: i,
})
// Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id)
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
defer removeAll(run_path, err)
classCount, err := generateCvsExp(c, run_path, model.Id, false)
if err != nil {
@ -408,16 +638,14 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
"Layers": got,
"Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.id,
"HeadId": exp.Id,
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN,
"DefId": definition_id,
"LoadPrev": load_prev,
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path),
"RemoveTopCount": remove_top_count,
"Depth": classCount,
"StartPoint": 0,
}); err != nil {
@ -453,6 +681,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return
}
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy)
return
}
@ -762,6 +991,12 @@ func trainModelExp(c *Context, model *BaseModel) {
return
}
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2;", MODEL_HEAD_STATUS_READY, def.Id)
if err != nil {
failed("Failed to train definition!")
return
}
finished = true
break
}
@ -823,17 +1058,17 @@ func trainModelExp(c *Context, model *BaseModel) {
}
}
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
var dat JustId
err = GetDBOnce(c, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("All definitions failed to train! And Failed to set class status")
return
}
failed("All definitions failed to train!")
return
} else if err != nil {
@ -863,10 +1098,23 @@ func trainModelExp(c *Context, model *BaseModel) {
}
if err = splitModel(c, model); err != nil {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to split the model! And Failed to set class status")
return
}
failed("Failed to split the model")
return
}
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
// There should only be one def availabale
def := JustId{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
@ -884,25 +1132,22 @@ func trainModelExp(c *Context, model *BaseModel) {
func splitModel(c *Context, model *BaseModel) (err error) {
def := JustId{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
return
}
head := JustId{}
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
return
}
// Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", def.Id)
run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id+"-split")
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
defer removeAll(run_path, err)
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
@ -970,8 +1215,8 @@ func splitModel(c *Context, model *BaseModel) (err error) {
return
}
os.RemoveAll(run_path)
c.Logger.Info("Python finished running")
return
}
@ -1141,6 +1386,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
}
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
if err != nil {
@ -1306,6 +1552,244 @@ func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy
return nil
}
func ResetClasses(c *Context, model *BaseModel) {
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.Logger.Error("Error while reseting the classes", "error", err)
}
}
func trainExpandable(c *Context, model *BaseModel) {
var err error = nil
failed := func(msg string) {
c.Logger.Error(msg, "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
ResetClasses(c, model)
}
var definitions TrainModelRowUsables
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
failed("Failed to get definitions")
return
}
if len(definitions) != 1 {
failed("There should only be one definition available!")
return
}
firstRound := true
def := definitions[0]
epoch := 0
for {
acc, err := trainDefinitionExp(c, model, def.Id, !firstRound)
if err != nil {
failed("Failed to train definition!")
return
}
epoch += EPOCH_PER_RUN
if float64(acc*100) >= float64(def.Acuracy) {
c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING)
if err != nil {
failed("Failed to train definition!")
return
}
break
} else if def.Epoch > MAX_EPOCH {
failed(fmt.Sprintf("Failed to train definition! Accuracy less %f < %d\n", acc*100, def.TargetAccuracy))
return
}
}
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
ModelUpdateStatus(c, model.Id, READY)
}
func trainRetrain(c *Context, model *BaseModel, defId string) {
var err error
failed := func() {
ResetClasses(c, model)
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
return
}
// This is something I have to check
acc, err := trainDefinitionExpandExp(c, model, defId, false)
if err != nil {
c.Logger.Error("Failed to retrain the model", "err", err)
failed()
return
}
c.Logger.Info("Retrained model", "accuracy", acc)
// TODO check accuracy
err = UpdateStatus(c, "models", model.Id, READY)
if err != nil {
failed()
return
}
c.Logger.Info("model updaded")
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.Logger.Error("Error while updating the classes", "error", err)
failed()
return
}
}
func handleRetrain(c *Context) *Error {
var err error = nil
if !c.CheckAuthLevel(1) {
return nil
}
var dat JustId
if err_ := c.ToJSON(&dat); err_ != nil {
return err_
}
if dat.Id == "" {
return c.JsonBadRequest("Please provide a id")
}
model, err := GetBaseModel(c.Db, dat.Id)
if err == ModelNotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.Error500(err)
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model in invalid status for re-training")
}
c.Logger.Info("Expanding definitions for models", "id", model.Id)
classesUpdated := false
failed := func() *Error {
if classesUpdated {
ResetClasses(c, model)
}
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
// TODO improve this response
return c.Error500(err)
}
var def struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
}
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
if err != nil {
return failed()
}
type C struct {
Id string
ClassOrder int `db:"class_order"`
}
err = c.StartTx()
if err != nil {
return failed()
}
classes, err := GetDbMultitple[C](
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
MODEL_CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
if len(classes) == 0 {
c.Logger.Error("No classes are available!")
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
err = c.CommitTx()
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
classesUpdated = true
}
_, err = CreateExpModelHead(c, def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
return failed()
}
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil {
fmt.Println("Failed to update model status")
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
}
return c.SendJSON(model.Id)
}
func handleTrain(handle *Handle) {
handle.Post("/models/train", func(c *Context) *Error {
if !c.CheckAuthLevel(1) {
@ -1374,6 +1858,8 @@ func handleTrain(handle *Handle) {
return c.SendJSON(model.Id)
})
handle.Post("/model/train/retrain", handleRetrain)
handle.Get("/model/epoch/update", func(c *Context) *Error {
f := c.R.URL.Query()

View File

@ -23,13 +23,16 @@ const (
FAILED_PREPARING_ZIP_FILE = -2
FAILED_PREPARING = -1
PREPARING = 1
CONFIRM_PRE_TRAINING = 2
PREPARING_ZIP_FILE = 3
TRAINING = 4
READY = 5
READY_ALTERATION = 6
READY_FAILED = -6
PREPARING = 1
CONFIRM_PRE_TRAINING = 2
PREPARING_ZIP_FILE = 3
TRAINING = 4
READY = 5
READY_ALTERATION = 6
READY_ALTERATION_FAILED = -6
READY_RETRAIN = 7
READY_RETRAIN_FAILED = -7
)
type ModelDefinitionStatus int
@ -62,6 +65,16 @@ const (
MODEL_CLASS_STATUS_TRAINED = 3
)
type ModelHeadStatus int
const (
MODEL_HEAD_STATUS_PRE_INIT ModelHeadStatus = 1
MODEL_HEAD_STATUS_INIT = 2
MODEL_HEAD_STATUS_TRAINING = 3
MODEL_HEAD_STATUS_TRAINED = 4
MODEL_HEAD_STATUS_READY = 5
)
var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {

View File

@ -89,6 +89,7 @@ func (x *Handle) handleGets(context *Context) {
return
}
}
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context)
}
@ -99,6 +100,7 @@ func (x *Handle) handlePosts(context *Context) {
return
}
}
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context)
}
@ -109,6 +111,7 @@ func (x *Handle) handleDeletes(context *Context) {
return
}
}
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context)
}
@ -127,12 +130,58 @@ func (c *Context) CheckAuthLevel(authLevel int) bool {
}
type Context struct {
Token *string
User *dbtypes.User
Logger *log.Logger
Db *sql.DB
Writer http.ResponseWriter
R *http.Request
Token *string
User *dbtypes.User
Logger *log.Logger
Db *sql.DB
Writer http.ResponseWriter
R *http.Request
Tx *sql.Tx
ShowMessage bool
}
func (c Context) Prepare(str string) (*sql.Stmt, error) {
if c.Tx == nil {
return c.Db.Prepare(str)
}
return c.Tx.Prepare(str)
}
var TransactionAlreadyStarted = errors.New("Transaction already started")
var TransactionNotStarted = errors.New("Transaction not started")
func (c *Context) StartTx() error {
if c.Tx != nil {
return TransactionAlreadyStarted
}
var err error = nil
c.Tx, err = c.Db.Begin()
return err
}
func (c *Context) CommitTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Commit()
if err != nil {
return err
}
c.Tx = nil
return nil
}
func (c *Context) RollbackTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Rollback()
if err != nil {
return err
}
c.Tx = nil
return nil
}
func (c Context) ToJSON(dat any) *Error {
@ -200,14 +249,14 @@ func (c *Context) GetModelFromId(id_path string) (*BaseModel, *Error) {
return nil, c.Error500(err)
}
return model, nil
return model, nil
}
func ModelUpdateStatus(c *Context, id string, status int) {
_, err := c.Db.Exec("update models set status=$1 where id=$2;", status, id)
if err != nil {
c.Logger.Error("Failed to update model status", "err", err)
c.Logger.Warn("TODO Maybe handle better")
c.Logger.Error("Failed to update model status", "err", err)
c.Logger.Warn("TODO Maybe handle better")
}
}
@ -270,6 +319,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
Db: handler.Db,
Writer: w,
R: r,
ShowMessage: true,
}, nil
}
@ -278,7 +328,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
return nil, errors.Join(err, LogoffError)
}
return &Context{token, user, logger, handler.Db, w, r}, nil
return &Context{token, user, logger, handler.Db, w, r, nil, true}, nil
}
func contextlessLogoff(w http.ResponseWriter) {
@ -457,20 +507,19 @@ func NewHandler(db *sql.DB) *Handle {
if r.Method == "GET" {
x.handleGets(context)
return
}
if r.Method == "POST" {
} else if r.Method == "POST" {
x.handlePosts(context)
return
}
if r.Method == "DELETE" {
} else if r.Method == "DELETE" {
x.handleDeletes(context)
return
}
if r.Method == "OPTIONS" {
return
}
panic("TODO handle method: " + r.Method)
} else if r.Method == "OPTIONS" {
// do nothing
} else {
panic("TODO handle method: " + r.Method)
}
if context.ShowMessage {
context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
}
})
return x

View File

@ -189,6 +189,7 @@ type JustId struct { Id string }
type Generic struct{ reflect.Type }
var NotFoundError = errors.New("Not found")
var CouldNotInsert = errors.New("Could not insert")
func generateQuery(t reflect.Type) (query string, nargs int) {
nargs = t.NumField()
@ -200,6 +201,10 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
if !ok {
name = field.Name;
}
if name == "__nil__" {
continue
}
query += strings.ToLower(name) + ","
}
@ -214,7 +219,13 @@ func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([
query, nargs := generateQuery(t)
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
if err != nil {
return nil, err
}
defer db_query.Close()
rows, err := db_query.Query(args...)
if err != nil {
return nil, err
}
@ -251,6 +262,43 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
return nil
}
func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) {
t := reflect.TypeOf(store).Elem()
query, nargs := generateQuery(t)
query2 := ""
for i := 0; i < nargs; i += 1 {
query2 += fmt.Sprintf("$%d,", i)
}
// Remove last quotation
query2 = query2[0 : len(query2)-1]
val := reflect.ValueOf(store).Elem()
scan_args := make([]interface{}, nargs);
for i := 0; i < nargs; i++ {
valueField := val.Field(i)
scan_args[i] = valueField.Interface()
}
rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return "", CouldNotInsert
}
err = rows.Scan(&id)
if err != nil {
return
}
return
}
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem()

View File

@ -0,0 +1,223 @@
import tensorflow as tf
import random
import pandas as pd
from tensorflow import keras
from tensorflow.data import AUTOTUNE
from keras import layers, losses, optimizers
import requests
import numpy as np
class NotifyServerCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, log, *args, **kwargs):
requests.get(f'http://localhost:8000/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}')
DATA_DIR = "{{ .DataDir }}"
image_size = ({{ .Size }})
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
keys = tf.constant(df['Id'].dropna())
values = tf.constant(list(map(int, df['Index'].dropna())))
depth = {{ .Depth }}
diff = {{ .StartPoint }}
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=keys,
values=values,
),
default_value=tf.constant(-1),
name="Indexes"
)
DATA_DIR_PREPARE = DATA_DIR + "/"
# based on https://www.tensorflow.org/tutorials/load_data/images
def pathToLabel(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }}
ERROR
{{ end }}
num = table.lookup(tf.strings.as_string([path]))
return tf.cond(
tf.math.equal(num, tf.constant(-1)),
lambda: tf.zeros([depth]),
lambda: tf.one_hot(table.lookup(tf.strings.as_string([path])) - diff, depth)[0]
)
old_model = keras.models.load_model("{{ .BaseModel }}")
def decode_image(img):
{{ if eq .Model.Format "png" }}
img = tf.io.decode_png(img, channels={{.ColorMode}})
{{ else if eq .Model.Format "jpeg" }}
img = tf.io.decode_jpeg(img, channels={{.ColorMode}})
{{ else }}
ERROR
{{ end }}
return tf.image.resize(img, image_size)
def process_path(path):
label = pathToLabel(path)
img = tf.io.read_file(path)
img = decode_image(img)
return img, label
def configure_for_performance(ds: tf.data.Dataset, size: int, shuffle: bool) -> tf.data.Dataset:
# ds = ds.cache()
if shuffle:
ds = ds.shuffle(buffer_size=size)
ds = ds.batch(batch_size)
ds = ds.prefetch(AUTOTUNE)
return ds
def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
ds = configure_for_performance(ds, size, False)
return ds
def filterDataset(path):
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
{{ if eq .Model.Format "png" }}
path = tf.strings.regex_replace(path, ".png", "")
{{ else if eq .Model.Format "jpeg" }}
path = tf.strings.regex_replace(path, ".jpeg", "")
{{ else }}
ERROR
{{ end }}
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
seed = random.randint(0, 100000000)
batch_size = 64
# Read all the files from the direcotry
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
list_ds = list_ds.filter(filterDataset)
image_count = len(list(list_ds.as_numpy_iterator()))
list_ds = list_ds.shuffle(image_count, seed=seed)
val_size = int(image_count * 0.3)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)
dataset = prepare_dataset(train_ds, image_count)
dataset_validation = prepare_dataset(val_ds, val_size)
track = 0
def addBlock(
b_size: int,
filter_size: int,
kernel_size: int = 3,
top: bool = True,
pooling_same: bool = False,
pool_func=layers.MaxPool2D
):
global track
model = keras.Sequential(
name=f"{track}-{b_size}-{filter_size}-{kernel_size}"
)
track += 1
for _ in range(b_size):
model.add(layers.Conv2D(
filter_size,
kernel_size,
padding="same"
))
model.add(layers.ReLU())
if top:
if pooling_same:
model.add(pool_func(padding="same", strides=(1, 1)))
else:
model.add(pool_func())
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.4))
return model
# Proccess old data
new_data = old_model.predict(dataset)
labels = np.concatenate([y for _, y in dataset], axis=0)
new_data = tf.data.Dataset.from_tensor_slices(
(new_data, labels))
new_data = configure_for_performance(new_data, batch_size, True)
new_data_val = old_model.predict(dataset_validation)
labels_val = np.concatenate([y for _, y in dataset_validation], axis=0)
new_data_val = tf.data.Dataset.from_tensor_slices(
(new_data_val, labels_val))
new_data_val = configure_for_performance(new_data_val, batch_size, True)
{{ if .LoadPrev }}
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
{{ else }}
model = keras.Sequential()
{{- range .Layers }}
{{- if eq .LayerType 1}}
model.add(layers.Rescaling(1./255))
{{- else if eq .LayerType 2 }}
model.add(layers.Dense({{ .Shape }}, activation="sigmoid"))
{{- else if eq .LayerType 3}}
model.add(layers.Flatten())
{{- else if eq .LayerType 4}}
model.add(addBlock(2, 128, 3, pool_func=layers.AveragePooling2D))
{{- else }}
ERROR
{{- end }}
{{- end }}
{{ end }}
model.layers[0]._name = "head"
model.compile(
loss=losses.BinaryCrossentropy(from_logits=False),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
his = model.fit(
new_data,
validation_data=new_data_val,
epochs={{.EPOCH_PER_RUN}},
callbacks=[
NotifyServerCallback(),
tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)
],
use_multiprocessing=True
)
acc = his.history["accuracy"]
f = open("accuracy.val", "w")
f.write(str(acc[-1]))
f.close()
tf.saved_model.save(model, "{{ .SaveModelPath }}/model")
model.save("{{ .SaveModelPath }}/model.keras")

View File

@ -17,7 +17,7 @@ split_len = {{ .SplitLen }}
bottom_input = Input(model.input_shape[1:])
bottom_output = bottom_input
top_input = Input(model.layers[split_len + 1].input_shape[1:])
top_input = Input(model.layers[split_len + 1].input_shape[1:], name="head_input")
top_output = top_input
for i, layer in enumerate(model.layers):

View File

@ -146,14 +146,7 @@
<!-- TODO improve message -->
<h2 class="text-center">Failed to prepare model</h2>
<div>TODO button delete</div>
<!--form hx-delete="/models/delete">
<input type="hidden" name="id" value="{{ .Model.Id }}" />
<button class="danger">
Delete
</button>
</form-->
<DeleteModel model={m} />
</div>
<!-- PRE TRAINING STATUS -->
{:else if m.status == 2}
@ -288,7 +281,7 @@
{/await}
<!-- TODO Add ability to stop training -->
</div>
{:else if [5, 6, -6].includes(m.status)}
{:else if [5, 6, -6, 7, -7].includes(m.status)}
<BaseModelInfo model={m} />
<RunModel model={m} />
{#if m.status == 6}
@ -299,6 +292,13 @@
{#if m.status == -6}
<DeleteZip model={m} on:reload={getModel} expand />
{/if}
{#if m.status == -7}
<form>
<!-- TODO add more info about the failure -->
Failed to train the model!
Try to retrain
</form>
{/if}
{#if m.model_type == 2}
<ModelData model={m} on:reload={getModel} />
{/if}

View File

@ -142,15 +142,17 @@
</FileUpload>
</fieldset>
<MessageSimple bind:this={uploadImage} />
{#await uploading}
<button disabled>
Uploading
</button>
{:then}
<button>
Add
</button>
{/await}
{#if file}
{#await uploading}
<button disabled>
Uploading
</button>
{:then}
<button>
Add
</button>
{/await}
{/if}
</form>
</div>
<div class="content" class:selected={isActive("create-class")}>
@ -190,4 +192,6 @@
{/if}
</div>
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} />
{#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)}
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} />
{/if}

View File

@ -106,6 +106,15 @@
class:selected={isActive(item.name)}
>
{item.name}
{#if model.model_type == 2}
{#if item.status == 1}
<span class="bi bi-book" style="color: orange;" />
{:else if item.status == 2}
<span class="bi bi-book" style="color: green;" />
{:else if item.status == 3}
<span class="bi bi-check" style="color: green;" />
{/if}
{/if}
</button>
{/each}
</div>
@ -170,15 +179,17 @@
</FileUpload>
</fieldset>
<MessageSimple bind:this={uploadImage} />
{#await uploading}
<button disabled>
Uploading
</button>
{:then}
<button>
Add
</button>
{/await}
{#if file}
{#await uploading}
<button disabled>
Uploading
</button>
{:then}
<button>
Add
</button>
{/await}
{/if}
</form>
</div>
{/if}

View File

@ -23,6 +23,7 @@
let messages: MessageSimple;
async function submit() {
messages.clear();
submitted = true;
try {
await post('models/train', {
@ -34,62 +35,97 @@
if (e instanceof Response) {
messages.display(await e.json());
} else {
messages.display("Could not start the training of the model");
messages.display('Could not start the training of the model');
}
}
}
async function submitRetrain() {
messages.clear();
submitted = true;
try {
await post('model/train/retrain', { id: model.id });
dispatch('reload');
} catch (e) {
if (e instanceof Response) {
messages.display(await e.json());
} else {
messages.display('Could not start the training of the model');
}
}
}
</script>
<form class:submitted on:submit|preventDefault={submit}>
{#if has_data}
{#if number_of_invalid_images > 0}
<p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be delete when the model trains.
</p>
{#if model.status == 2}
<form class:submitted on:submit|preventDefault={submit}>
{#if has_data}
{#if number_of_invalid_images > 0}
<p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be delete when the model trains.
</p>
{/if}
<MessageSimple bind:this={messages} />
<!-- TODO expading mode -->
<fieldset>
<legend> Model Type </legend>
<div class="input-radial">
<input
id="model_type_simple"
value="simple"
name="model_type"
type="radio"
bind:group={data.model_type}
/>
<label for="model_type_simple">Simple</label><br />
<input
id="model_type_expandable"
value="expandable"
name="model_type"
bind:group={data.model_type}
type="radio"
/>
<label for="model_type_expandable">Expandable</label>
</div>
</fieldset>
<!-- TODO allow more models to be created -->
<fieldset>
<label for="number_of_models">Number of Models</label>
<input
id="number_of_models"
type="number"
name="number_of_models"
bind:value={data.number_of_models}
/>
</fieldset>
<!-- TODO to Change the acc -->
<fieldset>
<label for="accuracy">Target accuracy</label>
<input id="accuracy" type="number" name="accuracy" bind:value={data.accuracy} />
</fieldset>
<!-- TODO allow to chose the base of the model -->
<!-- TODO allow to change the shape of the model -->
<button> Train </button>
{:else}
<h2>To train the model please provide data to the model first</h2>
{/if}
<MessageSimple bind:this={messages} />
<!-- TODO expading mode -->
<fieldset>
<legend> Model Type </legend>
<div class="input-radial">
<input
id="model_type_simple"
value="simple"
name="model_type"
type="radio"
bind:group={data.model_type}
/>
<label for="model_type_simple">Simple</label><br />
<input
id="model_type_expandable"
value="expandable"
name="model_type"
bind:group={data.model_type}
type="radio"
/>
<label for="model_type_expandable">Expandable</label>
</div>
</fieldset>
<!-- TODO allow more models to be created -->
<fieldset>
<label for="number_of_models">Number of Models</label>
<input
id="number_of_models"
type="number"
name="number_of_models"
bind:value={data.number_of_models}
/>
</fieldset>
<!-- TODO to Change the acc -->
<fieldset>
<label for="accuracy">Target accuracy</label>
<input id="accuracy" type="number" name="accuracy" bind:value={data.accuracy} />
</fieldset>
<!-- TODO allow to chose the base of the model -->
<!-- TODO allow to change the shape of the model -->
<button> Train </button>
{:else}
<h2>To train the model please provide data to the model first</h2>
{/if}
</form>
</form>
{:else}
<form class:submitted on:submit|preventDefault={submitRetrain}>
{#if has_data}
<h2>
This model has new classes and can be expanded
</h2>
{#if number_of_invalid_images > 0}
<p class="danger">
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
These images will be delete when the model trains.
</p>
{/if}
<MessageSimple bind:this={messages} />
<button> Retrain </button>
{:else}
<h2>To train the model please provide data to the model first</h2>
{/if}
</form>
{/if}