move model retrain to runners closes #94
This commit is contained in:
parent
f182b205f8
commit
06642dcb1e
17
config.toml
17
config.toml
@ -1,14 +1,17 @@
|
||||
PORT=5002
|
||||
PORT = 5002
|
||||
|
||||
HOSTNAME="https://testing.andr3h3nriqu3s.com"
|
||||
HOSTNAME = "https://testing.andr3h3nriqu3s.com"
|
||||
|
||||
NUMBER_OF_WORKERS=20
|
||||
NUMBER_OF_WORKERS = 20
|
||||
|
||||
SUPRESS_CUDA=1
|
||||
SUPRESS_CUDA = 1
|
||||
|
||||
[ServiceUser]
|
||||
USER="service"
|
||||
USER = "service"
|
||||
|
||||
[Worker]
|
||||
PULLING_TIME="500ms"
|
||||
NUMBER_OF_WORKERS=1
|
||||
PULLING_TIME = "500ms"
|
||||
NUMBER_OF_WORKERS = 1
|
||||
|
||||
[DB]
|
||||
MAX_CONNECTIONS = 600
|
||||
|
@ -1,11 +1,11 @@
|
||||
version: '3.1'
|
||||
version: "3.1"
|
||||
|
||||
services:
|
||||
db:
|
||||
image: docker.andr3h3nriqu3s.com/services/postgres
|
||||
command: -c 'max_connections=400'
|
||||
command: -c 'max_connections=600'
|
||||
restart: always
|
||||
environment:
|
||||
POSTGRES_PASSWORD: verysafepassword
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5432:5432"
|
||||
|
@ -281,16 +281,17 @@ func trainDefinition(c BasePack, model *BaseModel, definition_id string, load_pr
|
||||
return
|
||||
}
|
||||
|
||||
func generateCvsExpandExp(c *Context, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
|
||||
func generateCvsExpandExp(c BasePack, run_path string, model_id string, offset int, doPanic bool) (count_re int, err error) {
|
||||
l, db := c.GetLogger(), c.GetDb()
|
||||
|
||||
var co struct {
|
||||
Count int `db:"count(*)"`
|
||||
}
|
||||
err = GetDBOnce(c, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||
err = GetDBOnce(db, &co, "model_classes where model_id=$1 and status=$2;", model_id, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.Logger.Info("test here", "count", co)
|
||||
l.Info("test here", "count", co)
|
||||
count_re = co.Count
|
||||
count := co.Count
|
||||
|
||||
@ -304,7 +305,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
||||
return generateCvsExpandExp(c, run_path, model_id, offset, true)
|
||||
}
|
||||
|
||||
data, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
data, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINING)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -338,7 +339,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
||||
// This is to load some extra data so that the model has more things to train on
|
||||
//
|
||||
|
||||
data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
|
||||
data_other, err := db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -361,10 +362,11 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i
|
||||
return
|
||||
}
|
||||
|
||||
func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
func trainDefinitionExpandExp(c BasePack, model *BaseModel, definition_id string, load_prev bool) (accuracy float64, err error) {
|
||||
accuracy = 0
|
||||
|
||||
c.Logger.Warn("About to retrain model")
|
||||
l := c.GetLogger()
|
||||
l.Warn("About to retrain model")
|
||||
|
||||
// Get untrained models heads
|
||||
|
||||
@ -375,7 +377,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
}
|
||||
|
||||
// status = 2 (INIT) 3 (TRAINING)
|
||||
heads, err := GetDbMultitple[ExpHead](c, "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||
heads, err := GetDbMultitple[ExpHead](c.GetDb(), "exp_model_head where def_id=$1 and (status = 2 or status = 3)", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
} else if len(heads) == 0 {
|
||||
@ -389,13 +391,13 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
|
||||
exp := heads[0]
|
||||
|
||||
c.Logger.Info("Got exp head", "head", exp)
|
||||
l.Info("Got exp head", "head", exp)
|
||||
|
||||
if err = UpdateStatus(c, "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||
if err = UpdateStatus(c.GetDb(), "exp_model_head", exp.Id, MODEL_DEFINITION_STATUS_TRAINING); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
layers, err := c.Db.Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
layers, err := c.GetDb().Query("select layer_type, shape, exp_type from model_definition_layer where def_id=$1 order by layer_order asc;", definition_id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -447,7 +449,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
LayerNum: i,
|
||||
})
|
||||
|
||||
c.Logger.Info("Got layers", "layers", got)
|
||||
l.Info("Got layers", "layers", got)
|
||||
|
||||
// Generate run folder
|
||||
run_path := path.Join("/tmp", model.Id+"-defs-"+definition_id+"-retrain")
|
||||
@ -462,7 +464,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Generated cvs", "classCount", classCount)
|
||||
l.Info("Generated cvs", "classCount", classCount)
|
||||
|
||||
// TODO update the run script
|
||||
|
||||
@ -473,7 +475,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
c.Logger.Info("About to run python!")
|
||||
l.Info("About to run python!")
|
||||
|
||||
tmpl, err := template.New("python_model_template_expand.py").ParseFiles("views/py/python_model_template_expand.py")
|
||||
if err != nil {
|
||||
@ -498,7 +500,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
"SaveModelPath": path.Join(getDir(), result_path, "head", exp.Id),
|
||||
"Depth": classCount,
|
||||
"StartPoint": 0,
|
||||
"Host": (*c.Handle).Config.Hostname,
|
||||
"Host": c.GetHost(),
|
||||
}); err != nil {
|
||||
return
|
||||
}
|
||||
@ -506,11 +508,11 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
// Run the command
|
||||
out, err := exec.Command("bash", "-c", fmt.Sprintf("cd %s && python run.py", run_path)).CombinedOutput()
|
||||
if err != nil {
|
||||
c.Logger.Warn("Python failed to run", "err", err, "out", string(out))
|
||||
l.Warn("Python failed to run", "err", err, "out", string(out))
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("Python finished running")
|
||||
l.Info("Python finished running")
|
||||
|
||||
if err = os.MkdirAll(result_path, os.ModePerm); err != nil {
|
||||
return
|
||||
@ -533,7 +535,7 @@ func trainDefinitionExpandExp(c *Context, model *BaseModel, definition_id string
|
||||
}
|
||||
|
||||
os.RemoveAll(run_path)
|
||||
c.Logger.Info("Model finished training!", "accuracy", accuracy)
|
||||
l.Info("Model finished training!", "accuracy", accuracy)
|
||||
return
|
||||
}
|
||||
|
||||
@ -1555,10 +1557,10 @@ func generateExpandableDefinitions(c BasePack, model *BaseModel, target_accuracy
|
||||
return nil
|
||||
}
|
||||
|
||||
func ResetClasses(c *Context, model *BaseModel) {
|
||||
_, err := c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||
func ResetClasses(c BasePack, model *BaseModel) {
|
||||
_, err := c.GetDb().Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TO_TRAIN, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Error while reseting the classes", "error", err)
|
||||
c.GetLogger().Error("Error while reseting the classes", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1620,44 +1622,6 @@ func trainExpandable(c *Context, model *BaseModel) {
|
||||
ModelUpdateStatus(c, model.Id, READY)
|
||||
}
|
||||
|
||||
func trainRetrain(c *Context, model *BaseModel, defId string) {
|
||||
var err error
|
||||
|
||||
failed := func() {
|
||||
ResetClasses(c, model)
|
||||
ModelUpdateStatus(c, model.Id, READY_RETRAIN_FAILED)
|
||||
c.Logger.Error("Failed to retrain", "err", err)
|
||||
return
|
||||
}
|
||||
|
||||
// This is something I have to check
|
||||
acc, err := trainDefinitionExpandExp(c, model, defId, false)
|
||||
if err != nil {
|
||||
c.Logger.Error("Failed to retrain the model", "err", err)
|
||||
failed()
|
||||
return
|
||||
|
||||
}
|
||||
c.Logger.Info("Retrained model", "accuracy", acc)
|
||||
|
||||
// TODO check accuracy
|
||||
|
||||
err = UpdateStatus(c, "models", model.Id, READY)
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
c.Logger.Info("model updaded")
|
||||
|
||||
_, err = c.Db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||
if err != nil {
|
||||
c.Logger.Error("Error while updating the classes", "error", err)
|
||||
failed()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func RunTaskTrain(b BasePack, task Task) (err error) {
|
||||
l := b.GetLogger()
|
||||
|
||||
@ -1718,6 +1682,62 @@ func RunTaskTrain(b BasePack, task Task) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func RunTaskRetrain(b BasePack, task Task) (err error) {
|
||||
model, err := GetBaseModel(b.GetDb(), task.ModelId)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if model.Status != READY_RETRAIN {
|
||||
return errors.New("Model in invalid status for re-training")
|
||||
}
|
||||
|
||||
l := b.GetLogger()
|
||||
db := b.GetDb()
|
||||
|
||||
failed := func() {
|
||||
ResetClasses(b, model)
|
||||
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
|
||||
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model failed retraining")
|
||||
l.Error("Failed to retrain", "err", err)
|
||||
}
|
||||
|
||||
task.UpdateStatusLog(b, TASK_RUNNING, "Model retraining")
|
||||
|
||||
defId, err := GetDbVar[string](db, "md.id", "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
// This is something I have to check
|
||||
acc, err := trainDefinitionExpandExp(b, model, *defId, false)
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
l.Info("Retrained model", "accuracy", acc)
|
||||
|
||||
// TODO check accuracy
|
||||
err = UpdateStatus(db, "models", model.Id, READY)
|
||||
if err != nil {
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
l.Info("Model updaded")
|
||||
|
||||
_, err = db.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINED, MODEL_CLASS_STATUS_TRAINING, model.Id)
|
||||
if err != nil {
|
||||
l.Error("Error while updating the classes", "error", err)
|
||||
failed()
|
||||
return
|
||||
}
|
||||
|
||||
task.UpdateStatusLog(b, TASK_DONE, "Model finished retraining")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func handleTrain(handle *Handle) {
|
||||
|
||||
type TrainReq struct {
|
||||
@ -1899,17 +1919,29 @@ func handleTrain(handle *Handle) {
|
||||
return failed()
|
||||
}
|
||||
|
||||
go trainRetrain(c, model, def.Id)
|
||||
|
||||
_, err = c.Db.Exec("update models set status=$1 where id=$2;", READY_RETRAIN, model.Id)
|
||||
if err != nil {
|
||||
fmt.Println("Failed to update model status")
|
||||
fmt.Println(err)
|
||||
// TODO improve this response
|
||||
return c.Error500(err)
|
||||
return c.E500M("Failed to update model status", err)
|
||||
}
|
||||
|
||||
return c.SendJSON(model.Id)
|
||||
newTask := struct {
|
||||
UserId string `db:"user_id"`
|
||||
ModelId string `db:"model_id"`
|
||||
TaskType TaskType `db:"task_type"`
|
||||
Status int `db:"status"`
|
||||
}{
|
||||
UserId: c.User.Id,
|
||||
ModelId: model.Id,
|
||||
TaskType: TASK_TYPE_RETRAINING,
|
||||
Status: 1,
|
||||
}
|
||||
|
||||
id, err := InsertReturnId(c, &newTask, "tasks", "id")
|
||||
if err != nil {
|
||||
return c.E500M("Failed to create task", err)
|
||||
}
|
||||
|
||||
return c.SendJSON(JustId{Id: id})
|
||||
})
|
||||
|
||||
handle.Get("/model/epoch/update", func(c *Context) *Error {
|
||||
|
@ -63,6 +63,14 @@ func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_c
|
||||
logger.Error("Failed to tain the model", "error", err)
|
||||
}
|
||||
|
||||
back_channel <- index
|
||||
continue
|
||||
} else if task.TaskType == int(TASK_TYPE_RETRAINING) {
|
||||
logger.Info("Retraining Task")
|
||||
if err = RunTaskRetrain(base, task); err != nil {
|
||||
logger.Error("Failed to tain the model", "error", err)
|
||||
}
|
||||
|
||||
back_channel <- index
|
||||
continue
|
||||
}
|
||||
|
@ -38,6 +38,7 @@ type TaskType int
|
||||
const (
|
||||
TASK_TYPE_CLASSIFICATION TaskType = 1 + iota
|
||||
TASK_TYPE_TRAINING
|
||||
TASK_TYPE_RETRAINING
|
||||
)
|
||||
|
||||
func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) {
|
||||
|
@ -21,6 +21,10 @@ type ServiceUser struct {
|
||||
UserId string `toml:"__user__id__"`
|
||||
}
|
||||
|
||||
type DbInfo struct {
|
||||
MaxConnections int `toml:"max_connections"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Hostname string
|
||||
Port int
|
||||
@ -30,6 +34,8 @@ type Config struct {
|
||||
GpuWorker WorkerConfig `toml:"Worker"`
|
||||
|
||||
ServiceUser ServiceUser `toml:"ServiceUser"`
|
||||
|
||||
DbInfo DbInfo `toml:"DB"`
|
||||
}
|
||||
|
||||
func LoadConfig() Config {
|
||||
@ -52,6 +58,9 @@ func LoadConfig() Config {
|
||||
User: "Service",
|
||||
UserId: "",
|
||||
},
|
||||
DbInfo: DbInfo{
|
||||
MaxConnections: 200,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -84,13 +93,13 @@ func (c *Config) GenerateToken(db *sql.DB) {
|
||||
Email string
|
||||
Salt string
|
||||
Password string
|
||||
UserType UserType `db:"user_type"`
|
||||
UserType UserType `db:"user_type"`
|
||||
}{
|
||||
c.ServiceUser.User,
|
||||
c.ServiceUser.User,
|
||||
"",
|
||||
"",
|
||||
User_Admin,
|
||||
User_Admin,
|
||||
}
|
||||
id, err := InsertReturnId(db, &newUser, "users", "id")
|
||||
if err != nil {
|
||||
|
4
main.go
4
main.go
@ -37,7 +37,9 @@ func main() {
|
||||
|
||||
config := LoadConfig()
|
||||
log.Info("Config loaded!", "config", config)
|
||||
config.GenerateToken(db)
|
||||
config.GenerateToken(db)
|
||||
|
||||
db.SetMaxOpenConns(config.DbInfo.MaxConnections)
|
||||
|
||||
StartRunners(db, config)
|
||||
|
||||
|
@ -33,23 +33,35 @@
|
||||
$effect(() => {
|
||||
if (data && ctx) {
|
||||
if (chart) {
|
||||
console.log('update');
|
||||
chart.data = {
|
||||
labels: data.map((a) => a.name),
|
||||
datasets: [
|
||||
{
|
||||
label: 'Training',
|
||||
data: data.map((a) => a.training)
|
||||
},
|
||||
{
|
||||
label: 'Testing',
|
||||
data: data.map((a) => a.testing)
|
||||
chart.destroy();
|
||||
chart = new Chart(ctx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
labels: data.map((a) => a.name),
|
||||
datasets: [
|
||||
{
|
||||
label: 'Training',
|
||||
data: data.map((a) => a.training)
|
||||
},
|
||||
{
|
||||
label: 'Testing',
|
||||
data: data.map((a) => a.testing)
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
animation: false,
|
||||
scales: {
|
||||
x: {
|
||||
stacked: true
|
||||
},
|
||||
y: {
|
||||
stacked: true
|
||||
}
|
||||
}
|
||||
]
|
||||
};
|
||||
chart.update('resize');
|
||||
}
|
||||
});
|
||||
} else {
|
||||
console.log('create');
|
||||
chart = new Chart(ctx, {
|
||||
type: 'bar',
|
||||
data: {
|
||||
|
@ -85,6 +85,8 @@
|
||||
Image Run
|
||||
{:else if task.type == 2}
|
||||
Model training
|
||||
{:else if task.type == 3}
|
||||
Model Re-training
|
||||
{:else}
|
||||
{task.type}
|
||||
{/if}
|
||||
@ -98,7 +100,7 @@
|
||||
width="30px"
|
||||
style="object-fit: contain;"
|
||||
/>
|
||||
{:else if [2].includes(task.type)}{:else}
|
||||
{:else if [2, 3].includes(task.type)}{:else}
|
||||
TODO Show more information {task.status}
|
||||
{/if}
|
||||
</td>
|
||||
@ -117,7 +119,7 @@
|
||||
{:else}
|
||||
-
|
||||
{/if}
|
||||
{:else if [2].includes(task.type)}{:else}
|
||||
{:else if [2, 3].includes(task.type)}{:else}
|
||||
TODO Handle {task.type}
|
||||
{/if}
|
||||
</td>
|
||||
|
Loading…
Reference in New Issue
Block a user