added the ability to expand the models
This commit is contained in:
parent
274d7d22aa
commit
de0b430467
@ -60,6 +60,7 @@ func HandleList(handle *Handle) {
|
||||
|
||||
max_len := min(11, len(rows))
|
||||
|
||||
c.ShowMessage = false;
|
||||
return c.SendJSON(ReturnType{
|
||||
ImageList: rows[0:max_len],
|
||||
Page: page,
|
||||
|
@ -156,7 +156,7 @@ func processZipFileExpand(c *Context, model *BaseModel) {
|
||||
|
||||
failed := func(msg string) {
|
||||
c.Logger.Error(msg, "err", err)
|
||||
ModelUpdateStatus(c, model.Id, READY_FAILED)
|
||||
ModelUpdateStatus(c, model.Id, READY_ALTERATION_FAILED)
|
||||
}
|
||||
|
||||
reader, err := zip.OpenReader(path.Join("savedData", model.Id, "expand_data.zip"))
|
||||
@ -202,8 +202,19 @@ func processZipFileExpand(c *Context, model *BaseModel) {
|
||||
|
||||
ids := map[string]string{}
|
||||
|
||||
var baseOrder struct {
|
||||
Order int `db:"class_order"`
|
||||
}
|
||||
|
||||
err = GetDBOnce(c, &baseOrder, "model_classes where model_id=$1 order by class_order desc;", model.Id)
|
||||
if err != nil {
|
||||
failed("Failed to get the last class_order")
|
||||
}
|
||||
|
||||
base := baseOrder.Order + 1
|
||||
|
||||
for i, name := range training {
|
||||
id, err := model_classes.CreateClass(c.Db, model.Id, i, name)
|
||||
id, err := model_classes.CreateClass(c.Db, model.Id, base + i, name)
|
||||
if err != nil {
|
||||
failed(fmt.Sprintf("Failed to create class '%s' on db\n", name))
|
||||
return
|
||||
@ -416,7 +427,7 @@ func handleDataUpload(handle *Handle) {
|
||||
|
||||
delete_path := "base_data.zip"
|
||||
|
||||
if model.Status == READY_FAILED {
|
||||
if model.Status == READY_ALTERATION_FAILED {
|
||||
delete_path = "expand_data.zip"
|
||||
} else if model.Status != FAILED_PREPARING_ZIP_FILE {
|
||||
return c.JsonBadRequest("Model not in the correct status")
|
||||
@ -427,7 +438,7 @@ func handleDataUpload(handle *Handle) {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != READY_FAILED {
|
||||
if model.Status != READY_ALTERATION_FAILED {
|
||||
err = os.RemoveAll(path.Join("savedData", model.Id, "data"))
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
@ -436,7 +447,7 @@ func handleDataUpload(handle *Handle) {
|
||||
c.Logger.Warn("Handle failed to remove the savedData when deleteing the zip file while expanding")
|
||||
}
|
||||
|
||||
if model.Status != READY_FAILED {
|
||||
if model.Status != READY_ALTERATION_FAILED {
|
||||
_, err = handle.Db.Exec("delete from model_classes where model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
@ -448,7 +459,7 @@ func handleDataUpload(handle *Handle) {
|
||||
}
|
||||
}
|
||||
|
||||
if model.Status != READY_FAILED {
|
||||
if model.Status != READY_ALTERATION_FAILED {
|
||||
ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
|
||||
} else {
|
||||
ModelUpdateStatus(c, model.Id, READY)
|
||||
|
@ -29,7 +29,7 @@ func deleteModelJSON(c *Context, id string) *Error {
|
||||
|
||||
func handleDelete(handle *Handle) {
|
||||
handle.Delete("/models/delete", func(c *Context) *Error {
|
||||
if c.CheckAuthLevel(1) {
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
var dat struct {
|
||||
@ -66,6 +66,10 @@ func handleDelete(handle *Handle) {
|
||||
|
||||
case READY:
|
||||
fallthrough
|
||||
case READY_RETRAIN_FAILED:
|
||||
fallthrough
|
||||
case READY_ALTERATION_FAILED:
|
||||
fallthrough
|
||||
case CONFIRM_PRE_TRAINING:
|
||||
if dat.Name == nil {
|
||||
return c.JsonBadRequest("Provided name does not match the model name")
|
||||
|
@ -41,6 +41,7 @@ func handleEdit(handle *Handle) {
|
||||
NumberOfInvalidImages int `json:"number_of_invalid_images"`
|
||||
}
|
||||
|
||||
c.ShowMessage = false;
|
||||
return c.SendJSON(ReturnType{
|
||||
Classes: cls,
|
||||
HasData: has_data,
|
||||
@ -180,6 +181,7 @@ func handleEdit(handle *Handle) {
|
||||
}
|
||||
}
|
||||
|
||||
c.ShowMessage = false;
|
||||
return c.SendJSON(defsToReturn)
|
||||
})
|
||||
|
||||
@ -188,10 +190,6 @@ func handleEdit(handle *Handle) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
id, err := GetIdFromUrl(c, "id")
|
||||
if err != nil {
|
||||
return c.JsonBadRequest("Model not found")
|
||||
@ -216,6 +214,7 @@ func handleEdit(handle *Handle) {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
c.ShowMessage = false
|
||||
return c.SendJSON(model)
|
||||
})
|
||||
}
|
||||
|
@ -87,6 +87,8 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("test", "count", len(heads))
|
||||
|
||||
var vmax float32 = 0.0
|
||||
|
||||
for _, element := range heads {
|
||||
@ -95,12 +97,14 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
||||
results := head_model.Exec([]tf.Output{
|
||||
head_model.Op("StatefulPartitionedCall", 0),
|
||||
}, map[tf.Output]*tf.Tensor{
|
||||
head_model.Op("serving_default_input_2", 0): base_results[0],
|
||||
head_model.Op("serving_default_head_input", 0): base_results[0],
|
||||
})
|
||||
|
||||
var predictions = results[0].Value().([][]float32)[0]
|
||||
|
||||
|
||||
for i, v := range predictions {
|
||||
c.Logger.Info("predictions", "class", i, "preds", v)
|
||||
if v > vmax {
|
||||
order = element.Range_start + i
|
||||
vmax = v
|
||||
@ -111,7 +115,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
||||
// TODO runthe head model
|
||||
confidence = vmax
|
||||
|
||||
c.Logger.Info("Got", "heads", len(heads))
|
||||
c.Logger.Info("Got", "heads", len(heads), "order", order, "vmax", vmax)
|
||||
return
|
||||
}
|
||||
|
||||
@ -155,7 +159,7 @@ func handleRun(handle *Handle) {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
if model.Status != READY {
|
||||
if model.Status != READY && model.Status != READY_RETRAIN && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION && model.Status != READY_ALTERATION_FAILED {
|
||||
return c.JsonBadRequest("Model not ready to run images")
|
||||
}
|
||||
|
||||
|
@ -36,13 +36,16 @@ func getDir() string {
|
||||
return dir
|
||||
}
|
||||
|
||||
// This function creates a new model_definition
|
||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||
id = ""
|
||||
|
||||
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return id, errors.New("Something wrong!")
|
||||
}
|
||||
@ -72,17 +75,14 @@ func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type
|
||||
|
||||
func generateCvs(c *Context, run_path string, model_id string) (count int, err error) {
|
||||
|
||||
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id)
|
||||
var co struct {
|
||||
Count int `db:"count(*)"`
|
||||
}
|
||||
err = GetDBOnce(c, &co, "model_classes where model_id=$1;", model_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer classes.Close()
|
||||
if !classes.Next() {
|
||||
return
|
||||
}
|
||||
if err = classes.Scan(&count); err != nil {
|
||||
return
|
||||
}
|
||||
count = co.Count
|
||||
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING)
|
||||
if err != nil {
|
||||
@ -121,19 +121,14 @@ func setModelClassStatus(c *Context, status ModelClassStatus, filter string, arg
|
||||
|
||||
func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) {
|
||||
|
||||
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||
var co struct {
|
||||
Count int `db:"count(*)"`
|
||||
}
|
||||
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer classes.Close()
|
||||
|
||||
if !classes.Next() {
|
||||
return
|
||||
}
|
||||
|
||||
if err = classes.Scan(&count); err != nil {
|
||||
return
|
||||
}
|
||||
count = co.Count
|
||||
|
||||
if count == 0 {
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
|
||||
@ -214,7 +209,6 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer removeAll(run_path, err)
|
||||
|
||||
classCount, err := generateCvs(c, run_path, model.Id)
|
||||
if err != nil {
|
||||
@ -283,55 +277,125 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
||||
return
|
||||
}
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
|
||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||
return
|
||||
}
|
||||
|
||||
func removeAll(path string, err error) {
|
||||
if err != nil {
|
||||
os.RemoveAll(path)
|
||||
func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
|
||||
|
||||
var co struct {
|
||||
Count int `db:"count(*)"`
|
||||
}
|
||||
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Logger.Info("test here", "count", co)
|
||||
count_re = co.Count
|
||||
count := co.Count
|
||||
|
||||
if count == 0 {
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
|
||||
if err != nil {
|
||||
return
|
||||
} else if doPanic {
|
||||
return 0, errors.New("No model classes available")
|
||||
}
|
||||
return generateCvsExpandExp(c, run_path, model_id, offset, true)
|
||||
}
|
||||
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer data.Close()
|
||||
|
||||
f, err := os.Create(path.Join(run_path, "train.csv"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
f.Write([]byte("Id,Index\n"))
|
||||
|
||||
count = 0
|
||||
|
||||
for data.Next() {
|
||||
var id string
|
||||
var class_order int
|
||||
var file_path string
|
||||
if err = data.Scan(&id, &class_order, &file_path); err != nil {
|
||||
return
|
||||
}
|
||||
if file_path == "id://" {
|
||||
f.Write([]byte(id + "," + strconv.Itoa(class_order-offset) + "\n"))
|
||||
} else {
|
||||
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||
}
|
||||
count += 1
|
||||
}
|
||||
|
||||
//
|
||||
// This is to load some extra data so that the model has more things to train on
|
||||
//
|
||||
|
||||
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer data_other.Close()
|
||||
|
||||
for data_other.Next() {
|
||||
var id string
|
||||
var class_order int
|
||||
var file_path string
|
||||
if err = data_other.Scan(&id, &class_order, &file_path); err != nil {
|
||||
return
|
||||
}
|
||||
if file_path == "id://" {
|
||||
f.Write([]byte(id + "," + strconv.Itoa(-1) + "\n"))
|
||||
} else {
|
||||
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
accuracy = 0
|
||||
|
||||
c.Logger.Warn("About to start training definition")
|
||||
c.Logger.Warn("About to retrain model")
|
||||
|
||||
// Get untrained models heads
|
||||
|
||||
// Status = 2 (INIT)
|
||||
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||
type ExpHead struct {
|
||||
Id string
|
||||
Start int `db:"range_start"`
|
||||
End int `db:"range_end"`
|
||||
}
|
||||
|
||||
// status = 2 (INIT) 3 (TRAINING)
|
||||
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
type ExpHead struct {
|
||||
id string
|
||||
start int
|
||||
end int
|
||||
}
|
||||
|
||||
exp := ExpHead{}
|
||||
|
||||
if rows.Next() {
|
||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
} else if len(heads) == 0 {
|
||||
log.Error("Failed to get the exp head of the model")
|
||||
err = errors.New("Failed to get the exp head of the model")
|
||||
return
|
||||
}
|
||||
|
||||
if rows.Next() {
|
||||
} else if len(heads) != 1 {
|
||||
log.Error("This training function can only train one model at the time")
|
||||
err = errors.New("This training function can only train one model at the time")
|
||||
return
|
||||
}
|
||||
|
||||
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
|
||||
exp := heads[0]
|
||||
|
||||
c.Logger.Info("Got exp head", "head", exp)
|
||||
|
||||
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
@ -348,8 +412,178 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
|
||||
got := []layerrow{}
|
||||
|
||||
remove_top_count := 1
|
||||
i := 1
|
||||
var last *layerrow = nil
|
||||
got_2 := false
|
||||
|
||||
var first *layerrow = nil
|
||||
|
||||
for layers.Next() {
|
||||
var row = layerrow{}
|
||||
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Keep track of the first layer so we can keep the size of the image
|
||||
if first == nil {
|
||||
first = &row
|
||||
}
|
||||
|
||||
row.LayerNum = i
|
||||
row.Shape = shapeToSize(row.Shape)
|
||||
if row.ExpType == 2 {
|
||||
if !got_2 {
|
||||
got = append(got, *last)
|
||||
got_2 = true
|
||||
}
|
||||
got = append(got, row)
|
||||
}
|
||||
last = &row
|
||||
i += 1
|
||||
}
|
||||
|
||||
got = append(got, layerrow{
|
||||
LayerType: LAYER_DENSE,
|
||||
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
|
||||
ExpType: 2,
|
||||
LayerNum: i,
|
||||
})
|
||||
|
||||
c.Logger.Info("Got layers", "layers", got)
|
||||
|
||||
// Generate run folder
|
||||
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
|
||||
|
||||
err = os.MkdirAll(run_path, os.ModePerm)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
classCount, err := generateCvsExpandExp(c, run_path, model.Id, exp.Start, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Generated cvs", "classCount", classCount)
|
||||
|
||||
// TODO update the run script
|
||||
|
||||
// Create python script
|
||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
c.Logger.Info("About to run python!")
|
||||
|
||||
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Copy result around
|
||||
result_path := path.Join("savedData", model.Id, "defs", definition_id)
|
||||
|
||||
if err = tmpl.Execute(f, AnyMap{
|
||||
"Layers": got,
|
||||
"Size": first.Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||
"HeadId": exp.Id,
|
||||
"RunPath": run_path,
|
||||
"ColorMode": model.ImageMode,
|
||||
"Model": model,
|
||||
"EPOCH_PER_RUN": EPOCH_PER_RUN,
|
||||
"LoadPrev": load_prev,
|
||||
"BaseModel": path.Join(getDir(), result_path, "base", "model.keras"),
|
||||
"LastModelRunPath": path.Join(getDir(), result_path, "head", exp.Id, "model.keras"),
|
||||
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
|
||||
"Depth": classCount,
|
||||
"StartPoint": 0,
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Run the command
|
||||
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
|
||||
if err != nil {
|
||||
c.Logger.Warn("Python failed to run", "err", err, "out", string(out))
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Python finished running")
|
||||
|
||||
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer accuracy_file.Close()
|
||||
|
||||
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||
return
|
||||
}
|
||||
|
||||
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
accuracy = 0
|
||||
|
||||
c.Logger.Warn("About to start training definition")
|
||||
|
||||
// Get untrained models heads
|
||||
|
||||
type ExpHead struct {
|
||||
Id string
|
||||
Start int `db:"range_start"`
|
||||
End int `db:"range_end"`
|
||||
}
|
||||
|
||||
// status = 2 (INIT) 3 (TRAINING)
|
||||
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
} else if len(heads) == 0 {
|
||||
log.Error("Failed to get the exp head of the model")
|
||||
return
|
||||
} else if len(heads) != 1 {
|
||||
log.Error("This training function can only train one model at the time")
|
||||
err = errors.New("This training function can only train one model at the time")
|
||||
return
|
||||
}
|
||||
|
||||
exp := heads[0]
|
||||
|
||||
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer layers.Close()
|
||||
|
||||
type layerrow struct {
|
||||
LayerType int
|
||||
Shape string
|
||||
ExpType int
|
||||
LayerNum int
|
||||
}
|
||||
|
||||
got := []layerrow{}
|
||||
i := 1
|
||||
|
||||
for layers.Next() {
|
||||
@ -358,9 +592,6 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
return
|
||||
}
|
||||
row.LayerNum = i
|
||||
if row.ExpType == 2 {
|
||||
remove_top_count += 1
|
||||
}
|
||||
row.Shape = shapeToSize(row.Shape)
|
||||
got = append(got, row)
|
||||
i += 1
|
||||
@ -368,19 +599,18 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
|
||||
got = append(got, layerrow{
|
||||
LayerType: LAYER_DENSE,
|
||||
Shape: fmt.Sprintf("%d", exp.end-exp.start+1),
|
||||
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
|
||||
ExpType: 2,
|
||||
LayerNum: i,
|
||||
})
|
||||
|
||||
// Generate run folder
|
||||
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
|
||||
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id)
|
||||
|
||||
err = os.MkdirAll(run_path, os.ModePerm)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer removeAll(run_path, err)
|
||||
|
||||
classCount, err := generateCvsExp(c, run_path, model.Id, false)
|
||||
if err != nil {
|
||||
@ -408,16 +638,14 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
"Layers": got,
|
||||
"Size": got[0].Shape,
|
||||
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||
"HeadId": exp.id,
|
||||
"HeadId": exp.Id,
|
||||
"RunPath": run_path,
|
||||
"ColorMode": model.ImageMode,
|
||||
"Model": model,
|
||||
"EPOCH_PER_RUN": EPOCH_PER_RUN,
|
||||
"DefId": definition_id,
|
||||
"LoadPrev": load_prev,
|
||||
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
||||
"SaveModelPath": path.Join(getDir(), result_path),
|
||||
"RemoveTopCount": remove_top_count,
|
||||
"Depth": classCount,
|
||||
"StartPoint": 0,
|
||||
}); err != nil {
|
||||
@ -453,6 +681,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
||||
return
|
||||
}
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||
return
|
||||
}
|
||||
@ -762,6 +991,12 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2;", MODEL_HEAD_STATUS_READY, def.Id)
|
||||
if err != nil {
|
||||
failed("Failed to train definition!")
|
||||
return
|
||||
}
|
||||
|
||||
finished = true
|
||||
break
|
||||
}
|
||||
@ -823,17 +1058,17 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
}
|
||||
}
|
||||
|
||||
// Set the class status to trained
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
failed("Failed to set class status")
|
||||
return
|
||||
}
|
||||
|
||||
var dat JustId
|
||||
|
||||
err = GetDBOnce(c, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
|
||||
if err == NotFoundError {
|
||||
// Set the class status to trained
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
failed("All definitions failed to train! And Failed to set class status")
|
||||
return
|
||||
}
|
||||
|
||||
failed("All definitions failed to train!")
|
||||
return
|
||||
} else if err != nil {
|
||||
@ -863,10 +1098,23 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
}
|
||||
|
||||
if err = splitModel(c, model); err != nil {
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
failed("Failed to split the model! And Failed to set class status")
|
||||
return
|
||||
}
|
||||
|
||||
failed("Failed to split the model")
|
||||
return
|
||||
}
|
||||
|
||||
// Set the class status to trained
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
failed("Failed to set class status")
|
||||
return
|
||||
}
|
||||
|
||||
// There should only be one def availabale
|
||||
def := JustId{}
|
||||
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
||||
@ -884,25 +1132,22 @@ func trainModelExp(c *Context, model *BaseModel) {
|
||||
func splitModel(c *Context, model *BaseModel) (err error) {
|
||||
|
||||
def := JustId{}
|
||||
|
||||
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
head := JustId{}
|
||||
|
||||
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate run folder
|
||||
run_path := path.Join("/tmp", model.Id, "defs", def.Id)
|
||||
run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id+"-split")
|
||||
|
||||
err = os.MkdirAll(run_path, os.ModePerm)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer removeAll(run_path, err)
|
||||
|
||||
// Create python script
|
||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||
@ -970,8 +1215,8 @@ func splitModel(c *Context, model *BaseModel) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
c.Logger.Info("Python finished running")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -1141,6 +1386,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
|
||||
}
|
||||
|
||||
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
|
||||
|
||||
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
|
||||
|
||||
if err != nil {
|
||||
@ -1306,6 +1552,244 @@ func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResetClasses(c *Context, model *BaseModel) {
|
||||
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Error while reseting the classes", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func trainExpandable(c *Context, model *BaseModel) {
|
||||
var err error = nil
|
||||
|
||||
failed := func(msg string) {
|
||||
c.Logger.Error(msg, "err", err)
|
||||
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||
ResetClasses(c, model)
|
||||
}
|
||||
|
||||
var definitions TrainModelRowUsables
|
||||
|
||||
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
|
||||
if err != nil {
|
||||
failed("Failed to get definitions")
|
||||
return
|
||||
}
|
||||
if len(definitions) != 1 {
|
||||
failed("There should only be one definition available!")
|
||||
return
|
||||
}
|
||||
|
||||
firstRound := true
|
||||
def := definitions[0]
|
||||
epoch := 0
|
||||
|
||||
for {
|
||||
acc, err := trainDefinitionExp(c, model, def.Id, !firstRound)
|
||||
if err != nil {
|
||||
failed("Failed to train definition!")
|
||||
return
|
||||
}
|
||||
epoch += EPOCH_PER_RUN
|
||||
|
||||
if float64(acc*100) >= float64(def.Acuracy) {
|
||||
c.Logger.Info("Found a definition that reaches target_accuracy!")
|
||||
|
||||
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
failed("Failed to train definition!")
|
||||
return
|
||||
}
|
||||
break
|
||||
} else if def.Epoch > MAX_EPOCH {
|
||||
failed(fmt.Sprintf("Failed to train definition! Accuracy less %f < %d\n", acc*100, def.TargetAccuracy))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Set the class status to trained
|
||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
failed("Failed to set class status")
|
||||
return
|
||||
}
|
||||
|
||||
ModelUpdateStatus(c, model.Id, READY)
|
||||
}
|
||||
|
||||
func trainRetrain(c *Context, model *BaseModel, defId string) {
|
||||
var err error
|
||||
|
||||
failed := func() {
|
||||
ResetClasses(c, model)
|
||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||
c.Logger.Error("Failed to retrain", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// This is something I have to check
|
||||
acc, err := trainDefinitionExpandExp(c, model, defId, false)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to retrain the model", "err", err)
|
||||
failed()
|
||||
return
|
||||
|
||||
}
|
||||
c.Logger.Info("Retrained model", "accuracy", acc)
|
||||
|
||||
// TODO check accuracy
|
||||
|
||||
err = UpdateStatus(c, "models", model.Id, READY)
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("model updaded")
|
||||
|
||||
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Error while updating the classes", "error", err)
|
||||
failed()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func handleRetrain(c *Context) *Error {
|
||||
var err error = nil
|
||||
if !c.CheckAuthLevel(1) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var dat JustId
|
||||
|
||||
if err_ := c.ToJSON(&dat); err_ != nil {
|
||||
return err_
|
||||
}
|
||||
|
||||
if dat.Id == "" {
|
||||
return c.JsonBadRequest("Please provide a id")
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c.Db, dat.Id)
|
||||
if err == ModelNotFoundError {
|
||||
return c.JsonBadRequest("Model not found")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
||||
return c.JsonBadRequest("Model in invalid status for re-training")
|
||||
}
|
||||
|
||||
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
||||
|
||||
classesUpdated := false
|
||||
|
||||
failed := func() *Error {
|
||||
if classesUpdated {
|
||||
ResetClasses(c, model)
|
||||
}
|
||||
|
||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||
c.Logger.Error("Failed to retrain", "err", err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
var def struct {
|
||||
Id string
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
}
|
||||
|
||||
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
type C struct {
|
||||
Id string
|
||||
ClassOrder int `db:"class_order"`
|
||||
}
|
||||
|
||||
err = c.StartTx()
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
classes, err := GetDbMultitple[C](
|
||||
c,
|
||||
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
||||
model.Id,
|
||||
MODEL_CLASS_STATUS_TO_TRAIN,
|
||||
)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
if len(classes) == 0 {
|
||||
c.Logger.Error("No classes are available!")
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
//Update the classes
|
||||
{
|
||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||
err = err2
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
err = c.CommitTx()
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
|
||||
classesUpdated = true
|
||||
}
|
||||
|
||||
_, err = CreateExpModelHead(c, def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT)
|
||||
if err != nil {
|
||||
return failed()
|
||||
}
|
||||
|
||||
go trainRetrain(c, model, def.Id)
|
||||
|
||||
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
||||
fmt.Println(err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
return c.SendJSON(model.Id)
|
||||
}
|
||||
|
||||
func handleTrain(handle *Handle) {
|
||||
handle.Post("/models/train", func(c *Context) *Error {
|
||||
if !c.CheckAuthLevel(1) {
|
||||
@ -1374,6 +1858,8 @@ func handleTrain(handle *Handle) {
|
||||
return c.SendJSON(model.Id)
|
||||
})
|
||||
|
||||
handle.Post("/model/train/retrain", handleRetrain)
|
||||
|
||||
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
||||
f := c.R.URL.Query()
|
||||
|
||||
|
@ -23,13 +23,16 @@ const (
|
||||
FAILED_PREPARING_ZIP_FILE = -2
|
||||
FAILED_PREPARING = -1
|
||||
|
||||
PREPARING = 1
|
||||
CONFIRM_PRE_TRAINING = 2
|
||||
PREPARING_ZIP_FILE = 3
|
||||
TRAINING = 4
|
||||
READY = 5
|
||||
READY_ALTERATION = 6
|
||||
READY_FAILED = -6
|
||||
PREPARING = 1
|
||||
CONFIRM_PRE_TRAINING = 2
|
||||
PREPARING_ZIP_FILE = 3
|
||||
TRAINING = 4
|
||||
READY = 5
|
||||
READY_ALTERATION = 6
|
||||
READY_ALTERATION_FAILED = -6
|
||||
|
||||
READY_RETRAIN = 7
|
||||
READY_RETRAIN_FAILED = -7
|
||||
)
|
||||
|
||||
type ModelDefinitionStatus int
|
||||
@ -62,6 +65,16 @@ const (
|
||||
MODEL_CLASS_STATUS_TRAINED = 3
|
||||
)
|
||||
|
||||
type ModelHeadStatus int
|
||||
|
||||
const (
|
||||
MODEL_HEAD_STATUS_PRE_INIT ModelHeadStatus = 1
|
||||
MODEL_HEAD_STATUS_INIT = 2
|
||||
MODEL_HEAD_STATUS_TRAINING = 3
|
||||
MODEL_HEAD_STATUS_TRAINED = 4
|
||||
MODEL_HEAD_STATUS_READY = 5
|
||||
)
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
|
@ -89,6 +89,7 @@ func (x *Handle) handleGets(context *Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
context.ShowMessage = false
|
||||
handleError(&Error{404, "Endpoint not found"}, context)
|
||||
}
|
||||
|
||||
@ -99,6 +100,7 @@ func (x *Handle) handlePosts(context *Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
context.ShowMessage = false
|
||||
handleError(&Error{404, "Endpoint not found"}, context)
|
||||
}
|
||||
|
||||
@ -109,6 +111,7 @@ func (x *Handle) handleDeletes(context *Context) {
|
||||
return
|
||||
}
|
||||
}
|
||||
context.ShowMessage = false
|
||||
handleError(&Error{404, "Endpoint not found"}, context)
|
||||
}
|
||||
|
||||
@ -127,12 +130,58 @@ func (c *Context) CheckAuthLevel(authLevel int) bool {
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
Token *string
|
||||
User *dbtypes.User
|
||||
Logger *log.Logger
|
||||
Db *sql.DB
|
||||
Writer http.ResponseWriter
|
||||
R *http.Request
|
||||
Token *string
|
||||
User *dbtypes.User
|
||||
Logger *log.Logger
|
||||
Db *sql.DB
|
||||
Writer http.ResponseWriter
|
||||
R *http.Request
|
||||
Tx *sql.Tx
|
||||
ShowMessage bool
|
||||
}
|
||||
|
||||
func (c Context) Prepare(str string) (*sql.Stmt, error) {
|
||||
if c.Tx == nil {
|
||||
return c.Db.Prepare(str)
|
||||
}
|
||||
|
||||
return c.Tx.Prepare(str)
|
||||
}
|
||||
|
||||
var TransactionAlreadyStarted = errors.New("Transaction already started")
|
||||
var TransactionNotStarted = errors.New("Transaction not started")
|
||||
|
||||
func (c *Context) StartTx() error {
|
||||
if c.Tx != nil {
|
||||
return TransactionAlreadyStarted
|
||||
}
|
||||
var err error = nil
|
||||
c.Tx, err = c.Db.Begin()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Context) CommitTx() error {
|
||||
if c.Tx == nil {
|
||||
return TransactionNotStarted
|
||||
}
|
||||
err := c.Tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Tx = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Context) RollbackTx() error {
|
||||
if c.Tx == nil {
|
||||
return TransactionNotStarted
|
||||
}
|
||||
err := c.Tx.Rollback()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.Tx = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Context) ToJSON(dat any) *Error {
|
||||
@ -200,14 +249,14 @@ func (c *Context) GetModelFromId(id_path string) (*BaseModel, *Error) {
|
||||
return nil, c.Error500(err)
|
||||
}
|
||||
|
||||
return model, nil
|
||||
return model, nil
|
||||
}
|
||||
|
||||
func ModelUpdateStatus(c *Context, id string, status int) {
|
||||
_, err := c.Db.Exec("update models set status=$1 where id=$2;", status, id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to update model status", "err", err)
|
||||
c.Logger.Warn("TODO Maybe handle better")
|
||||
c.Logger.Error("Failed to update model status", "err", err)
|
||||
c.Logger.Warn("TODO Maybe handle better")
|
||||
}
|
||||
}
|
||||
|
||||
@ -270,6 +319,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
|
||||
Db: handler.Db,
|
||||
Writer: w,
|
||||
R: r,
|
||||
ShowMessage: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -278,7 +328,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
|
||||
return nil, errors.Join(err, LogoffError)
|
||||
}
|
||||
|
||||
return &Context{token, user, logger, handler.Db, w, r}, nil
|
||||
return &Context{token, user, logger, handler.Db, w, r, nil, true}, nil
|
||||
}
|
||||
|
||||
func contextlessLogoff(w http.ResponseWriter) {
|
||||
@ -457,20 +507,19 @@ func NewHandler(db *sql.DB) *Handle {
|
||||
|
||||
if r.Method == "GET" {
|
||||
x.handleGets(context)
|
||||
return
|
||||
}
|
||||
if r.Method == "POST" {
|
||||
} else if r.Method == "POST" {
|
||||
x.handlePosts(context)
|
||||
return
|
||||
}
|
||||
if r.Method == "DELETE" {
|
||||
} else if r.Method == "DELETE" {
|
||||
x.handleDeletes(context)
|
||||
return
|
||||
}
|
||||
if r.Method == "OPTIONS" {
|
||||
return
|
||||
}
|
||||
panic("TODO handle method: " + r.Method)
|
||||
} else if r.Method == "OPTIONS" {
|
||||
// do nothing
|
||||
} else {
|
||||
panic("TODO handle method: " + r.Method)
|
||||
}
|
||||
|
||||
if context.ShowMessage {
|
||||
context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
|
||||
}
|
||||
})
|
||||
|
||||
return x
|
||||
|
@ -189,6 +189,7 @@ type JustId struct { Id string }
|
||||
type Generic struct{ reflect.Type }
|
||||
|
||||
var NotFoundError = errors.New("Not found")
|
||||
var CouldNotInsert = errors.New("Could not insert")
|
||||
|
||||
func generateQuery(t reflect.Type) (query string, nargs int) {
|
||||
nargs = t.NumField()
|
||||
@ -200,6 +201,10 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
|
||||
if !ok {
|
||||
name = field.Name;
|
||||
}
|
||||
|
||||
if name == "__nil__" {
|
||||
continue
|
||||
}
|
||||
query += strings.ToLower(name) + ","
|
||||
}
|
||||
|
||||
@ -214,7 +219,13 @@ func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
|
||||
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db_query.Close()
|
||||
|
||||
rows, err := db_query.Query(args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -251,6 +262,43 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) {
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
|
||||
query2 := ""
|
||||
for i := 0; i < nargs; i += 1 {
|
||||
query2 += fmt.Sprintf("$%d,", i)
|
||||
}
|
||||
// Remove last quotation
|
||||
query2 = query2[0 : len(query2)-1]
|
||||
|
||||
val := reflect.ValueOf(store).Elem()
|
||||
scan_args := make([]interface{}, nargs);
|
||||
for i := 0; i < nargs; i++ {
|
||||
valueField := val.Field(i)
|
||||
scan_args[i] = valueField.Interface()
|
||||
}
|
||||
|
||||
rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return "", CouldNotInsert
|
||||
}
|
||||
|
||||
err = rows.Scan(&id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
|
||||
|
223
views/py/python_model_template_expand.py
Normal file
223
views/py/python_model_template_expand.py
Normal file
@ -0,0 +1,223 @@
|
||||
import tensorflow as tf
|
||||
import random
|
||||
import pandas as pd
|
||||
from tensorflow import keras
|
||||
from tensorflow.data import AUTOTUNE
|
||||
from keras import layers, losses, optimizers
|
||||
import requests
|
||||
import numpy as np
|
||||
|
||||
|
||||
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||
def on_epoch_end(self, epoch, log, *args, **kwargs):
|
||||
requests.get(f'http://localhost:8000/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}')
|
||||
|
||||
|
||||
DATA_DIR = "{{ .DataDir }}"
|
||||
image_size = ({{ .Size }})
|
||||
|
||||
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
||||
keys = tf.constant(df['Id'].dropna())
|
||||
values = tf.constant(list(map(int, df['Index'].dropna())))
|
||||
|
||||
depth = {{ .Depth }}
|
||||
diff = {{ .StartPoint }}
|
||||
|
||||
table = tf.lookup.StaticHashTable(
|
||||
initializer=tf.lookup.KeyValueTensorInitializer(
|
||||
keys=keys,
|
||||
values=values,
|
||||
),
|
||||
default_value=tf.constant(-1),
|
||||
name="Indexes"
|
||||
)
|
||||
|
||||
DATA_DIR_PREPARE = DATA_DIR + "/"
|
||||
|
||||
|
||||
# based on https://www.tensorflow.org/tutorials/load_data/images
|
||||
def pathToLabel(path):
|
||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||
{{ if eq .Model.Format "png" }}
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
{{ else if eq .Model.Format "jpeg" }}
|
||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||
{{ else }}
|
||||
ERROR
|
||||
{{ end }}
|
||||
|
||||
num = table.lookup(tf.strings.as_string([path]))
|
||||
|
||||
return tf.cond(
|
||||
tf.math.equal(num, tf.constant(-1)),
|
||||
lambda: tf.zeros([depth]),
|
||||
lambda: tf.one_hot(table.lookup(tf.strings.as_string([path])) - diff, depth)[0]
|
||||
)
|
||||
|
||||
|
||||
old_model = keras.models.load_model("{{ .BaseModel }}")
|
||||
|
||||
|
||||
def decode_image(img):
|
||||
{{ if eq .Model.Format "png" }}
|
||||
img = tf.io.decode_png(img, channels={{.ColorMode}})
|
||||
{{ else if eq .Model.Format "jpeg" }}
|
||||
img = tf.io.decode_jpeg(img, channels={{.ColorMode}})
|
||||
{{ else }}
|
||||
ERROR
|
||||
{{ end }}
|
||||
|
||||
return tf.image.resize(img, image_size)
|
||||
|
||||
|
||||
def process_path(path):
|
||||
label = pathToLabel(path)
|
||||
|
||||
img = tf.io.read_file(path)
|
||||
img = decode_image(img)
|
||||
|
||||
return img, label
|
||||
|
||||
|
||||
def configure_for_performance(ds: tf.data.Dataset, size: int, shuffle: bool) -> tf.data.Dataset:
|
||||
# ds = ds.cache()
|
||||
if shuffle:
|
||||
ds = ds.shuffle(buffer_size=size)
|
||||
ds = ds.batch(batch_size)
|
||||
ds = ds.prefetch(AUTOTUNE)
|
||||
return ds
|
||||
|
||||
|
||||
def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
|
||||
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
||||
ds = configure_for_performance(ds, size, False)
|
||||
return ds
|
||||
|
||||
|
||||
def filterDataset(path):
|
||||
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||
{{ if eq .Model.Format "png" }}
|
||||
path = tf.strings.regex_replace(path, ".png", "")
|
||||
{{ else if eq .Model.Format "jpeg" }}
|
||||
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||
{{ else }}
|
||||
ERROR
|
||||
{{ end }}
|
||||
|
||||
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
|
||||
|
||||
|
||||
seed = random.randint(0, 100000000)
|
||||
|
||||
batch_size = 64
|
||||
|
||||
# Read all the files from the direcotry
|
||||
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
|
||||
list_ds = list_ds.filter(filterDataset)
|
||||
|
||||
image_count = len(list(list_ds.as_numpy_iterator()))
|
||||
|
||||
list_ds = list_ds.shuffle(image_count, seed=seed)
|
||||
|
||||
val_size = int(image_count * 0.3)
|
||||
|
||||
train_ds = list_ds.skip(val_size)
|
||||
val_ds = list_ds.take(val_size)
|
||||
|
||||
dataset = prepare_dataset(train_ds, image_count)
|
||||
dataset_validation = prepare_dataset(val_ds, val_size)
|
||||
|
||||
track = 0
|
||||
|
||||
|
||||
def addBlock(
|
||||
b_size: int,
|
||||
filter_size: int,
|
||||
kernel_size: int = 3,
|
||||
top: bool = True,
|
||||
pooling_same: bool = False,
|
||||
pool_func=layers.MaxPool2D
|
||||
):
|
||||
global track
|
||||
model = keras.Sequential(
|
||||
name=f"{track}-{b_size}-{filter_size}-{kernel_size}"
|
||||
)
|
||||
track += 1
|
||||
for _ in range(b_size):
|
||||
model.add(layers.Conv2D(
|
||||
filter_size,
|
||||
kernel_size,
|
||||
padding="same"
|
||||
))
|
||||
model.add(layers.ReLU())
|
||||
if top:
|
||||
if pooling_same:
|
||||
model.add(pool_func(padding="same", strides=(1, 1)))
|
||||
else:
|
||||
model.add(pool_func())
|
||||
model.add(layers.BatchNormalization())
|
||||
model.add(layers.LeakyReLU())
|
||||
model.add(layers.Dropout(0.4))
|
||||
return model
|
||||
|
||||
|
||||
# Proccess old data
|
||||
|
||||
new_data = old_model.predict(dataset)
|
||||
labels = np.concatenate([y for _, y in dataset], axis=0)
|
||||
new_data = tf.data.Dataset.from_tensor_slices(
|
||||
(new_data, labels))
|
||||
new_data = configure_for_performance(new_data, batch_size, True)
|
||||
|
||||
new_data_val = old_model.predict(dataset_validation)
|
||||
labels_val = np.concatenate([y for _, y in dataset_validation], axis=0)
|
||||
new_data_val = tf.data.Dataset.from_tensor_slices(
|
||||
(new_data_val, labels_val))
|
||||
new_data_val = configure_for_performance(new_data_val, batch_size, True)
|
||||
|
||||
{{ if .LoadPrev }}
|
||||
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
|
||||
{{ else }}
|
||||
model = keras.Sequential()
|
||||
|
||||
{{- range .Layers }}
|
||||
{{- if eq .LayerType 1}}
|
||||
model.add(layers.Rescaling(1./255))
|
||||
{{- else if eq .LayerType 2 }}
|
||||
model.add(layers.Dense({{ .Shape }}, activation="sigmoid"))
|
||||
{{- else if eq .LayerType 3}}
|
||||
model.add(layers.Flatten())
|
||||
{{- else if eq .LayerType 4}}
|
||||
model.add(addBlock(2, 128, 3, pool_func=layers.AveragePooling2D))
|
||||
{{- else }}
|
||||
ERROR
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
{{ end }}
|
||||
|
||||
model.layers[0]._name = "head"
|
||||
|
||||
model.compile(
|
||||
loss=losses.BinaryCrossentropy(from_logits=False),
|
||||
optimizer=tf.keras.optimizers.Adam(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
his = model.fit(
|
||||
new_data,
|
||||
validation_data=new_data_val,
|
||||
epochs={{.EPOCH_PER_RUN}},
|
||||
callbacks=[
|
||||
NotifyServerCallback(),
|
||||
tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)
|
||||
],
|
||||
use_multiprocessing=True
|
||||
)
|
||||
|
||||
acc = his.history["accuracy"]
|
||||
|
||||
f = open("accuracy.val", "w")
|
||||
f.write(str(acc[-1]))
|
||||
f.close()
|
||||
|
||||
tf.saved_model.save(model, "{{ .SaveModelPath }}/model")
|
||||
model.save("{{ .SaveModelPath }}/model.keras")
|
@ -17,7 +17,7 @@ split_len = {{ .SplitLen }}
|
||||
|
||||
bottom_input = Input(model.input_shape[1:])
|
||||
bottom_output = bottom_input
|
||||
top_input = Input(model.layers[split_len + 1].input_shape[1:])
|
||||
top_input = Input(model.layers[split_len + 1].input_shape[1:], name="head_input")
|
||||
top_output = top_input
|
||||
|
||||
for i, layer in enumerate(model.layers):
|
||||
|
@ -146,14 +146,7 @@
|
||||
<!-- TODO improve message -->
|
||||
<h2 class="text-center">Failed to prepare model</h2>
|
||||
|
||||
<div>TODO button delete</div>
|
||||
|
||||
<!--form hx-delete="/models/delete">
|
||||
<input type="hidden" name="id" value="{{ .Model.Id }}" />
|
||||
<button class="danger">
|
||||
Delete
|
||||
</button>
|
||||
</form-->
|
||||
<DeleteModel model={m} />
|
||||
</div>
|
||||
<!-- PRE TRAINING STATUS -->
|
||||
{:else if m.status == 2}
|
||||
@ -288,7 +281,7 @@
|
||||
{/await}
|
||||
<!-- TODO Add ability to stop training -->
|
||||
</div>
|
||||
{:else if [5, 6, -6].includes(m.status)}
|
||||
{:else if [5, 6, -6, 7, -7].includes(m.status)}
|
||||
<BaseModelInfo model={m} />
|
||||
<RunModel model={m} />
|
||||
{#if m.status == 6}
|
||||
@ -299,6 +292,13 @@
|
||||
{#if m.status == -6}
|
||||
<DeleteZip model={m} on:reload={getModel} expand />
|
||||
{/if}
|
||||
{#if m.status == -7}
|
||||
<form>
|
||||
<!-- TODO add more info about the failure -->
|
||||
Failed to train the model!
|
||||
Try to retrain
|
||||
</form>
|
||||
{/if}
|
||||
{#if m.model_type == 2}
|
||||
<ModelData model={m} on:reload={getModel} />
|
||||
{/if}
|
||||
|
@ -142,15 +142,17 @@
|
||||
</FileUpload>
|
||||
</fieldset>
|
||||
<MessageSimple bind:this={uploadImage} />
|
||||
{#await uploading}
|
||||
<button disabled>
|
||||
Uploading
|
||||
</button>
|
||||
{:then}
|
||||
<button>
|
||||
Add
|
||||
</button>
|
||||
{/await}
|
||||
{#if file}
|
||||
{#await uploading}
|
||||
<button disabled>
|
||||
Uploading
|
||||
</button>
|
||||
{:then}
|
||||
<button>
|
||||
Add
|
||||
</button>
|
||||
{/await}
|
||||
{/if}
|
||||
</form>
|
||||
</div>
|
||||
<div class="content" class:selected={isActive("create-class")}>
|
||||
@ -190,4 +192,6 @@
|
||||
{/if}
|
||||
</div>
|
||||
|
||||
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} />
|
||||
{#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)}
|
||||
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} />
|
||||
{/if}
|
||||
|
@ -106,6 +106,15 @@
|
||||
class:selected={isActive(item.name)}
|
||||
>
|
||||
{item.name}
|
||||
{#if model.model_type == 2}
|
||||
{#if item.status == 1}
|
||||
<span class="bi bi-book" style="color: orange;" />
|
||||
{:else if item.status == 2}
|
||||
<span class="bi bi-book" style="color: green;" />
|
||||
{:else if item.status == 3}
|
||||
<span class="bi bi-check" style="color: green;" />
|
||||
{/if}
|
||||
{/if}
|
||||
</button>
|
||||
{/each}
|
||||
</div>
|
||||
@ -170,15 +179,17 @@
|
||||
</FileUpload>
|
||||
</fieldset>
|
||||
<MessageSimple bind:this={uploadImage} />
|
||||
{#await uploading}
|
||||
<button disabled>
|
||||
Uploading
|
||||
</button>
|
||||
{:then}
|
||||
<button>
|
||||
Add
|
||||
</button>
|
||||
{/await}
|
||||
{#if file}
|
||||
{#await uploading}
|
||||
<button disabled>
|
||||
Uploading
|
||||
</button>
|
||||
{:then}
|
||||
<button>
|
||||
Add
|
||||
</button>
|
||||
{/await}
|
||||
{/if}
|
||||
</form>
|
||||
</div>
|
||||
{/if}
|
||||
|
@ -23,6 +23,7 @@
|
||||
let messages: MessageSimple;
|
||||
|
||||
async function submit() {
|
||||
messages.clear();
|
||||
submitted = true;
|
||||
try {
|
||||
await post('models/train', {
|
||||
@ -34,62 +35,97 @@
|
||||
if (e instanceof Response) {
|
||||
messages.display(await e.json());
|
||||
} else {
|
||||
messages.display("Could not start the training of the model");
|
||||
messages.display('Could not start the training of the model');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function submitRetrain() {
|
||||
messages.clear();
|
||||
submitted = true;
|
||||
try {
|
||||
await post('model/train/retrain', { id: model.id });
|
||||
dispatch('reload');
|
||||
} catch (e) {
|
||||
if (e instanceof Response) {
|
||||
messages.display(await e.json());
|
||||
} else {
|
||||
messages.display('Could not start the training of the model');
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<form class:submitted on:submit|preventDefault={submit}>
|
||||
{#if has_data}
|
||||
{#if number_of_invalid_images > 0}
|
||||
<p class="danger">
|
||||
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
|
||||
These images will be delete when the model trains.
|
||||
</p>
|
||||
{#if model.status == 2}
|
||||
<form class:submitted on:submit|preventDefault={submit}>
|
||||
{#if has_data}
|
||||
{#if number_of_invalid_images > 0}
|
||||
<p class="danger">
|
||||
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
|
||||
These images will be delete when the model trains.
|
||||
</p>
|
||||
{/if}
|
||||
<MessageSimple bind:this={messages} />
|
||||
<!-- TODO expading mode -->
|
||||
<fieldset>
|
||||
<legend> Model Type </legend>
|
||||
<div class="input-radial">
|
||||
<input
|
||||
id="model_type_simple"
|
||||
value="simple"
|
||||
name="model_type"
|
||||
type="radio"
|
||||
bind:group={data.model_type}
|
||||
/>
|
||||
<label for="model_type_simple">Simple</label><br />
|
||||
<input
|
||||
id="model_type_expandable"
|
||||
value="expandable"
|
||||
name="model_type"
|
||||
bind:group={data.model_type}
|
||||
type="radio"
|
||||
/>
|
||||
<label for="model_type_expandable">Expandable</label>
|
||||
</div>
|
||||
</fieldset>
|
||||
<!-- TODO allow more models to be created -->
|
||||
<fieldset>
|
||||
<label for="number_of_models">Number of Models</label>
|
||||
<input
|
||||
id="number_of_models"
|
||||
type="number"
|
||||
name="number_of_models"
|
||||
bind:value={data.number_of_models}
|
||||
/>
|
||||
</fieldset>
|
||||
<!-- TODO to Change the acc -->
|
||||
<fieldset>
|
||||
<label for="accuracy">Target accuracy</label>
|
||||
<input id="accuracy" type="number" name="accuracy" bind:value={data.accuracy} />
|
||||
</fieldset>
|
||||
<!-- TODO allow to chose the base of the model -->
|
||||
<!-- TODO allow to change the shape of the model -->
|
||||
<button> Train </button>
|
||||
{:else}
|
||||
<h2>To train the model please provide data to the model first</h2>
|
||||
{/if}
|
||||
<MessageSimple bind:this={messages} />
|
||||
<!-- TODO expading mode -->
|
||||
<fieldset>
|
||||
<legend> Model Type </legend>
|
||||
<div class="input-radial">
|
||||
<input
|
||||
id="model_type_simple"
|
||||
value="simple"
|
||||
name="model_type"
|
||||
type="radio"
|
||||
bind:group={data.model_type}
|
||||
/>
|
||||
<label for="model_type_simple">Simple</label><br />
|
||||
<input
|
||||
id="model_type_expandable"
|
||||
value="expandable"
|
||||
name="model_type"
|
||||
bind:group={data.model_type}
|
||||
type="radio"
|
||||
/>
|
||||
<label for="model_type_expandable">Expandable</label>
|
||||
</div>
|
||||
</fieldset>
|
||||
<!-- TODO allow more models to be created -->
|
||||
<fieldset>
|
||||
<label for="number_of_models">Number of Models</label>
|
||||
<input
|
||||
id="number_of_models"
|
||||
type="number"
|
||||
name="number_of_models"
|
||||
bind:value={data.number_of_models}
|
||||
/>
|
||||
</fieldset>
|
||||
<!-- TODO to Change the acc -->
|
||||
<fieldset>
|
||||
<label for="accuracy">Target accuracy</label>
|
||||
<input id="accuracy" type="number" name="accuracy" bind:value={data.accuracy} />
|
||||
</fieldset>
|
||||
<!-- TODO allow to chose the base of the model -->
|
||||
<!-- TODO allow to change the shape of the model -->
|
||||
<button> Train </button>
|
||||
{:else}
|
||||
<h2>To train the model please provide data to the model first</h2>
|
||||
{/if}
|
||||
</form>
|
||||
</form>
|
||||
{:else}
|
||||
<form class:submitted on:submit|preventDefault={submitRetrain}>
|
||||
{#if has_data}
|
||||
<h2>
|
||||
This model has new classes and can be expanded
|
||||
</h2>
|
||||
{#if number_of_invalid_images > 0}
|
||||
<p class="danger">
|
||||
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
|
||||
These images will be delete when the model trains.
|
||||
</p>
|
||||
{/if}
|
||||
<MessageSimple bind:this={messages} />
|
||||
<button> Retrain </button>
|
||||
{:else}
|
||||
<h2>To train the model please provide data to the model first</h2>
|
||||
{/if}
|
||||
</form>
|
||||
{/if}
|
||||
|
Loading…
Reference in New Issue
Block a user