|
|
|
@ -281,16 +281,17 @@ func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_pr
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
|
|
|
|
|
func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
|
|
|
|
|
l, db := c.GetLogger(), c.GetDb()
|
|
|
|
|
|
|
|
|
|
var co struct {
|
|
|
|
|
Count int `db:"count(*)"`
|
|
|
|
|
}
|
|
|
|
|
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
|
|
|
|
err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.Logger.Info("test here", "count", co)
|
|
|
|
|
l.Info("test here", "count", co)
|
|
|
|
|
count_re = co.Count
|
|
|
|
|
count := co.Count
|
|
|
|
|
|
|
|
|
@ -304,7 +305,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
|
|
|
|
return generateCvsExpandExp(c, run_path, model_id, offset, true)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
|
|
|
|
data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
@ -338,7 +339,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
|
|
|
|
// This is to load some extra data so that the model has more things to train on
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
|
|
|
|
|
data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
@ -361,10 +362,11 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
|
|
|
|
func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
|
|
|
|
accuracy = 0
|
|
|
|
|
|
|
|
|
|
c.Logger.Warn("About to retrain model")
|
|
|
|
|
l := c.GetLogger()
|
|
|
|
|
l.Warn("About to retrain model")
|
|
|
|
|
|
|
|
|
|
// Get untrained models heads
|
|
|
|
|
|
|
|
|
@ -375,7 +377,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// status = 2 (INIT) 3 (TRAINING)
|
|
|
|
|
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
|
|
|
|
heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
} else if len(heads) == 0 {
|
|
|
|
@ -389,13 +391,13 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
|
|
|
|
|
exp := heads[0]
|
|
|
|
|
|
|
|
|
|
c.Logger.Info("Got exp head", "head", exp)
|
|
|
|
|
l.Info("Got exp head", "head", exp)
|
|
|
|
|
|
|
|
|
|
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
|
|
|
|
if err = UpdateStatus(c.GetDb(), "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
|
|
|
|
layers, err := c.GetDb().Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
@ -447,7 +449,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
LayerNum: i,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
c.Logger.Info("Got layers", "layers", got)
|
|
|
|
|
l.Info("Got layers", "layers", got)
|
|
|
|
|
|
|
|
|
|
// Generate run folder
|
|
|
|
|
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
|
|
|
|
@ -462,7 +464,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.Logger.Info("Generated cvs", "classCount", classCount)
|
|
|
|
|
l.Info("Generated cvs", "classCount", classCount)
|
|
|
|
|
|
|
|
|
|
// TODO update the run script
|
|
|
|
|
|
|
|
|
@ -473,7 +475,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
}
|
|
|
|
|
defer f.Close()
|
|
|
|
|
|
|
|
|
|
c.Logger.Info("About to run python!")
|
|
|
|
|
l.Info("About to run python!")
|
|
|
|
|
|
|
|
|
|
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
|
|
|
|
|
if err != nil {
|
|
|
|
@ -498,7 +500,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
|
|
|
|
|
"Depth": classCount,
|
|
|
|
|
"StartPoint": 0,
|
|
|
|
|
"Host": (*c.Handle).Config.Hostname,
|
|
|
|
|
"Host": c.GetHost(),
|
|
|
|
|
}); err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
@ -506,11 +508,11 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
// Run the command
|
|
|
|
|
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.Logger.Warn("Python failed to run", "err", err, "out", string(out))
|
|
|
|
|
l.Warn("Python failed to run", "err", err, "out", string(out))
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.Logger.Info("Python finished running")
|
|
|
|
|
l.Info("Python finished running")
|
|
|
|
|
|
|
|
|
|
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
|
|
|
|
return
|
|
|
|
@ -533,7 +535,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
os.RemoveAll(run_path)
|
|
|
|
|
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
|
|
|
|
l.Info("Model finished training!", "accuracy", accuracy)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1555,10 +1557,10 @@ func generateExpandableDefinitions(c BasePack, model *BaseModel, target_accuracy
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func ResetClasses(c *Context, model *BaseModel) {
|
|
|
|
|
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
|
|
|
|
func ResetClasses(c BasePack, model *BaseModel) {
|
|
|
|
|
_, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.Logger.Error("Error while reseting the classes", "error", err)
|
|
|
|
|
c.GetLogger().Error("Error while reseting the classes", "error", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1620,44 +1622,6 @@ func trainExpandable(c *Context, model *BaseModel) {
|
|
|
|
|
ModelUpdateStatus(c, model.Id, READY)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func trainRetrain(c *Context, model *BaseModel, defId string) {
|
|
|
|
|
var err error
|
|
|
|
|
|
|
|
|
|
failed := func() {
|
|
|
|
|
ResetClasses(c, model)
|
|
|
|
|
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
|
|
|
|
c.Logger.Error("Failed to retrain", "err", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This is something I have to check
|
|
|
|
|
acc, err := trainDefinitionExpandExp(c, model, defId, false)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.Logger.Error("Failed to retrain the model", "err", err)
|
|
|
|
|
failed()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
c.Logger.Info("Retrained model", "accuracy", acc)
|
|
|
|
|
|
|
|
|
|
// TODO check accuracy
|
|
|
|
|
|
|
|
|
|
err = UpdateStatus(c, "models", model.Id, READY)
|
|
|
|
|
if err != nil {
|
|
|
|
|
failed()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.Logger.Info("model updaded")
|
|
|
|
|
|
|
|
|
|
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.Logger.Error("Error while updating the classes", "error", err)
|
|
|
|
|
failed()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func RunTaskTrain(b BasePack, task Task) (err error) {
|
|
|
|
|
l := b.GetLogger()
|
|
|
|
|
|
|
|
|
@ -1718,6 +1682,62 @@ func RunTaskTrain(b BasePack, task Task) (err error) {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func RunTaskRetrain(b BasePack, task Task) (err error) {
|
|
|
|
|
model, err := GetBaseModel(b.GetDb(), task.ModelId)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
} else if model.Status != READY_RETRAIN {
|
|
|
|
|
return errors.New("Model in invalid status for re-training")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
l := b.GetLogger()
|
|
|
|
|
db := b.GetDb()
|
|
|
|
|
|
|
|
|
|
failed := func() {
|
|
|
|
|
ResetClasses(b, model)
|
|
|
|
|
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
|
|
|
|
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining")
|
|
|
|
|
l.Error("Failed to retrain", "err", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining")
|
|
|
|
|
|
|
|
|
|
defId, err := GetDbVar[string](db, "md.id", "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
|
|
|
|
|
if err != nil {
|
|
|
|
|
failed()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// This is something I have to check
|
|
|
|
|
acc, err := trainDefinitionExpandExp(b, model, *defId, false)
|
|
|
|
|
if err != nil {
|
|
|
|
|
failed()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
l.Info("Retrained model", "accuracy", acc)
|
|
|
|
|
|
|
|
|
|
// TODO check accuracy
|
|
|
|
|
err = UpdateStatus(db, "models", model.Id, READY)
|
|
|
|
|
if err != nil {
|
|
|
|
|
failed()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
l.Info("Model updaded")
|
|
|
|
|
|
|
|
|
|
_, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
l.Error("Error while updating the classes", "error", err)
|
|
|
|
|
failed()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func handleTrain(handle *Handle) {
|
|
|
|
|
|
|
|
|
|
type TrainReq struct {
|
|
|
|
@ -1899,17 +1919,29 @@ func handleTrain(handle *Handle) {
|
|
|
|
|
return failed()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
go trainRetrain(c, model, def.Id)
|
|
|
|
|
|
|
|
|
|
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
|
|
|
|
if err != nil {
|
|
|
|
|
fmt.Println("Failed to update model status")
|
|
|
|
|
fmt.Println(err)
|
|
|
|
|
// TODO improve this response
|
|
|
|
|
return c.Error500(err)
|
|
|
|
|
return c.E500M("Failed to update model status", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return c.SendJSON(model.Id)
|
|
|
|
|
newTask := struct {
|
|
|
|
|
UserId string `db:"user_id"`
|
|
|
|
|
ModelId string `db:"model_id"`
|
|
|
|
|
TaskType TaskType `db:"task_type"`
|
|
|
|
|
Status int `db:"status"`
|
|
|
|
|
}{
|
|
|
|
|
UserId: c.User.Id,
|
|
|
|
|
ModelId: model.Id,
|
|
|
|
|
TaskType: TASK_TYPE_RETRAINING,
|
|
|
|
|
Status: 1,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
id, err := InsertReturnId(c, &newTask, "tasks", "id")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return c.E500M("Failed to create task", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return c.SendJSON(JustId{Id: id})
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
|
|
|
|