added the ability to expand the models

This commit is contained in:
2024-04-08 14:17:13 +01:00
parent 274d7d22aa
commit de0b430467
15 changed files with 1086 additions and 197 deletions

View File

@@ -60,6 +60,7 @@ func HandleList(handle *Handle) {
max_len := min(11, len(rows))
c.ShowMessage = false;
return c.SendJSON(ReturnType{
ImageList: rows[0:max_len],
Page: page,

View File

@@ -156,7 +156,7 @@ func processZipFileExpand(c *Context, model *BaseModel) {
failed := func(msg string) {
c.Logger.Error(msg, "err", err)
ModelUpdateStatus(c, model.Id, READY_FAILED)
ModelUpdateStatus(c, model.Id, READY_ALTERATION_FAILED)
}
reader, err := zip.OpenReader(path.Join("savedData", model.Id, "expand_data.zip"))
@@ -202,8 +202,19 @@ func processZipFileExpand(c *Context, model *BaseModel) {
ids := map[string]string{}
var baseOrder struct {
Order int `db:"class_order"`
}
err = GetDBOnce(c, &baseOrder, "model_classes where model_id=$1 order by class_order desc;", model.Id)
if err != nil {
failed("Failed to get the last class_order")
}
base := baseOrder.Order + 1
for i, name := range training {
id, err := model_classes.CreateClass(c.Db, model.Id, i, name)
id, err := model_classes.CreateClass(c.Db, model.Id, base + i, name)
if err != nil {
failed(fmt.Sprintf("Failed to create class '%s' on db\n", name))
return
@@ -416,7 +427,7 @@ func handleDataUpload(handle *Handle) {
delete_path := "base_data.zip"
if model.Status == READY_FAILED {
if model.Status == READY_ALTERATION_FAILED {
delete_path = "expand_data.zip"
} else if model.Status != FAILED_PREPARING_ZIP_FILE {
return c.JsonBadRequest("Model not in the correct status")
@@ -427,7 +438,7 @@ func handleDataUpload(handle *Handle) {
return c.Error500(err)
}
if model.Status != READY_FAILED {
if model.Status != READY_ALTERATION_FAILED {
err = os.RemoveAll(path.Join("savedData", model.Id, "data"))
if err != nil {
return c.Error500(err)
@@ -436,7 +447,7 @@ func handleDataUpload(handle *Handle) {
c.Logger.Warn("Handle failed to remove the savedData when deleteing the zip file while expanding")
}
if model.Status != READY_FAILED {
if model.Status != READY_ALTERATION_FAILED {
_, err = handle.Db.Exec("delete from model_classes where model_id=$1;", model.Id)
if err != nil {
return c.Error500(err)
@@ -448,7 +459,7 @@ func handleDataUpload(handle *Handle) {
}
}
if model.Status != READY_FAILED {
if model.Status != READY_ALTERATION_FAILED {
ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
} else {
ModelUpdateStatus(c, model.Id, READY)

View File

@@ -29,7 +29,7 @@ func deleteModelJSON(c *Context, id string) *Error {
func handleDelete(handle *Handle) {
handle.Delete("/models/delete", func(c *Context) *Error {
if c.CheckAuthLevel(1) {
if !c.CheckAuthLevel(1) {
return nil
}
var dat struct {
@@ -66,6 +66,10 @@ func handleDelete(handle *Handle) {
case READY:
fallthrough
case READY_RETRAIN_FAILED:
fallthrough
case READY_ALTERATION_FAILED:
fallthrough
case CONFIRM_PRE_TRAINING:
if dat.Name == nil {
return c.JsonBadRequest("Provided name does not match the model name")

View File

@@ -41,6 +41,7 @@ func handleEdit(handle *Handle) {
NumberOfInvalidImages int `json:"number_of_invalid_images"`
}
c.ShowMessage = false;
return c.SendJSON(ReturnType{
Classes: cls,
HasData: has_data,
@@ -179,7 +180,8 @@ func handleEdit(handle *Handle) {
Layers: lay,
}
}
c.ShowMessage = false;
return c.SendJSON(defsToReturn)
})
@@ -188,10 +190,6 @@ func handleEdit(handle *Handle) {
return nil
}
if !c.CheckAuthLevel(1) {
return nil
}
id, err := GetIdFromUrl(c, "id")
if err != nil {
return c.JsonBadRequest("Model not found")
@@ -215,7 +213,8 @@ func handleEdit(handle *Handle) {
} else if err != nil {
return c.Error500(err)
}
c.ShowMessage = false
return c.SendJSON(model)
})
}

View File

@@ -87,6 +87,8 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
return
}
c.Logger.Info("test", "count", len(heads))
var vmax float32 = 0.0
for _, element := range heads {
@@ -95,12 +97,14 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
results := head_model.Exec([]tf.Output{
head_model.Op("StatefulPartitionedCall", 0),
}, map[tf.Output]*tf.Tensor{
head_model.Op("serving_default_input_2", 0): base_results[0],
head_model.Op("serving_default_head_input", 0): base_results[0],
})
var predictions = results[0].Value().([][]float32)[0]
for i, v := range predictions {
c.Logger.Info("predictions", "class", i, "preds", v)
if v > vmax {
order = element.Range_start + i
vmax = v
@@ -111,7 +115,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
// TODO runthe head model
confidence = vmax
c.Logger.Info("Got", "heads", len(heads))
c.Logger.Info("Got", "heads", len(heads), "order", order, "vmax", vmax)
return
}
@@ -155,7 +159,7 @@ func handleRun(handle *Handle) {
return c.Error500(err)
}
if model.Status != READY {
if model.Status != READY && model.Status != READY_RETRAIN && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model not ready to run images")
}

View File

@@ -36,13 +36,16 @@ func getDir() string {
return dir
}
// This function creates a new model_definition
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
id = ""
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return id, errors.New("Something wrong!")
}
@@ -72,17 +75,14 @@ func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type
func generateCvs(c *Context, run_path string, model_id string) (count int, err error) {
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id)
var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1;", model_id)
if err != nil {
return
}
defer classes.Close()
if !classes.Next() {
return
}
if err = classes.Scan(&count); err != nil {
return
}
count = co.Count
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING)
if err != nil {
@@ -121,19 +121,14 @@ func setModelClassStatus(c *Context, status ModelClassStatus, filter string, arg
func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) {
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
defer classes.Close()
if !classes.Next() {
return
}
if err = classes.Scan(&count); err != nil {
return
}
count = co.Count
if count == 0 {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
@@ -214,7 +209,6 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
if err != nil {
return
}
defer removeAll(run_path, err)
classCount, err := generateCvs(c, run_path, model.Id)
if err != nil {
@@ -283,55 +277,125 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
return
}
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy)
return
}
func removeAll(path string, err error) {
if err != nil {
os.RemoveAll(path)
func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
var co struct {
Count int `db:"count(*)"`
}
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
c.Logger.Info("test here", "count", co)
count_re = co.Count
count := co.Count
if count == 0 {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
if err != nil {
return
} else if doPanic {
return 0, errors.New("No model classes available")
}
return generateCvsExpandExp(c, run_path, model_id, offset, true)
}
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
return
}
defer data.Close()
f, err := os.Create(path.Join(run_path, "train.csv"))
if err != nil {
return
}
defer f.Close()
f.Write([]byte("Id,Index\n"))
count = 0
for data.Next() {
var id string
var class_order int
var file_path string
if err = data.Scan(&id, &class_order, &file_path); err != nil {
return
}
if file_path == "id://" {
f.Write([]byte(id + "," + strconv.Itoa(class_order-offset) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
count += 1
}
//
// This is to load some extra data so that the model has more things to train on
//
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count)
if err != nil {
return
}
defer data_other.Close()
for data_other.Next() {
var id string
var class_order int
var file_path string
if err = data_other.Scan(&id, &class_order, &file_path); err != nil {
return
}
if file_path == "id://" {
f.Write([]byte(id + "," + strconv.Itoa(-1) + "\n"))
} else {
return count, errors.New("TODO generateCvs to file_path " + file_path)
}
}
return
}
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0
c.Logger.Warn("About to start training definition")
c.Logger.Warn("About to retrain model")
// Get untrained models heads
// Status = 2 (INIT)
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
type ExpHead struct {
Id string
Start int `db:"range_start"`
End int `db:"range_end"`
}
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil {
return
}
defer rows.Close()
type ExpHead struct {
id string
start int
end int
}
exp := ExpHead{}
if rows.Next() {
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
return
}
} else {
} else if len(heads) == 0 {
log.Error("Failed to get the exp head of the model")
err = errors.New("Failed to get the exp head of the model")
return
}
if rows.Next() {
} else if len(heads) != 1 {
log.Error("This training function can only train one model at the time")
err = errors.New("This training function can only train one model at the time")
return
}
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
exp := heads[0]
c.Logger.Info("Got exp head", "head", exp)
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return
}
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
@@ -348,8 +412,178 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
got := []layerrow{}
remove_top_count := 1
i := 1
var last *layerrow = nil
got_2 := false
var first *layerrow = nil
for layers.Next() {
var row = layerrow{}
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
return
}
// Keep track of the first layer so we can keep the size of the image
if first == nil {
first = &row
}
row.LayerNum = i
row.Shape = shapeToSize(row.Shape)
if row.ExpType == 2 {
if !got_2 {
got = append(got, *last)
got_2 = true
}
got = append(got, row)
}
last = &row
i += 1
}
got = append(got, layerrow{
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2,
LayerNum: i,
})
c.Logger.Info("Got layers", "layers", got)
// Generate run folder
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
classCount, err := generateCvsExpandExp(c, run_path, model.Id, exp.Start, false)
if err != nil {
return
}
c.Logger.Info("Generated cvs", "classCount", classCount)
// TODO update the run script
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
if err != nil {
return
}
defer f.Close()
c.Logger.Info("About to run python!")
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
if err != nil {
return
}
// Copy result around
result_path := path.Join("savedData", model.Id, "defs", definition_id)
if err = tmpl.Execute(f, AnyMap{
"Layers": got,
"Size": first.Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.Id,
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN,
"LoadPrev": load_prev,
"BaseModel": path.Join(getDir(), result_path, "base", "model.keras"),
"LastModelRunPath": path.Join(getDir(), result_path, "head", exp.Id, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
"Depth": classCount,
"StartPoint": 0,
}); err != nil {
return
}
// Run the command
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
if err != nil {
c.Logger.Warn("Python failed to run", "err", err, "out", string(out))
return
}
c.Logger.Info("Python finished running")
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
return
}
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
if err != nil {
return
}
defer accuracy_file.Close()
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
if err != nil {
return
}
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
if err != nil {
return
}
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy)
return
}
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
accuracy = 0
c.Logger.Warn("About to start training definition")
// Get untrained models heads
type ExpHead struct {
Id string
Start int `db:"range_start"`
End int `db:"range_end"`
}
// status = 2 (INIT) 3 (TRAINING)
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
if err != nil {
return
} else if len(heads) == 0 {
log.Error("Failed to get the exp head of the model")
return
} else if len(heads) != 1 {
log.Error("This training function can only train one model at the time")
err = errors.New("This training function can only train one model at the time")
return
}
exp := heads[0]
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
return
}
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
if err != nil {
return
}
defer layers.Close()
type layerrow struct {
LayerType int
Shape string
ExpType int
LayerNum int
}
got := []layerrow{}
i := 1
for layers.Next() {
@@ -358,9 +592,6 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return
}
row.LayerNum = i
if row.ExpType == 2 {
remove_top_count += 1
}
row.Shape = shapeToSize(row.Shape)
got = append(got, row)
i += 1
@@ -368,19 +599,18 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
got = append(got, layerrow{
LayerType: LAYER_DENSE,
Shape: fmt.Sprintf("%d", exp.end-exp.start+1),
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
ExpType: 2,
LayerNum: i,
})
// Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id)
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
defer removeAll(run_path, err)
classCount, err := generateCvsExp(c, run_path, model.Id, false)
if err != nil {
@@ -408,16 +638,14 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
"Layers": got,
"Size": got[0].Shape,
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
"HeadId": exp.id,
"HeadId": exp.Id,
"RunPath": run_path,
"ColorMode": model.ImageMode,
"Model": model,
"EPOCH_PER_RUN": EPOCH_PER_RUN,
"DefId": definition_id,
"LoadPrev": load_prev,
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
"SaveModelPath": path.Join(getDir(), result_path),
"RemoveTopCount": remove_top_count,
"Depth": classCount,
"StartPoint": 0,
}); err != nil {
@@ -453,6 +681,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
return
}
os.RemoveAll(run_path)
c.Logger.Info("Model finished training!", "accuracy", accuracy)
return
}
@@ -762,6 +991,12 @@ func trainModelExp(c *Context, model *BaseModel) {
return
}
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2;", MODEL_HEAD_STATUS_READY, def.Id)
if err != nil {
failed("Failed to train definition!")
return
}
finished = true
break
}
@@ -823,17 +1058,17 @@ func trainModelExp(c *Context, model *BaseModel) {
}
}
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
var dat JustId
err = GetDBOnce(c, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
if err == NotFoundError {
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("All definitions failed to train! And Failed to set class status")
return
}
failed("All definitions failed to train!")
return
} else if err != nil {
@@ -863,10 +1098,23 @@ func trainModelExp(c *Context, model *BaseModel) {
}
if err = splitModel(c, model); err != nil {
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to split the model! And Failed to set class status")
return
}
failed("Failed to split the model")
return
}
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
// There should only be one def availabale
def := JustId{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
@@ -884,25 +1132,22 @@ func trainModelExp(c *Context, model *BaseModel) {
func splitModel(c *Context, model *BaseModel) (err error) {
def := JustId{}
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
return
}
head := JustId{}
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
return
}
// Generate run folder
run_path := path.Join("/tmp", model.Id, "defs", def.Id)
run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id+"-split")
err = os.MkdirAll(run_path, os.ModePerm)
if err != nil {
return
}
defer removeAll(run_path, err)
// Create python script
f, err := os.Create(path.Join(run_path, "run.py"))
@@ -970,8 +1215,8 @@ func splitModel(c *Context, model *BaseModel) (err error) {
return
}
os.RemoveAll(run_path)
c.Logger.Info("Python finished running")
return
}
@@ -1141,6 +1386,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
}
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
if err != nil {
@@ -1306,6 +1552,244 @@ func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy
return nil
}
func ResetClasses(c *Context, model *BaseModel) {
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.Logger.Error("Error while reseting the classes", "error", err)
}
}
func trainExpandable(c *Context, model *BaseModel) {
var err error = nil
failed := func(msg string) {
c.Logger.Error(msg, "err", err)
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
ResetClasses(c, model)
}
var definitions TrainModelRowUsables
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
if err != nil {
failed("Failed to get definitions")
return
}
if len(definitions) != 1 {
failed("There should only be one definition available!")
return
}
firstRound := true
def := definitions[0]
epoch := 0
for {
acc, err := trainDefinitionExp(c, model, def.Id, !firstRound)
if err != nil {
failed("Failed to train definition!")
return
}
epoch += EPOCH_PER_RUN
if float64(acc*100) >= float64(def.Acuracy) {
c.Logger.Info("Found a definition that reaches target_accuracy!")
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING)
if err != nil {
failed("Failed to train definition!")
return
}
break
} else if def.Epoch > MAX_EPOCH {
failed(fmt.Sprintf("Failed to train definition! Accuracy less %f < %d\n", acc*100, def.TargetAccuracy))
return
}
}
// Set the class status to trained
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
if err != nil {
failed("Failed to set class status")
return
}
ModelUpdateStatus(c, model.Id, READY)
}
func trainRetrain(c *Context, model *BaseModel, defId string) {
var err error
failed := func() {
ResetClasses(c, model)
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
return
}
// This is something I have to check
acc, err := trainDefinitionExpandExp(c, model, defId, false)
if err != nil {
c.Logger.Error("Failed to retrain the model", "err", err)
failed()
return
}
c.Logger.Info("Retrained model", "accuracy", acc)
// TODO check accuracy
err = UpdateStatus(c, "models", model.Id, READY)
if err != nil {
failed()
return
}
c.Logger.Info("model updaded")
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
if err != nil {
c.Logger.Error("Error while updating the classes", "error", err)
failed()
return
}
}
func handleRetrain(c *Context) *Error {
var err error = nil
if !c.CheckAuthLevel(1) {
return nil
}
var dat JustId
if err_ := c.ToJSON(&dat); err_ != nil {
return err_
}
if dat.Id == "" {
return c.JsonBadRequest("Please provide a id")
}
model, err := GetBaseModel(c.Db, dat.Id)
if err == ModelNotFoundError {
return c.JsonBadRequest("Model not found")
} else if err != nil {
return c.Error500(err)
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
return c.JsonBadRequest("Model in invalid status for re-training")
}
c.Logger.Info("Expanding definitions for models", "id", model.Id)
classesUpdated := false
failed := func() *Error {
if classesUpdated {
ResetClasses(c, model)
}
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
c.Logger.Error("Failed to retrain", "err", err)
// TODO improve this response
return c.Error500(err)
}
var def struct {
Id string
TargetAccuracy int `db:"target_accuracy"`
}
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
if err != nil {
return failed()
}
type C struct {
Id string
ClassOrder int `db:"class_order"`
}
err = c.StartTx()
if err != nil {
return failed()
}
classes, err := GetDbMultitple[C](
c,
"model_classes where model_id=$1 and status=$2 order by class_order asc",
model.Id,
MODEL_CLASS_STATUS_TO_TRAIN,
)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
if len(classes) == 0 {
c.Logger.Error("No classes are available!")
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
err = c.CommitTx()
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
classesUpdated = true
}
_, err = CreateExpModelHead(c, def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT)
if err != nil {
return failed()
}
go trainRetrain(c, model, def.Id)
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
if err != nil {
fmt.Println("Failed to update model status")
fmt.Println(err)
// TODO improve this response
return c.Error500(err)
}
return c.SendJSON(model.Id)
}
func handleTrain(handle *Handle) {
handle.Post("/models/train", func(c *Context) *Error {
if !c.CheckAuthLevel(1) {
@@ -1374,6 +1858,8 @@ func handleTrain(handle *Handle) {
return c.SendJSON(model.Id)
})
handle.Post("/model/train/retrain", handleRetrain)
handle.Get("/model/epoch/update", func(c *Context) *Error {
f := c.R.URL.Query()

View File

@@ -23,13 +23,16 @@ const (
FAILED_PREPARING_ZIP_FILE = -2
FAILED_PREPARING = -1
PREPARING = 1
CONFIRM_PRE_TRAINING = 2
PREPARING_ZIP_FILE = 3
TRAINING = 4
READY = 5
READY_ALTERATION = 6
READY_FAILED = -6
PREPARING = 1
CONFIRM_PRE_TRAINING = 2
PREPARING_ZIP_FILE = 3
TRAINING = 4
READY = 5
READY_ALTERATION = 6
READY_ALTERATION_FAILED = -6
READY_RETRAIN = 7
READY_RETRAIN_FAILED = -7
)
type ModelDefinitionStatus int
@@ -62,6 +65,16 @@ const (
MODEL_CLASS_STATUS_TRAINED = 3
)
type ModelHeadStatus int
const (
MODEL_HEAD_STATUS_PRE_INIT ModelHeadStatus = 1
MODEL_HEAD_STATUS_INIT = 2
MODEL_HEAD_STATUS_TRAINING = 3
MODEL_HEAD_STATUS_TRAINED = 4
MODEL_HEAD_STATUS_READY = 5
)
var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {

View File

@@ -89,6 +89,7 @@ func (x *Handle) handleGets(context *Context) {
return
}
}
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context)
}
@@ -99,6 +100,7 @@ func (x *Handle) handlePosts(context *Context) {
return
}
}
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context)
}
@@ -109,6 +111,7 @@ func (x *Handle) handleDeletes(context *Context) {
return
}
}
context.ShowMessage = false
handleError(&Error{404, "Endpoint not found"}, context)
}
@@ -127,12 +130,58 @@ func (c *Context) CheckAuthLevel(authLevel int) bool {
}
type Context struct {
Token *string
User *dbtypes.User
Logger *log.Logger
Db *sql.DB
Writer http.ResponseWriter
R *http.Request
Token *string
User *dbtypes.User
Logger *log.Logger
Db *sql.DB
Writer http.ResponseWriter
R *http.Request
Tx *sql.Tx
ShowMessage bool
}
func (c Context) Prepare(str string) (*sql.Stmt, error) {
if c.Tx == nil {
return c.Db.Prepare(str)
}
return c.Tx.Prepare(str)
}
var TransactionAlreadyStarted = errors.New("Transaction already started")
var TransactionNotStarted = errors.New("Transaction not started")
func (c *Context) StartTx() error {
if c.Tx != nil {
return TransactionAlreadyStarted
}
var err error = nil
c.Tx, err = c.Db.Begin()
return err
}
func (c *Context) CommitTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Commit()
if err != nil {
return err
}
c.Tx = nil
return nil
}
func (c *Context) RollbackTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Rollback()
if err != nil {
return err
}
c.Tx = nil
return nil
}
func (c Context) ToJSON(dat any) *Error {
@@ -200,14 +249,14 @@ func (c *Context) GetModelFromId(id_path string) (*BaseModel, *Error) {
return nil, c.Error500(err)
}
return model, nil
return model, nil
}
func ModelUpdateStatus(c *Context, id string, status int) {
_, err := c.Db.Exec("update models set status=$1 where id=$2;", status, id)
if err != nil {
c.Logger.Error("Failed to update model status", "err", err)
c.Logger.Warn("TODO Maybe handle better")
c.Logger.Error("Failed to update model status", "err", err)
c.Logger.Warn("TODO Maybe handle better")
}
}
@@ -270,6 +319,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
Db: handler.Db,
Writer: w,
R: r,
ShowMessage: true,
}, nil
}
@@ -278,7 +328,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
return nil, errors.Join(err, LogoffError)
}
return &Context{token, user, logger, handler.Db, w, r}, nil
return &Context{token, user, logger, handler.Db, w, r, nil, true}, nil
}
func contextlessLogoff(w http.ResponseWriter) {
@@ -457,20 +507,19 @@ func NewHandler(db *sql.DB) *Handle {
if r.Method == "GET" {
x.handleGets(context)
return
}
if r.Method == "POST" {
} else if r.Method == "POST" {
x.handlePosts(context)
return
}
if r.Method == "DELETE" {
} else if r.Method == "DELETE" {
x.handleDeletes(context)
return
}
if r.Method == "OPTIONS" {
return
}
panic("TODO handle method: " + r.Method)
} else if r.Method == "OPTIONS" {
// do nothing
} else {
panic("TODO handle method: " + r.Method)
}
if context.ShowMessage {
context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
}
})
return x

View File

@@ -189,6 +189,7 @@ type JustId struct { Id string }
type Generic struct{ reflect.Type }
var NotFoundError = errors.New("Not found")
var CouldNotInsert = errors.New("Could not insert")
func generateQuery(t reflect.Type) (query string, nargs int) {
nargs = t.NumField()
@@ -200,6 +201,10 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
if !ok {
name = field.Name;
}
if name == "__nil__" {
continue
}
query += strings.ToLower(name) + ","
}
@@ -214,7 +219,13 @@ func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([
query, nargs := generateQuery(t)
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
if err != nil {
return nil, err
}
defer db_query.Close()
rows, err := db_query.Query(args...)
if err != nil {
return nil, err
}
@@ -251,6 +262,43 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
return nil
}
func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) {
t := reflect.TypeOf(store).Elem()
query, nargs := generateQuery(t)
query2 := ""
for i := 0; i < nargs; i += 1 {
query2 += fmt.Sprintf("$%d,", i)
}
// Remove last quotation
query2 = query2[0 : len(query2)-1]
val := reflect.ValueOf(store).Elem()
scan_args := make([]interface{}, nargs);
for i := 0; i < nargs; i++ {
valueField := val.Field(i)
scan_args[i] = valueField.Interface()
}
rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...)
if err != nil {
return
}
defer rows.Close()
if !rows.Next() {
return "", CouldNotInsert
}
err = rows.Scan(&id)
if err != nil {
return
}
return
}
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem()