fyp/logic/models/train/remote_train.go
2024-05-09 00:46:42 +01:00

119 lines
3.3 KiB
Go

package models_train
import (
"errors"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/charmbracelet/log"
"github.com/goccy/go-json"
)
func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (err error) {
l := b.GetLogger()
model, err := GetBaseModel(b.GetDb(), *task.ModelId)
if err != nil {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model information")
l.Error("Failed to get model information", "err", err)
return err
}
if model.Status != TRAINING {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model not in the correct status for training")
return errors.New("Model not in the right status")
}
// TODO do this when the runner says it's OK
//task.UpdateStatusLog(b, TASK_RUNNING, "Training model")
// TODO move this to the runner part as well
var dat struct {
NumberOfModels int
Accuracy int
}
err = json.Unmarshal([]byte(task.ExtraTaskInfo), &dat)
if err != nil {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model extra information")
}
if model.ModelType == 2 {
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if full_error != nil {
l.Error("Failed to generate defintions", "err", full_error)
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
return errors.New("Failed to generate definitions")
}
} else {
error := generateDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if error != nil {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
return errors.New("Failed to generate definitions")
}
}
runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task
runners[runner_id] = runner
handler.DataMap["runners"] = runners
return
}
func CleanUpFailed(b BasePack, task *Task) {
db := b.GetDb()
l := b.GetLogger()
model, err := GetBaseModel(db, *task.ModelId)
if err != nil {
l.Error("Failed to get model", "err", err)
} else {
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("Failed to get status", err)
}
}
// Set the class status to trained
err = SetModelClassStatus(b, CLASS_STATUS_TO_TRAIN, "model_id=$1 and status=$2;", model.Id, CLASS_STATUS_TRAINING)
if err != nil {
l.Error("Failed to set class status")
return
}
}
func CleanUpFailedRetrain(b BasePack, task *Task) {
db := b.GetDb()
l := b.GetLogger()
model, err := GetBaseModel(db, *task.ModelId)
if err != nil {
l.Error("Failed to get model", "err", err)
} else {
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("Failed to get status", err)
}
}
ResetClasses(b, model)
ModelUpdateStatus(b, model.Id, READY_RETRAIN_FAILED)
var defData struct {
Id string `db:"md.id"`
TargetAcuuracy float64 `db:"md.target_accuracy"`
}
err = GetDBOnce(db, &defData, "models as m inner join model_definition as md on m.id = md.model_id where m.id=$1;", task.ModelId)
if err != nil {
log.Error("failed to get def data", err)
return
}
_, err_ := db.Exec("delete from exp_model_head where def_id=$1 and status in (2,3)", defData.Id)
if err_ != nil {
panic(err_)
}
}