package models_train import ( "errors" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" "github.com/charmbracelet/log" "github.com/goccy/go-json" ) func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (err error) { l := b.GetLogger() model, err := GetBaseModel(b.GetDb(), *task.ModelId) if err != nil { task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model information") l.Error("Failed to get model information", "err", err) return err } if model.Status != TRAINING { task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model not in the correct status for training") return errors.New("Model not in the right status") } // TODO do this when the runner says it's OK //task.UpdateStatusLog(b, TASK_RUNNING, "Training model") // TODO move this to the runner part as well var dat struct { NumberOfModels int Accuracy int } err = json.Unmarshal([]byte(task.ExtraTaskInfo), &dat) if err != nil { task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model extra information") } if model.ModelType == 2 { full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) if full_error != nil { l.Error("Failed to generate defintions", "err", full_error) task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model") return errors.New("Failed to generate definitions") } } else { error := generateDefinitions(b, model, dat.Accuracy, dat.NumberOfModels) if error != nil { task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model") return errors.New("Failed to generate definitions") } } runners := handler.DataMap["runners"].(map[string]interface{}) runner := runners[runner_id].(map[string]interface{}) runner["task"] = &task runners[runner_id] = runner handler.DataMap["runners"] = runners return } func CleanUpFailed(b BasePack, task *Task) { db := b.GetDb() l := b.GetLogger() model, err := GetBaseModel(db, *task.ModelId) if err != nil { l.Error("Failed to get model", "err", err) } else { err = model.UpdateStatus(db, FAILED_TRAINING) if err != nil { l.Error("Failed to get status", err) } } // Set the class status to trained err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING) if err != nil { l.Error("Failed to set class status") return } } func CleanUpFailedRetrain(b BasePack, task *Task) { db := b.GetDb() l := b.GetLogger() model, err := GetBaseModel(db, *task.ModelId) if err != nil { l.Error("Failed to get model", "err", err) } else { err = model.UpdateStatus(db, FAILED_TRAINING) if err != nil { l.Error("Failed to get status", err) } } ResetClasses(b, model) ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED) var defData struct { Id string `db:"md.id"` TargetAcuuracy float64 `db:"md.target_accuracy"` } err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId) if err != nil { log.Error("failed to get def data", err) return } _, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id) if err_ != nil { panic(err_) } }