119 lines
3.3 KiB
Go
119 lines
3.3 KiB
Go
|
package models_train
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
|
||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||
|
"github.com/charmbracelet/log"
|
||
|
"github.com/goccy/go-json"
|
||
|
)
|
||
|
|
||
|
func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (err error) {
|
||
|
l := b.GetLogger()
|
||
|
|
||
|
model, err := GetBaseModel(b.GetDb(), *task.ModelId)
|
||
|
if err != nil {
|
||
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model information")
|
||
|
l.Error("Failed to get model information", "err", err)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if model.Status != TRAINING {
|
||
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model not in the correct status for training")
|
||
|
return errors.New("Model not in the right status")
|
||
|
}
|
||
|
|
||
|
// TODO do this when the runner says it's OK
|
||
|
//task.UpdateStatusLog(b, TASK_RUNNING, "Training model")
|
||
|
|
||
|
// TODO move this to the runner part as well
|
||
|
var dat struct {
|
||
|
NumberOfModels int
|
||
|
Accuracy int
|
||
|
}
|
||
|
|
||
|
err = json.Unmarshal([]byte(task.ExtraTaskInfo), &dat)
|
||
|
if err != nil {
|
||
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model extra information")
|
||
|
}
|
||
|
|
||
|
if model.ModelType == 2 {
|
||
|
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
|
||
|
if full_error != nil {
|
||
|
l.Error("Failed to generate defintions", "err", full_error)
|
||
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
|
||
|
return errors.New("Failed to generate definitions")
|
||
|
}
|
||
|
} else {
|
||
|
error := generateDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
|
||
|
if error != nil {
|
||
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
|
||
|
return errors.New("Failed to generate definitions")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
runners := handler.DataMap["runners"].(map[string]interface{})
|
||
|
runner := runners[runner_id].(map[string]interface{})
|
||
|
runner["task"] = &task
|
||
|
runners[runner_id] = runner
|
||
|
handler.DataMap["runners"] = runners
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func CleanUpFailed(b BasePack, task *Task) {
|
||
|
db := b.GetDb()
|
||
|
l := b.GetLogger()
|
||
|
model, err := GetBaseModel(db, *task.ModelId)
|
||
|
if err != nil {
|
||
|
l.Error("Failed to get model", "err", err)
|
||
|
} else {
|
||
|
err = model.UpdateStatus(db, FAILED_TRAINING)
|
||
|
if err != nil {
|
||
|
l.Error("Failed to get status", err)
|
||
|
}
|
||
|
}
|
||
|
// Set the class status to trained
|
||
|
err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
|
||
|
if err != nil {
|
||
|
l.Error("Failed to set class status")
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func CleanUpFailedRetrain(b BasePack, task *Task) {
|
||
|
db := b.GetDb()
|
||
|
l := b.GetLogger()
|
||
|
|
||
|
model, err := GetBaseModel(db, *task.ModelId)
|
||
|
if err != nil {
|
||
|
l.Error("Failed to get model", "err", err)
|
||
|
} else {
|
||
|
err = model.UpdateStatus(db, FAILED_TRAINING)
|
||
|
if err != nil {
|
||
|
l.Error("Failed to get status", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
ResetClasses(b, model)
|
||
|
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
|
||
|
|
||
|
var defData struct {
|
||
|
Id string `db:"md.id"`
|
||
|
TargetAcuuracy float64 `db:"md.target_accuracy"`
|
||
|
}
|
||
|
|
||
|
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
|
||
|
if err != nil {
|
||
|
log.Error("failed to get def data", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
|
||
|
if err_ != nil {
|
||
|
panic(err_)
|
||
|
}
|
||
|
}
|