fyp/logic/models/train/remote_train.go

80 lines
2.2 KiB
Go
Raw Permalink Normal View History

2024-05-06 01:10:58 +01:00
package models_train
import (
"errors"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
"github.com/goccy/go-json"
)
func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (err error) {
l := b.GetLogger()
model, err := GetBaseModel(b.GetDb(), *task.ModelId)
if err != nil {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model information")
l.Error("Failed to get model information", "err", err)
return err
}
if model.Status != TRAINING {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model not in the correct status for training")
return errors.New("Model not in the right status")
}
// TODO do this when the runner says it's OK
//task.UpdateStatusLog(b, TASK_RUNNING, "Training model")
// TODO move this to the runner part as well
var dat struct {
NumberOfModels int
Accuracy int
}
err = json.Unmarshal([]byte(task.ExtraTaskInfo), &dat)
if err != nil {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model extra information")
}
if model.ModelType == 2 {
panic("TODO")
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if full_error != nil {
l.Error("Failed to generate defintions", "err", full_error)
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
return errors.New("Failed to generate definitions")
}
} else {
error := generateDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
if error != nil {
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
return errors.New("Failed to generate definitions")
}
}
runners := handler.DataMap["runners"].(map[string]interface{})
runner := runners[runner_id].(map[string]interface{})
runner["task"] = &task
runners[runner_id] = runner
handler.DataMap["runners"] = runners
return
}
func CleanUpFailed(b BasePack, task *Task) {
db := b.GetDb()
l := b.GetLogger()
model, err := GetBaseModel(db, *task.ModelId)
if err != nil {
l.Error("Failed to get model", "err", err)
} else {
err = model.UpdateStatus(db, FAILED_TRAINING)
if err != nil {
l.Error("Failed to get status", err)
}
}
}