added the ability to expand the models
This commit is contained in:
parent
274d7d22aa
commit
de0b430467
@ -60,6 +60,7 @@ func HandleList(handle *Handle) {
|
|||||||
|
|
||||||
max_len := min(11, len(rows))
|
max_len := min(11, len(rows))
|
||||||
|
|
||||||
|
c.ShowMessage = false;
|
||||||
return c.SendJSON(ReturnType{
|
return c.SendJSON(ReturnType{
|
||||||
ImageList: rows[0:max_len],
|
ImageList: rows[0:max_len],
|
||||||
Page: page,
|
Page: page,
|
||||||
|
@ -156,7 +156,7 @@ func processZipFileExpand(c *Context, model *BaseModel) {
|
|||||||
|
|
||||||
failed := func(msg string) {
|
failed := func(msg string) {
|
||||||
c.Logger.Error(msg, "err", err)
|
c.Logger.Error(msg, "err", err)
|
||||||
ModelUpdateStatus(c, model.Id, READY_FAILED)
|
ModelUpdateStatus(c, model.Id, READY_ALTERATION_FAILED)
|
||||||
}
|
}
|
||||||
|
|
||||||
reader, err := zip.OpenReader(path.Join("savedData", model.Id, "expand_data.zip"))
|
reader, err := zip.OpenReader(path.Join("savedData", model.Id, "expand_data.zip"))
|
||||||
@ -202,8 +202,19 @@ func processZipFileExpand(c *Context, model *BaseModel) {
|
|||||||
|
|
||||||
ids := map[string]string{}
|
ids := map[string]string{}
|
||||||
|
|
||||||
|
var baseOrder struct {
|
||||||
|
Order int `db:"class_order"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = GetDBOnce(c, &baseOrder, "model_classes where model_id=$1 order by class_order desc;", model.Id)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to get the last class_order")
|
||||||
|
}
|
||||||
|
|
||||||
|
base := baseOrder.Order + 1
|
||||||
|
|
||||||
for i, name := range training {
|
for i, name := range training {
|
||||||
id, err := model_classes.CreateClass(c.Db, model.Id, i, name)
|
id, err := model_classes.CreateClass(c.Db, model.Id, base + i, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
failed(fmt.Sprintf("Failed to create class '%s' on db\n", name))
|
failed(fmt.Sprintf("Failed to create class '%s' on db\n", name))
|
||||||
return
|
return
|
||||||
@ -416,7 +427,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
|
|
||||||
delete_path := "base_data.zip"
|
delete_path := "base_data.zip"
|
||||||
|
|
||||||
if model.Status == READY_FAILED {
|
if model.Status == READY_ALTERATION_FAILED {
|
||||||
delete_path = "expand_data.zip"
|
delete_path = "expand_data.zip"
|
||||||
} else if model.Status != FAILED_PREPARING_ZIP_FILE {
|
} else if model.Status != FAILED_PREPARING_ZIP_FILE {
|
||||||
return c.JsonBadRequest("Model not in the correct status")
|
return c.JsonBadRequest("Model not in the correct status")
|
||||||
@ -427,7 +438,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.Status != READY_FAILED {
|
if model.Status != READY_ALTERATION_FAILED {
|
||||||
err = os.RemoveAll(path.Join("savedData", model.Id, "data"))
|
err = os.RemoveAll(path.Join("savedData", model.Id, "data"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
@ -436,7 +447,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
c.Logger.Warn("Handle failed to remove the savedData when deleteing the zip file while expanding")
|
c.Logger.Warn("Handle failed to remove the savedData when deleteing the zip file while expanding")
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.Status != READY_FAILED {
|
if model.Status != READY_ALTERATION_FAILED {
|
||||||
_, err = handle.Db.Exec("delete from model_classes where model_id=$1;", model.Id)
|
_, err = handle.Db.Exec("delete from model_classes where model_id=$1;", model.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
@ -448,7 +459,7 @@ func handleDataUpload(handle *Handle) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.Status != READY_FAILED {
|
if model.Status != READY_ALTERATION_FAILED {
|
||||||
ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
|
ModelUpdateStatus(c, model.Id, CONFIRM_PRE_TRAINING)
|
||||||
} else {
|
} else {
|
||||||
ModelUpdateStatus(c, model.Id, READY)
|
ModelUpdateStatus(c, model.Id, READY)
|
||||||
|
@ -29,7 +29,7 @@ func deleteModelJSON(c *Context, id string) *Error {
|
|||||||
|
|
||||||
func handleDelete(handle *Handle) {
|
func handleDelete(handle *Handle) {
|
||||||
handle.Delete("/models/delete", func(c *Context) *Error {
|
handle.Delete("/models/delete", func(c *Context) *Error {
|
||||||
if c.CheckAuthLevel(1) {
|
if !c.CheckAuthLevel(1) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var dat struct {
|
var dat struct {
|
||||||
@ -66,6 +66,10 @@ func handleDelete(handle *Handle) {
|
|||||||
|
|
||||||
case READY:
|
case READY:
|
||||||
fallthrough
|
fallthrough
|
||||||
|
case READY_RETRAIN_FAILED:
|
||||||
|
fallthrough
|
||||||
|
case READY_ALTERATION_FAILED:
|
||||||
|
fallthrough
|
||||||
case CONFIRM_PRE_TRAINING:
|
case CONFIRM_PRE_TRAINING:
|
||||||
if dat.Name == nil {
|
if dat.Name == nil {
|
||||||
return c.JsonBadRequest("Provided name does not match the model name")
|
return c.JsonBadRequest("Provided name does not match the model name")
|
||||||
|
@ -41,6 +41,7 @@ func handleEdit(handle *Handle) {
|
|||||||
NumberOfInvalidImages int `json:"number_of_invalid_images"`
|
NumberOfInvalidImages int `json:"number_of_invalid_images"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.ShowMessage = false;
|
||||||
return c.SendJSON(ReturnType{
|
return c.SendJSON(ReturnType{
|
||||||
Classes: cls,
|
Classes: cls,
|
||||||
HasData: has_data,
|
HasData: has_data,
|
||||||
@ -180,6 +181,7 @@ func handleEdit(handle *Handle) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.ShowMessage = false;
|
||||||
return c.SendJSON(defsToReturn)
|
return c.SendJSON(defsToReturn)
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -188,10 +190,6 @@ func handleEdit(handle *Handle) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if !c.CheckAuthLevel(1) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := GetIdFromUrl(c, "id")
|
id, err := GetIdFromUrl(c, "id")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return c.JsonBadRequest("Model not found")
|
return c.JsonBadRequest("Model not found")
|
||||||
@ -216,6 +214,7 @@ func handleEdit(handle *Handle) {
|
|||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.ShowMessage = false
|
||||||
return c.SendJSON(model)
|
return c.SendJSON(model)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -87,6 +87,8 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("test", "count", len(heads))
|
||||||
|
|
||||||
var vmax float32 = 0.0
|
var vmax float32 = 0.0
|
||||||
|
|
||||||
for _, element := range heads {
|
for _, element := range heads {
|
||||||
@ -95,12 +97,14 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
|||||||
results := head_model.Exec([]tf.Output{
|
results := head_model.Exec([]tf.Output{
|
||||||
head_model.Op("StatefulPartitionedCall", 0),
|
head_model.Op("StatefulPartitionedCall", 0),
|
||||||
}, map[tf.Output]*tf.Tensor{
|
}, map[tf.Output]*tf.Tensor{
|
||||||
head_model.Op("serving_default_input_2", 0): base_results[0],
|
head_model.Op("serving_default_head_input", 0): base_results[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
var predictions = results[0].Value().([][]float32)[0]
|
var predictions = results[0].Value().([][]float32)[0]
|
||||||
|
|
||||||
|
|
||||||
for i, v := range predictions {
|
for i, v := range predictions {
|
||||||
|
c.Logger.Info("predictions", "class", i, "preds", v)
|
||||||
if v > vmax {
|
if v > vmax {
|
||||||
order = element.Range_start + i
|
order = element.Range_start + i
|
||||||
vmax = v
|
vmax = v
|
||||||
@ -111,7 +115,7 @@ func runModelExp(c *Context, model *BaseModel, def_id string, inputImage *tf.Ten
|
|||||||
// TODO runthe head model
|
// TODO runthe head model
|
||||||
confidence = vmax
|
confidence = vmax
|
||||||
|
|
||||||
c.Logger.Info("Got", "heads", len(heads))
|
c.Logger.Info("Got", "heads", len(heads), "order", order, "vmax", vmax)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -155,7 +159,7 @@ func handleRun(handle *Handle) {
|
|||||||
return c.Error500(err)
|
return c.Error500(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if model.Status != READY {
|
if model.Status != READY && model.Status != READY_RETRAIN && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION && model.Status != READY_ALTERATION_FAILED {
|
||||||
return c.JsonBadRequest("Model not ready to run images")
|
return c.JsonBadRequest("Model not ready to run images")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,13 +36,16 @@ func getDir() string {
|
|||||||
return dir
|
return dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function creates a new model_definition
|
||||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||||
id = ""
|
id = ""
|
||||||
|
|
||||||
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
|
rows, err := db.Query("insert into model_definition (model_id, target_accuracy) values ($1, $2) returning id;", model_id, target_accuracy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
return id, errors.New("Something wrong!")
|
return id, errors.New("Something wrong!")
|
||||||
}
|
}
|
||||||
@ -72,17 +75,14 @@ func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type
|
|||||||
|
|
||||||
func generateCvs(c *Context, run_path string, model_id string) (count int, err error) {
|
func generateCvs(c *Context, run_path string, model_id string) (count int, err error) {
|
||||||
|
|
||||||
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1;", model_id)
|
var co struct {
|
||||||
|
Count int `db:"count(*)"`
|
||||||
|
}
|
||||||
|
err = GetDBOnce(c, &co, "model_classes where model_id=$1;", model_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer classes.Close()
|
count = co.Count
|
||||||
if !classes.Next() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err = classes.Scan(&count); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING)
|
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2;", model_id, model_classes.DATA_POINT_MODE_TRAINING)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -121,19 +121,14 @@ func setModelClassStatus(c *Context, status ModelClassStatus, filter string, arg
|
|||||||
|
|
||||||
func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) {
|
func generateCvsExp(c *Context, run_path string, model_id string, doPanic bool) (count int, err error) {
|
||||||
|
|
||||||
classes, err := c.Db.Query("select count(*) from model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
var co struct {
|
||||||
|
Count int `db:"count(*)"`
|
||||||
|
}
|
||||||
|
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer classes.Close()
|
count = co.Count
|
||||||
|
|
||||||
if !classes.Next() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = classes.Scan(&count); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
|
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
|
||||||
@ -214,7 +209,6 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer removeAll(run_path, err)
|
|
||||||
|
|
||||||
classCount, err := generateCvs(c, run_path, model.Id)
|
classCount, err := generateCvs(c, run_path, model.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -283,55 +277,125 @@ func trainDefinition(c *Context, model *BaseModel, definition_id string, load_pr
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
os.RemoveAll(run_path)
|
||||||
|
|
||||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeAll(path string, err error) {
|
func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
|
||||||
if err != nil {
|
|
||||||
os.RemoveAll(path)
|
var co struct {
|
||||||
|
Count int `db:"count(*)"`
|
||||||
}
|
}
|
||||||
|
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Logger.Info("test here", "count", co)
|
||||||
|
count_re = co.Count
|
||||||
|
count := co.Count
|
||||||
|
|
||||||
|
if count == 0 {
|
||||||
|
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINING, "model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TO_TRAIN)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
} else if doPanic {
|
||||||
|
return 0, errors.New("No model classes available")
|
||||||
|
}
|
||||||
|
return generateCvsExpandExp(c, run_path, model_id, offset, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer data.Close()
|
||||||
|
|
||||||
|
f, err := os.Create(path.Join(run_path, "train.csv"))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
f.Write([]byte("Id,Index\n"))
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for data.Next() {
|
||||||
|
var id string
|
||||||
|
var class_order int
|
||||||
|
var file_path string
|
||||||
|
if err = data.Scan(&id, &class_order, &file_path); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if file_path == "id://" {
|
||||||
|
f.Write([]byte(id + "," + strconv.Itoa(class_order-offset) + "\n"))
|
||||||
|
} else {
|
||||||
|
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||||
|
}
|
||||||
|
count += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// This is to load some extra data so that the model has more things to train on
|
||||||
|
//
|
||||||
|
|
||||||
|
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer data_other.Close()
|
||||||
|
|
||||||
|
for data_other.Next() {
|
||||||
|
var id string
|
||||||
|
var class_order int
|
||||||
|
var file_path string
|
||||||
|
if err = data_other.Scan(&id, &class_order, &file_path); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if file_path == "id://" {
|
||||||
|
f.Write([]byte(id + "," + strconv.Itoa(-1) + "\n"))
|
||||||
|
} else {
|
||||||
|
return count, errors.New("TODO generateCvs to file_path " + file_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
|
|
||||||
c.Logger.Warn("About to start training definition")
|
c.Logger.Warn("About to retrain model")
|
||||||
|
|
||||||
// Get untrained models heads
|
// Get untrained models heads
|
||||||
|
|
||||||
// Status = 2 (INIT)
|
type ExpHead struct {
|
||||||
rows, err := c.Db.Query("select id, range_start, range_end from exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
Id string
|
||||||
|
Start int `db:"range_start"`
|
||||||
|
End int `db:"range_end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// status = 2 (INIT) 3 (TRAINING)
|
||||||
|
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
} else if len(heads) == 0 {
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
type ExpHead struct {
|
|
||||||
id string
|
|
||||||
start int
|
|
||||||
end int
|
|
||||||
}
|
|
||||||
|
|
||||||
exp := ExpHead{}
|
|
||||||
|
|
||||||
if rows.Next() {
|
|
||||||
if err = rows.Scan(&exp.id, &exp.start, &exp.end); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Error("Failed to get the exp head of the model")
|
log.Error("Failed to get the exp head of the model")
|
||||||
err = errors.New("Failed to get the exp head of the model")
|
|
||||||
return
|
return
|
||||||
}
|
} else if len(heads) != 1 {
|
||||||
|
|
||||||
if rows.Next() {
|
|
||||||
log.Error("This training function can only train one model at the time")
|
log.Error("This training function can only train one model at the time")
|
||||||
err = errors.New("This training function can only train one model at the time")
|
err = errors.New("This training function can only train one model at the time")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
UpdateStatus(c, "exp_model_head", exp.id, MODEL_DEFINITION_STATUS_TRAINING)
|
exp := heads[0]
|
||||||
|
|
||||||
|
c.Logger.Info("Got exp head", "head", exp)
|
||||||
|
|
||||||
|
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -348,8 +412,178 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
|
|
||||||
got := []layerrow{}
|
got := []layerrow{}
|
||||||
|
|
||||||
remove_top_count := 1
|
i := 1
|
||||||
|
var last *layerrow = nil
|
||||||
|
got_2 := false
|
||||||
|
|
||||||
|
var first *layerrow = nil
|
||||||
|
|
||||||
|
for layers.Next() {
|
||||||
|
var row = layerrow{}
|
||||||
|
if err = layers.Scan(&row.LayerType, &row.Shape, &row.ExpType); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep track of the first layer so we can keep the size of the image
|
||||||
|
if first == nil {
|
||||||
|
first = &row
|
||||||
|
}
|
||||||
|
|
||||||
|
row.LayerNum = i
|
||||||
|
row.Shape = shapeToSize(row.Shape)
|
||||||
|
if row.ExpType == 2 {
|
||||||
|
if !got_2 {
|
||||||
|
got = append(got, *last)
|
||||||
|
got_2 = true
|
||||||
|
}
|
||||||
|
got = append(got, row)
|
||||||
|
}
|
||||||
|
last = &row
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
got = append(got, layerrow{
|
||||||
|
LayerType: LAYER_DENSE,
|
||||||
|
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
|
||||||
|
ExpType: 2,
|
||||||
|
LayerNum: i,
|
||||||
|
})
|
||||||
|
|
||||||
|
c.Logger.Info("Got layers", "layers", got)
|
||||||
|
|
||||||
|
// Generate run folder
|
||||||
|
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
|
||||||
|
|
||||||
|
err = os.MkdirAll(run_path, os.ModePerm)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
classCount, err := generateCvsExpandExp(c, run_path, model.Id, exp.Start, false)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("Generated cvs", "classCount", classCount)
|
||||||
|
|
||||||
|
// TODO update the run script
|
||||||
|
|
||||||
|
// Create python script
|
||||||
|
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
c.Logger.Info("About to run python!")
|
||||||
|
|
||||||
|
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy result around
|
||||||
|
result_path := path.Join("savedData", model.Id, "defs", definition_id)
|
||||||
|
|
||||||
|
if err = tmpl.Execute(f, AnyMap{
|
||||||
|
"Layers": got,
|
||||||
|
"Size": first.Shape,
|
||||||
|
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||||
|
"HeadId": exp.Id,
|
||||||
|
"RunPath": run_path,
|
||||||
|
"ColorMode": model.ImageMode,
|
||||||
|
"Model": model,
|
||||||
|
"EPOCH_PER_RUN": EPOCH_PER_RUN,
|
||||||
|
"LoadPrev": load_prev,
|
||||||
|
"BaseModel": path.Join(getDir(), result_path, "base", "model.keras"),
|
||||||
|
"LastModelRunPath": path.Join(getDir(), result_path, "head", exp.Id, "model.keras"),
|
||||||
|
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
|
||||||
|
"Depth": classCount,
|
||||||
|
"StartPoint": 0,
|
||||||
|
}); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the command
|
||||||
|
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Warn("Python failed to run", "err", err, "out", string(out))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("Python finished running")
|
||||||
|
|
||||||
|
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accuracy_file, err := os.Open(path.Join(run_path, "accuracy.val"))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer accuracy_file.Close()
|
||||||
|
|
||||||
|
accuracy_file_bytes, err := io.ReadAll(accuracy_file)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
accuracy, err = strconv.ParseFloat(string(accuracy_file_bytes), 64)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
os.RemoveAll(run_path)
|
||||||
|
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||||
|
accuracy = 0
|
||||||
|
|
||||||
|
c.Logger.Warn("About to start training definition")
|
||||||
|
|
||||||
|
// Get untrained models heads
|
||||||
|
|
||||||
|
type ExpHead struct {
|
||||||
|
Id string
|
||||||
|
Start int `db:"range_start"`
|
||||||
|
End int `db:"range_end"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// status = 2 (INIT) 3 (TRAINING)
|
||||||
|
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
} else if len(heads) == 0 {
|
||||||
|
log.Error("Failed to get the exp head of the model")
|
||||||
|
return
|
||||||
|
} else if len(heads) != 1 {
|
||||||
|
log.Error("This training function can only train one model at the time")
|
||||||
|
err = errors.New("This training function can only train one model at the time")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exp := heads[0]
|
||||||
|
|
||||||
|
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer layers.Close()
|
||||||
|
|
||||||
|
type layerrow struct {
|
||||||
|
LayerType int
|
||||||
|
Shape string
|
||||||
|
ExpType int
|
||||||
|
LayerNum int
|
||||||
|
}
|
||||||
|
|
||||||
|
got := []layerrow{}
|
||||||
i := 1
|
i := 1
|
||||||
|
|
||||||
for layers.Next() {
|
for layers.Next() {
|
||||||
@ -358,9 +592,6 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
row.LayerNum = i
|
row.LayerNum = i
|
||||||
if row.ExpType == 2 {
|
|
||||||
remove_top_count += 1
|
|
||||||
}
|
|
||||||
row.Shape = shapeToSize(row.Shape)
|
row.Shape = shapeToSize(row.Shape)
|
||||||
got = append(got, row)
|
got = append(got, row)
|
||||||
i += 1
|
i += 1
|
||||||
@ -368,19 +599,18 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
|
|
||||||
got = append(got, layerrow{
|
got = append(got, layerrow{
|
||||||
LayerType: LAYER_DENSE,
|
LayerType: LAYER_DENSE,
|
||||||
Shape: fmt.Sprintf("%d", exp.end-exp.start+1),
|
Shape: fmt.Sprintf("%d", exp.End-exp.Start+1),
|
||||||
ExpType: 2,
|
ExpType: 2,
|
||||||
LayerNum: i,
|
LayerNum: i,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Generate run folder
|
// Generate run folder
|
||||||
run_path := path.Join("/tmp", model.Id, "defs", definition_id)
|
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id)
|
||||||
|
|
||||||
err = os.MkdirAll(run_path, os.ModePerm)
|
err = os.MkdirAll(run_path, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer removeAll(run_path, err)
|
|
||||||
|
|
||||||
classCount, err := generateCvsExp(c, run_path, model.Id, false)
|
classCount, err := generateCvsExp(c, run_path, model.Id, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -408,16 +638,14 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
"Layers": got,
|
"Layers": got,
|
||||||
"Size": got[0].Shape,
|
"Size": got[0].Shape,
|
||||||
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
"DataDir": path.Join(getDir(), "savedData", model.Id, "data"),
|
||||||
"HeadId": exp.id,
|
"HeadId": exp.Id,
|
||||||
"RunPath": run_path,
|
"RunPath": run_path,
|
||||||
"ColorMode": model.ImageMode,
|
"ColorMode": model.ImageMode,
|
||||||
"Model": model,
|
"Model": model,
|
||||||
"EPOCH_PER_RUN": EPOCH_PER_RUN,
|
"EPOCH_PER_RUN": EPOCH_PER_RUN,
|
||||||
"DefId": definition_id,
|
|
||||||
"LoadPrev": load_prev,
|
"LoadPrev": load_prev,
|
||||||
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
"LastModelRunPath": path.Join(getDir(), result_path, "model.keras"),
|
||||||
"SaveModelPath": path.Join(getDir(), result_path),
|
"SaveModelPath": path.Join(getDir(), result_path),
|
||||||
"RemoveTopCount": remove_top_count,
|
|
||||||
"Depth": classCount,
|
"Depth": classCount,
|
||||||
"StartPoint": 0,
|
"StartPoint": 0,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@ -453,6 +681,7 @@ func trainDefinitionExp(c *Context, model *BaseModel, definition_id string, load
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
os.RemoveAll(run_path)
|
||||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -762,6 +991,12 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2;", MODEL_HEAD_STATUS_READY, def.Id)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to train definition!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
finished = true
|
finished = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -823,17 +1058,17 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the class status to trained
|
|
||||||
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
|
||||||
if err != nil {
|
|
||||||
failed("Failed to set class status")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var dat JustId
|
var dat JustId
|
||||||
|
|
||||||
err = GetDBOnce(c, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
|
err = GetDBOnce(c, &dat, "model_definition where model_id=$1 and status=$2 order by accuracy desc limit 1;", model.Id, MODEL_DEFINITION_STATUS_TRANIED)
|
||||||
if err == NotFoundError {
|
if err == NotFoundError {
|
||||||
|
// Set the class status to trained
|
||||||
|
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
failed("All definitions failed to train! And Failed to set class status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
failed("All definitions failed to train!")
|
failed("All definitions failed to train!")
|
||||||
return
|
return
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@ -863,10 +1098,23 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err = splitModel(c, model); err != nil {
|
if err = splitModel(c, model); err != nil {
|
||||||
|
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to split the model! And Failed to set class status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
failed("Failed to split the model")
|
failed("Failed to split the model")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the class status to trained
|
||||||
|
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to set class status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// There should only be one def availabale
|
// There should only be one def availabale
|
||||||
def := JustId{}
|
def := JustId{}
|
||||||
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
||||||
@ -884,25 +1132,22 @@ func trainModelExp(c *Context, model *BaseModel) {
|
|||||||
func splitModel(c *Context, model *BaseModel) (err error) {
|
func splitModel(c *Context, model *BaseModel) (err error) {
|
||||||
|
|
||||||
def := JustId{}
|
def := JustId{}
|
||||||
|
|
||||||
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
if err = GetDBOnce(c, &def, "model_definition where model_id=$1", model.Id); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
head := JustId{}
|
head := JustId{}
|
||||||
|
|
||||||
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
|
if err = GetDBOnce(c, &head, "exp_model_head where def_id=$1", def.Id); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate run folder
|
// Generate run folder
|
||||||
run_path := path.Join("/tmp", model.Id, "defs", def.Id)
|
run_path := path.Join("/tmp", model.Id+"-defs-"+def.Id+"-split")
|
||||||
|
|
||||||
err = os.MkdirAll(run_path, os.ModePerm)
|
err = os.MkdirAll(run_path, os.ModePerm)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer removeAll(run_path, err)
|
|
||||||
|
|
||||||
// Create python script
|
// Create python script
|
||||||
f, err := os.Create(path.Join(run_path, "run.py"))
|
f, err := os.Create(path.Join(run_path, "run.py"))
|
||||||
@ -970,8 +1215,8 @@ func splitModel(c *Context, model *BaseModel) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
os.RemoveAll(run_path)
|
||||||
c.Logger.Info("Python finished running")
|
c.Logger.Info("Python finished running")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1141,6 +1386,7 @@ func generateDefinitions(c *Context, model *BaseModel, target_accuracy int, numb
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
|
func CreateExpModelHead(c *Context, def_id string, range_start int, range_end int, status ModelDefinitionStatus) (id string, err error) {
|
||||||
|
|
||||||
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
|
rows, err := c.Db.Query("insert into exp_model_head (def_id, range_start, range_end, status) values ($1, $2, $3, $4) returning id", def_id, range_start, range_end, status)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1306,6 +1552,244 @@ func generateExpandableDefinitions(c *Context, model *BaseModel, target_accuracy
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ResetClasses(c *Context, model *BaseModel) {
|
||||||
|
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Error while reseting the classes", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func trainExpandable(c *Context, model *BaseModel) {
|
||||||
|
var err error = nil
|
||||||
|
|
||||||
|
failed := func(msg string) {
|
||||||
|
c.Logger.Error(msg, "err", err)
|
||||||
|
ModelUpdateStatus(c, model.Id, FAILED_TRAINING)
|
||||||
|
ResetClasses(c, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
var definitions TrainModelRowUsables
|
||||||
|
|
||||||
|
definitions, err = GetDbMultitple[TrainModelRowUsable](c, "model_definition where status=$1 and model_id=$2", MODEL_DEFINITION_STATUS_READY, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to get definitions")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(definitions) != 1 {
|
||||||
|
failed("There should only be one definition available!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
firstRound := true
|
||||||
|
def := definitions[0]
|
||||||
|
epoch := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
acc, err := trainDefinitionExp(c, model, def.Id, !firstRound)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to train definition!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
epoch += EPOCH_PER_RUN
|
||||||
|
|
||||||
|
if float64(acc*100) >= float64(def.Acuracy) {
|
||||||
|
c.Logger.Info("Found a definition that reaches target_accuracy!")
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update exp_model_head set status=$1 where def_id=$2 and status=$3;", MODEL_HEAD_STATUS_READY, def.Id, MODEL_HEAD_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to train definition!")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else if def.Epoch > MAX_EPOCH {
|
||||||
|
failed(fmt.Sprintf("Failed to train definition! Accuracy less %f < %d\n", acc*100, def.TargetAccuracy))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the class status to trained
|
||||||
|
err = setModelClassStatus(c, MODEL_CLASS_STATUS_TRAINED, "model_id=$1 and status=$2;", model.Id, MODEL_CLASS_STATUS_TRAINING)
|
||||||
|
if err != nil {
|
||||||
|
failed("Failed to set class status")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelUpdateStatus(c, model.Id, READY)
|
||||||
|
}
|
||||||
|
|
||||||
|
func trainRetrain(c *Context, model *BaseModel, defId string) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
failed := func() {
|
||||||
|
ResetClasses(c, model)
|
||||||
|
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||||
|
c.Logger.Error("Failed to retrain", "err", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is something I have to check
|
||||||
|
acc, err := trainDefinitionExpandExp(c, model, defId, false)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Failed to retrain the model", "err", err)
|
||||||
|
failed()
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
c.Logger.Info("Retrained model", "accuracy", acc)
|
||||||
|
|
||||||
|
// TODO check accuracy
|
||||||
|
|
||||||
|
err = UpdateStatus(c, "models", model.Id, READY)
|
||||||
|
if err != nil {
|
||||||
|
failed()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("model updaded")
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
c.Logger.Error("Error while updating the classes", "error", err)
|
||||||
|
failed()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleRetrain(c *Context) *Error {
|
||||||
|
var err error = nil
|
||||||
|
if !c.CheckAuthLevel(1) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var dat JustId
|
||||||
|
|
||||||
|
if err_ := c.ToJSON(&dat); err_ != nil {
|
||||||
|
return err_
|
||||||
|
}
|
||||||
|
|
||||||
|
if dat.Id == "" {
|
||||||
|
return c.JsonBadRequest("Please provide a id")
|
||||||
|
}
|
||||||
|
|
||||||
|
model, err := GetBaseModel(c.Db, dat.Id)
|
||||||
|
if err == ModelNotFoundError {
|
||||||
|
return c.JsonBadRequest("Model not found")
|
||||||
|
} else if err != nil {
|
||||||
|
return c.Error500(err)
|
||||||
|
} else if model.Status != READY && model.Status != READY_RETRAIN_FAILED && model.Status != READY_ALTERATION_FAILED {
|
||||||
|
return c.JsonBadRequest("Model in invalid status for re-training")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Logger.Info("Expanding definitions for models", "id", model.Id)
|
||||||
|
|
||||||
|
classesUpdated := false
|
||||||
|
|
||||||
|
failed := func() *Error {
|
||||||
|
if classesUpdated {
|
||||||
|
ResetClasses(c, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||||
|
c.Logger.Error("Failed to retrain", "err", err)
|
||||||
|
// TODO improve this response
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var def struct {
|
||||||
|
Id string
|
||||||
|
TargetAccuracy int `db:"target_accuracy"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = GetDBOnce(c, &def, "model_definition where model_id=$1;", model.Id)
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
type C struct {
|
||||||
|
Id string
|
||||||
|
ClassOrder int `db:"class_order"`
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.StartTx()
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
classes, err := GetDbMultitple[C](
|
||||||
|
c,
|
||||||
|
"model_classes where model_id=$1 and status=$2 order by class_order asc",
|
||||||
|
model.Id,
|
||||||
|
MODEL_CLASS_STATUS_TO_TRAIN,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(classes) == 0 {
|
||||||
|
c.Logger.Error("No classes are available!")
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
//Update the classes
|
||||||
|
{
|
||||||
|
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||||
|
err = err2
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.CommitTx()
|
||||||
|
if err != nil {
|
||||||
|
_err := c.RollbackTx()
|
||||||
|
if _err != nil {
|
||||||
|
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||||
|
}
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
classesUpdated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = CreateExpModelHead(c, def.Id, classes[0].ClassOrder, classes[len(classes)-1].ClassOrder, MODEL_DEFINITION_STATUS_INIT)
|
||||||
|
if err != nil {
|
||||||
|
return failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
go trainRetrain(c, model, def.Id)
|
||||||
|
|
||||||
|
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("Failed to update model status")
|
||||||
|
fmt.Println(err)
|
||||||
|
// TODO improve this response
|
||||||
|
return c.Error500(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.SendJSON(model.Id)
|
||||||
|
}
|
||||||
|
|
||||||
func handleTrain(handle *Handle) {
|
func handleTrain(handle *Handle) {
|
||||||
handle.Post("/models/train", func(c *Context) *Error {
|
handle.Post("/models/train", func(c *Context) *Error {
|
||||||
if !c.CheckAuthLevel(1) {
|
if !c.CheckAuthLevel(1) {
|
||||||
@ -1374,6 +1858,8 @@ func handleTrain(handle *Handle) {
|
|||||||
return c.SendJSON(model.Id)
|
return c.SendJSON(model.Id)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
handle.Post("/model/train/retrain", handleRetrain)
|
||||||
|
|
||||||
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
||||||
f := c.R.URL.Query()
|
f := c.R.URL.Query()
|
||||||
|
|
||||||
|
@ -29,7 +29,10 @@ const (
|
|||||||
TRAINING = 4
|
TRAINING = 4
|
||||||
READY = 5
|
READY = 5
|
||||||
READY_ALTERATION = 6
|
READY_ALTERATION = 6
|
||||||
READY_FAILED = -6
|
READY_ALTERATION_FAILED = -6
|
||||||
|
|
||||||
|
READY_RETRAIN = 7
|
||||||
|
READY_RETRAIN_FAILED = -7
|
||||||
)
|
)
|
||||||
|
|
||||||
type ModelDefinitionStatus int
|
type ModelDefinitionStatus int
|
||||||
@ -62,6 +65,16 @@ const (
|
|||||||
MODEL_CLASS_STATUS_TRAINED = 3
|
MODEL_CLASS_STATUS_TRAINED = 3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ModelHeadStatus int
|
||||||
|
|
||||||
|
const (
|
||||||
|
MODEL_HEAD_STATUS_PRE_INIT ModelHeadStatus = 1
|
||||||
|
MODEL_HEAD_STATUS_INIT = 2
|
||||||
|
MODEL_HEAD_STATUS_TRAINING = 3
|
||||||
|
MODEL_HEAD_STATUS_TRAINED = 4
|
||||||
|
MODEL_HEAD_STATUS_READY = 5
|
||||||
|
)
|
||||||
|
|
||||||
var ModelNotFoundError = errors.New("Model not found error")
|
var ModelNotFoundError = errors.New("Model not found error")
|
||||||
|
|
||||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||||
|
@ -89,6 +89,7 @@ func (x *Handle) handleGets(context *Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
context.ShowMessage = false
|
||||||
handleError(&Error{404, "Endpoint not found"}, context)
|
handleError(&Error{404, "Endpoint not found"}, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,6 +100,7 @@ func (x *Handle) handlePosts(context *Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
context.ShowMessage = false
|
||||||
handleError(&Error{404, "Endpoint not found"}, context)
|
handleError(&Error{404, "Endpoint not found"}, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -109,6 +111,7 @@ func (x *Handle) handleDeletes(context *Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
context.ShowMessage = false
|
||||||
handleError(&Error{404, "Endpoint not found"}, context)
|
handleError(&Error{404, "Endpoint not found"}, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,6 +136,52 @@ type Context struct {
|
|||||||
Db *sql.DB
|
Db *sql.DB
|
||||||
Writer http.ResponseWriter
|
Writer http.ResponseWriter
|
||||||
R *http.Request
|
R *http.Request
|
||||||
|
Tx *sql.Tx
|
||||||
|
ShowMessage bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Context) Prepare(str string) (*sql.Stmt, error) {
|
||||||
|
if c.Tx == nil {
|
||||||
|
return c.Db.Prepare(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.Tx.Prepare(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
var TransactionAlreadyStarted = errors.New("Transaction already started")
|
||||||
|
var TransactionNotStarted = errors.New("Transaction not started")
|
||||||
|
|
||||||
|
func (c *Context) StartTx() error {
|
||||||
|
if c.Tx != nil {
|
||||||
|
return TransactionAlreadyStarted
|
||||||
|
}
|
||||||
|
var err error = nil
|
||||||
|
c.Tx, err = c.Db.Begin()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) CommitTx() error {
|
||||||
|
if c.Tx == nil {
|
||||||
|
return TransactionNotStarted
|
||||||
|
}
|
||||||
|
err := c.Tx.Commit()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.Tx = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Context) RollbackTx() error {
|
||||||
|
if c.Tx == nil {
|
||||||
|
return TransactionNotStarted
|
||||||
|
}
|
||||||
|
err := c.Tx.Rollback()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.Tx = nil
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Context) ToJSON(dat any) *Error {
|
func (c Context) ToJSON(dat any) *Error {
|
||||||
@ -270,6 +319,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
|
|||||||
Db: handler.Db,
|
Db: handler.Db,
|
||||||
Writer: w,
|
Writer: w,
|
||||||
R: r,
|
R: r,
|
||||||
|
ShowMessage: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -278,7 +328,7 @@ func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseW
|
|||||||
return nil, errors.Join(err, LogoffError)
|
return nil, errors.Join(err, LogoffError)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Context{token, user, logger, handler.Db, w, r}, nil
|
return &Context{token, user, logger, handler.Db, w, r, nil, true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func contextlessLogoff(w http.ResponseWriter) {
|
func contextlessLogoff(w http.ResponseWriter) {
|
||||||
@ -457,20 +507,19 @@ func NewHandler(db *sql.DB) *Handle {
|
|||||||
|
|
||||||
if r.Method == "GET" {
|
if r.Method == "GET" {
|
||||||
x.handleGets(context)
|
x.handleGets(context)
|
||||||
return
|
} else if r.Method == "POST" {
|
||||||
}
|
|
||||||
if r.Method == "POST" {
|
|
||||||
x.handlePosts(context)
|
x.handlePosts(context)
|
||||||
return
|
} else if r.Method == "DELETE" {
|
||||||
}
|
|
||||||
if r.Method == "DELETE" {
|
|
||||||
x.handleDeletes(context)
|
x.handleDeletes(context)
|
||||||
return
|
} else if r.Method == "OPTIONS" {
|
||||||
}
|
// do nothing
|
||||||
if r.Method == "OPTIONS" {
|
} else {
|
||||||
return
|
|
||||||
}
|
|
||||||
panic("TODO handle method: " + r.Method)
|
panic("TODO handle method: " + r.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
if context.ShowMessage {
|
||||||
|
context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
@ -189,6 +189,7 @@ type JustId struct { Id string }
|
|||||||
type Generic struct{ reflect.Type }
|
type Generic struct{ reflect.Type }
|
||||||
|
|
||||||
var NotFoundError = errors.New("Not found")
|
var NotFoundError = errors.New("Not found")
|
||||||
|
var CouldNotInsert = errors.New("Could not insert")
|
||||||
|
|
||||||
func generateQuery(t reflect.Type) (query string, nargs int) {
|
func generateQuery(t reflect.Type) (query string, nargs int) {
|
||||||
nargs = t.NumField()
|
nargs = t.NumField()
|
||||||
@ -200,6 +201,10 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
name = field.Name;
|
name = field.Name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if name == "__nil__" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
query += strings.ToLower(name) + ","
|
query += strings.ToLower(name) + ","
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,7 +219,13 @@ func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([
|
|||||||
|
|
||||||
query, nargs := generateQuery(t)
|
query, nargs := generateQuery(t)
|
||||||
|
|
||||||
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer db_query.Close()
|
||||||
|
|
||||||
|
rows, err := db_query.Query(args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -251,6 +262,43 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) {
|
||||||
|
t := reflect.TypeOf(store).Elem()
|
||||||
|
|
||||||
|
query, nargs := generateQuery(t)
|
||||||
|
|
||||||
|
query2 := ""
|
||||||
|
for i := 0; i < nargs; i += 1 {
|
||||||
|
query2 += fmt.Sprintf("$%d,", i)
|
||||||
|
}
|
||||||
|
// Remove last quotation
|
||||||
|
query2 = query2[0 : len(query2)-1]
|
||||||
|
|
||||||
|
val := reflect.ValueOf(store).Elem()
|
||||||
|
scan_args := make([]interface{}, nargs);
|
||||||
|
for i := 0; i < nargs; i++ {
|
||||||
|
valueField := val.Field(i)
|
||||||
|
scan_args[i] = valueField.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
return "", CouldNotInsert
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
|
||||||
t := reflect.TypeOf(store).Elem()
|
t := reflect.TypeOf(store).Elem()
|
||||||
|
|
||||||
|
223
views/py/python_model_template_expand.py
Normal file
223
views/py/python_model_template_expand.py
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import random
|
||||||
|
import pandas as pd
|
||||||
|
from tensorflow import keras
|
||||||
|
from tensorflow.data import AUTOTUNE
|
||||||
|
from keras import layers, losses, optimizers
|
||||||
|
import requests
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class NotifyServerCallback(tf.keras.callbacks.Callback):
|
||||||
|
def on_epoch_end(self, epoch, log, *args, **kwargs):
|
||||||
|
requests.get(f'http://localhost:8000/api/model/head/epoch/update?epoch={epoch + 1}&accuracy={log["accuracy"]}&head_id={{.HeadId}}')
|
||||||
|
|
||||||
|
|
||||||
|
DATA_DIR = "{{ .DataDir }}"
|
||||||
|
image_size = ({{ .Size }})
|
||||||
|
|
||||||
|
df = pd.read_csv("{{ .RunPath }}/train.csv", dtype=str)
|
||||||
|
keys = tf.constant(df['Id'].dropna())
|
||||||
|
values = tf.constant(list(map(int, df['Index'].dropna())))
|
||||||
|
|
||||||
|
depth = {{ .Depth }}
|
||||||
|
diff = {{ .StartPoint }}
|
||||||
|
|
||||||
|
table = tf.lookup.StaticHashTable(
|
||||||
|
initializer=tf.lookup.KeyValueTensorInitializer(
|
||||||
|
keys=keys,
|
||||||
|
values=values,
|
||||||
|
),
|
||||||
|
default_value=tf.constant(-1),
|
||||||
|
name="Indexes"
|
||||||
|
)
|
||||||
|
|
||||||
|
DATA_DIR_PREPARE = DATA_DIR + "/"
|
||||||
|
|
||||||
|
|
||||||
|
# based on https://www.tensorflow.org/tutorials/load_data/images
|
||||||
|
def pathToLabel(path):
|
||||||
|
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||||
|
{{ if eq .Model.Format "png" }}
|
||||||
|
path = tf.strings.regex_replace(path, ".png", "")
|
||||||
|
{{ else if eq .Model.Format "jpeg" }}
|
||||||
|
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||||
|
{{ else }}
|
||||||
|
ERROR
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
|
num = table.lookup(tf.strings.as_string([path]))
|
||||||
|
|
||||||
|
return tf.cond(
|
||||||
|
tf.math.equal(num, tf.constant(-1)),
|
||||||
|
lambda: tf.zeros([depth]),
|
||||||
|
lambda: tf.one_hot(table.lookup(tf.strings.as_string([path])) - diff, depth)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
old_model = keras.models.load_model("{{ .BaseModel }}")
|
||||||
|
|
||||||
|
|
||||||
|
def decode_image(img):
|
||||||
|
{{ if eq .Model.Format "png" }}
|
||||||
|
img = tf.io.decode_png(img, channels={{.ColorMode}})
|
||||||
|
{{ else if eq .Model.Format "jpeg" }}
|
||||||
|
img = tf.io.decode_jpeg(img, channels={{.ColorMode}})
|
||||||
|
{{ else }}
|
||||||
|
ERROR
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
|
return tf.image.resize(img, image_size)
|
||||||
|
|
||||||
|
|
||||||
|
def process_path(path):
|
||||||
|
label = pathToLabel(path)
|
||||||
|
|
||||||
|
img = tf.io.read_file(path)
|
||||||
|
img = decode_image(img)
|
||||||
|
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
|
||||||
|
def configure_for_performance(ds: tf.data.Dataset, size: int, shuffle: bool) -> tf.data.Dataset:
|
||||||
|
# ds = ds.cache()
|
||||||
|
if shuffle:
|
||||||
|
ds = ds.shuffle(buffer_size=size)
|
||||||
|
ds = ds.batch(batch_size)
|
||||||
|
ds = ds.prefetch(AUTOTUNE)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_dataset(ds: tf.data.Dataset, size: int) -> tf.data.Dataset:
|
||||||
|
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
|
||||||
|
ds = configure_for_performance(ds, size, False)
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def filterDataset(path):
|
||||||
|
path = tf.strings.regex_replace(path, DATA_DIR_PREPARE, "")
|
||||||
|
{{ if eq .Model.Format "png" }}
|
||||||
|
path = tf.strings.regex_replace(path, ".png", "")
|
||||||
|
{{ else if eq .Model.Format "jpeg" }}
|
||||||
|
path = tf.strings.regex_replace(path, ".jpeg", "")
|
||||||
|
{{ else }}
|
||||||
|
ERROR
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
|
return tf.reshape(table.lookup(tf.strings.as_string([path])), []) != -1
|
||||||
|
|
||||||
|
|
||||||
|
seed = random.randint(0, 100000000)
|
||||||
|
|
||||||
|
batch_size = 64
|
||||||
|
|
||||||
|
# Read all the files from the direcotry
|
||||||
|
list_ds = tf.data.Dataset.list_files(str(f'{DATA_DIR}/*'), shuffle=False)
|
||||||
|
list_ds = list_ds.filter(filterDataset)
|
||||||
|
|
||||||
|
image_count = len(list(list_ds.as_numpy_iterator()))
|
||||||
|
|
||||||
|
list_ds = list_ds.shuffle(image_count, seed=seed)
|
||||||
|
|
||||||
|
val_size = int(image_count * 0.3)
|
||||||
|
|
||||||
|
train_ds = list_ds.skip(val_size)
|
||||||
|
val_ds = list_ds.take(val_size)
|
||||||
|
|
||||||
|
dataset = prepare_dataset(train_ds, image_count)
|
||||||
|
dataset_validation = prepare_dataset(val_ds, val_size)
|
||||||
|
|
||||||
|
track = 0
|
||||||
|
|
||||||
|
|
||||||
|
def addBlock(
|
||||||
|
b_size: int,
|
||||||
|
filter_size: int,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
top: bool = True,
|
||||||
|
pooling_same: bool = False,
|
||||||
|
pool_func=layers.MaxPool2D
|
||||||
|
):
|
||||||
|
global track
|
||||||
|
model = keras.Sequential(
|
||||||
|
name=f"{track}-{b_size}-{filter_size}-{kernel_size}"
|
||||||
|
)
|
||||||
|
track += 1
|
||||||
|
for _ in range(b_size):
|
||||||
|
model.add(layers.Conv2D(
|
||||||
|
filter_size,
|
||||||
|
kernel_size,
|
||||||
|
padding="same"
|
||||||
|
))
|
||||||
|
model.add(layers.ReLU())
|
||||||
|
if top:
|
||||||
|
if pooling_same:
|
||||||
|
model.add(pool_func(padding="same", strides=(1, 1)))
|
||||||
|
else:
|
||||||
|
model.add(pool_func())
|
||||||
|
model.add(layers.BatchNormalization())
|
||||||
|
model.add(layers.LeakyReLU())
|
||||||
|
model.add(layers.Dropout(0.4))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# Proccess old data
|
||||||
|
|
||||||
|
new_data = old_model.predict(dataset)
|
||||||
|
labels = np.concatenate([y for _, y in dataset], axis=0)
|
||||||
|
new_data = tf.data.Dataset.from_tensor_slices(
|
||||||
|
(new_data, labels))
|
||||||
|
new_data = configure_for_performance(new_data, batch_size, True)
|
||||||
|
|
||||||
|
new_data_val = old_model.predict(dataset_validation)
|
||||||
|
labels_val = np.concatenate([y for _, y in dataset_validation], axis=0)
|
||||||
|
new_data_val = tf.data.Dataset.from_tensor_slices(
|
||||||
|
(new_data_val, labels_val))
|
||||||
|
new_data_val = configure_for_performance(new_data_val, batch_size, True)
|
||||||
|
|
||||||
|
{{ if .LoadPrev }}
|
||||||
|
model = tf.keras.saving.load_model('{{.LastModelRunPath}}')
|
||||||
|
{{ else }}
|
||||||
|
model = keras.Sequential()
|
||||||
|
|
||||||
|
{{- range .Layers }}
|
||||||
|
{{- if eq .LayerType 1}}
|
||||||
|
model.add(layers.Rescaling(1./255))
|
||||||
|
{{- else if eq .LayerType 2 }}
|
||||||
|
model.add(layers.Dense({{ .Shape }}, activation="sigmoid"))
|
||||||
|
{{- else if eq .LayerType 3}}
|
||||||
|
model.add(layers.Flatten())
|
||||||
|
{{- else if eq .LayerType 4}}
|
||||||
|
model.add(addBlock(2, 128, 3, pool_func=layers.AveragePooling2D))
|
||||||
|
{{- else }}
|
||||||
|
ERROR
|
||||||
|
{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
{{ end }}
|
||||||
|
|
||||||
|
model.layers[0]._name = "head"
|
||||||
|
|
||||||
|
model.compile(
|
||||||
|
loss=losses.BinaryCrossentropy(from_logits=False),
|
||||||
|
optimizer=tf.keras.optimizers.Adam(),
|
||||||
|
metrics=['accuracy'])
|
||||||
|
|
||||||
|
his = model.fit(
|
||||||
|
new_data,
|
||||||
|
validation_data=new_data_val,
|
||||||
|
epochs={{.EPOCH_PER_RUN}},
|
||||||
|
callbacks=[
|
||||||
|
NotifyServerCallback(),
|
||||||
|
tf.keras.callbacks.EarlyStopping("loss", mode="min", patience=5)
|
||||||
|
],
|
||||||
|
use_multiprocessing=True
|
||||||
|
)
|
||||||
|
|
||||||
|
acc = his.history["accuracy"]
|
||||||
|
|
||||||
|
f = open("accuracy.val", "w")
|
||||||
|
f.write(str(acc[-1]))
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
tf.saved_model.save(model, "{{ .SaveModelPath }}/model")
|
||||||
|
model.save("{{ .SaveModelPath }}/model.keras")
|
@ -17,7 +17,7 @@ split_len = {{ .SplitLen }}
|
|||||||
|
|
||||||
bottom_input = Input(model.input_shape[1:])
|
bottom_input = Input(model.input_shape[1:])
|
||||||
bottom_output = bottom_input
|
bottom_output = bottom_input
|
||||||
top_input = Input(model.layers[split_len + 1].input_shape[1:])
|
top_input = Input(model.layers[split_len + 1].input_shape[1:], name="head_input")
|
||||||
top_output = top_input
|
top_output = top_input
|
||||||
|
|
||||||
for i, layer in enumerate(model.layers):
|
for i, layer in enumerate(model.layers):
|
||||||
|
@ -146,14 +146,7 @@
|
|||||||
<!-- TODO improve message -->
|
<!-- TODO improve message -->
|
||||||
<h2 class="text-center">Failed to prepare model</h2>
|
<h2 class="text-center">Failed to prepare model</h2>
|
||||||
|
|
||||||
<div>TODO button delete</div>
|
<DeleteModel model={m} />
|
||||||
|
|
||||||
<!--form hx-delete="/models/delete">
|
|
||||||
<input type="hidden" name="id" value="{{ .Model.Id }}" />
|
|
||||||
<button class="danger">
|
|
||||||
Delete
|
|
||||||
</button>
|
|
||||||
</form-->
|
|
||||||
</div>
|
</div>
|
||||||
<!-- PRE TRAINING STATUS -->
|
<!-- PRE TRAINING STATUS -->
|
||||||
{:else if m.status == 2}
|
{:else if m.status == 2}
|
||||||
@ -288,7 +281,7 @@
|
|||||||
{/await}
|
{/await}
|
||||||
<!-- TODO Add ability to stop training -->
|
<!-- TODO Add ability to stop training -->
|
||||||
</div>
|
</div>
|
||||||
{:else if [5, 6, -6].includes(m.status)}
|
{:else if [5, 6, -6, 7, -7].includes(m.status)}
|
||||||
<BaseModelInfo model={m} />
|
<BaseModelInfo model={m} />
|
||||||
<RunModel model={m} />
|
<RunModel model={m} />
|
||||||
{#if m.status == 6}
|
{#if m.status == 6}
|
||||||
@ -299,6 +292,13 @@
|
|||||||
{#if m.status == -6}
|
{#if m.status == -6}
|
||||||
<DeleteZip model={m} on:reload={getModel} expand />
|
<DeleteZip model={m} on:reload={getModel} expand />
|
||||||
{/if}
|
{/if}
|
||||||
|
{#if m.status == -7}
|
||||||
|
<form>
|
||||||
|
<!-- TODO add more info about the failure -->
|
||||||
|
Failed to train the model!
|
||||||
|
Try to retrain
|
||||||
|
</form>
|
||||||
|
{/if}
|
||||||
{#if m.model_type == 2}
|
{#if m.model_type == 2}
|
||||||
<ModelData model={m} on:reload={getModel} />
|
<ModelData model={m} on:reload={getModel} />
|
||||||
{/if}
|
{/if}
|
||||||
|
@ -142,6 +142,7 @@
|
|||||||
</FileUpload>
|
</FileUpload>
|
||||||
</fieldset>
|
</fieldset>
|
||||||
<MessageSimple bind:this={uploadImage} />
|
<MessageSimple bind:this={uploadImage} />
|
||||||
|
{#if file}
|
||||||
{#await uploading}
|
{#await uploading}
|
||||||
<button disabled>
|
<button disabled>
|
||||||
Uploading
|
Uploading
|
||||||
@ -151,6 +152,7 @@
|
|||||||
Add
|
Add
|
||||||
</button>
|
</button>
|
||||||
{/await}
|
{/await}
|
||||||
|
{/if}
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
<div class="content" class:selected={isActive("create-class")}>
|
<div class="content" class:selected={isActive("create-class")}>
|
||||||
@ -190,4 +192,6 @@
|
|||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} />
|
{#if classes.some((item) => item.status == 1) && ![-6, 6].includes(model.status)}
|
||||||
|
<TrainModel number_of_invalid_images={numberOfInvalidImages} {model} {has_data} on:reload={() => dispatch('reload')} />
|
||||||
|
{/if}
|
||||||
|
@ -106,6 +106,15 @@
|
|||||||
class:selected={isActive(item.name)}
|
class:selected={isActive(item.name)}
|
||||||
>
|
>
|
||||||
{item.name}
|
{item.name}
|
||||||
|
{#if model.model_type == 2}
|
||||||
|
{#if item.status == 1}
|
||||||
|
<span class="bi bi-book" style="color: orange;" />
|
||||||
|
{:else if item.status == 2}
|
||||||
|
<span class="bi bi-book" style="color: green;" />
|
||||||
|
{:else if item.status == 3}
|
||||||
|
<span class="bi bi-check" style="color: green;" />
|
||||||
|
{/if}
|
||||||
|
{/if}
|
||||||
</button>
|
</button>
|
||||||
{/each}
|
{/each}
|
||||||
</div>
|
</div>
|
||||||
@ -170,6 +179,7 @@
|
|||||||
</FileUpload>
|
</FileUpload>
|
||||||
</fieldset>
|
</fieldset>
|
||||||
<MessageSimple bind:this={uploadImage} />
|
<MessageSimple bind:this={uploadImage} />
|
||||||
|
{#if file}
|
||||||
{#await uploading}
|
{#await uploading}
|
||||||
<button disabled>
|
<button disabled>
|
||||||
Uploading
|
Uploading
|
||||||
@ -179,6 +189,7 @@
|
|||||||
Add
|
Add
|
||||||
</button>
|
</button>
|
||||||
{/await}
|
{/await}
|
||||||
|
{/if}
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
{/if}
|
{/if}
|
||||||
|
@ -23,6 +23,7 @@
|
|||||||
let messages: MessageSimple;
|
let messages: MessageSimple;
|
||||||
|
|
||||||
async function submit() {
|
async function submit() {
|
||||||
|
messages.clear();
|
||||||
submitted = true;
|
submitted = true;
|
||||||
try {
|
try {
|
||||||
await post('models/train', {
|
await post('models/train', {
|
||||||
@ -34,13 +35,29 @@
|
|||||||
if (e instanceof Response) {
|
if (e instanceof Response) {
|
||||||
messages.display(await e.json());
|
messages.display(await e.json());
|
||||||
} else {
|
} else {
|
||||||
messages.display("Could not start the training of the model");
|
messages.display('Could not start the training of the model');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function submitRetrain() {
|
||||||
|
messages.clear();
|
||||||
|
submitted = true;
|
||||||
|
try {
|
||||||
|
await post('model/train/retrain', { id: model.id });
|
||||||
|
dispatch('reload');
|
||||||
|
} catch (e) {
|
||||||
|
if (e instanceof Response) {
|
||||||
|
messages.display(await e.json());
|
||||||
|
} else {
|
||||||
|
messages.display('Could not start the training of the model');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<form class:submitted on:submit|preventDefault={submit}>
|
{#if model.status == 2}
|
||||||
|
<form class:submitted on:submit|preventDefault={submit}>
|
||||||
{#if has_data}
|
{#if has_data}
|
||||||
{#if number_of_invalid_images > 0}
|
{#if number_of_invalid_images > 0}
|
||||||
<p class="danger">
|
<p class="danger">
|
||||||
@ -92,4 +109,23 @@
|
|||||||
{:else}
|
{:else}
|
||||||
<h2>To train the model please provide data to the model first</h2>
|
<h2>To train the model please provide data to the model first</h2>
|
||||||
{/if}
|
{/if}
|
||||||
</form>
|
</form>
|
||||||
|
{:else}
|
||||||
|
<form class:submitted on:submit|preventDefault={submitRetrain}>
|
||||||
|
{#if has_data}
|
||||||
|
<h2>
|
||||||
|
This model has new classes and can be expanded
|
||||||
|
</h2>
|
||||||
|
{#if number_of_invalid_images > 0}
|
||||||
|
<p class="danger">
|
||||||
|
There are images {number_of_invalid_images} that were loaded that do not have the correct format.DeleteZip
|
||||||
|
These images will be delete when the model trains.
|
||||||
|
</p>
|
||||||
|
{/if}
|
||||||
|
<MessageSimple bind:this={messages} />
|
||||||
|
<button> Retrain </button>
|
||||||
|
{:else}
|
||||||
|
<h2>To train the model please provide data to the model first</h2>
|
||||||
|
{/if}
|
||||||
|
</form>
|
||||||
|
{/if}
|
||||||
|
Loading…
Reference in New Issue
Block a user