79 lines
2.2 KiB
Go
79 lines
2.2 KiB
Go
package models_train
|
|
|
|
import (
|
|
"errors"
|
|
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
|
"github.com/goccy/go-json"
|
|
)
|
|
|
|
func PrepareTraining(handler *Handle, b BasePack, task Task, runner_id string) (err error) {
|
|
l := b.GetLogger()
|
|
|
|
model, err := GetBaseModel(b.GetDb(), *task.ModelId)
|
|
if err != nil {
|
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model information")
|
|
l.Error("Failed to get model information", "err", err)
|
|
return err
|
|
}
|
|
|
|
if model.Status != TRAINING {
|
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Model not in the correct status for training")
|
|
return errors.New("Model not in the right status")
|
|
}
|
|
|
|
// TODO do this when the runner says it's OK
|
|
//task.UpdateStatusLog(b, TASK_RUNNING, "Training model")
|
|
|
|
// TODO move this to the runner part as well
|
|
var dat struct {
|
|
NumberOfModels int
|
|
Accuracy int
|
|
}
|
|
|
|
err = json.Unmarshal([]byte(task.ExtraTaskInfo), &dat)
|
|
if err != nil {
|
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed to get model extra information")
|
|
}
|
|
|
|
if model.ModelType == 2 {
|
|
panic("TODO")
|
|
full_error := generateExpandableDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
|
|
if full_error != nil {
|
|
l.Error("Failed to generate defintions", "err", full_error)
|
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
|
|
return errors.New("Failed to generate definitions")
|
|
}
|
|
} else {
|
|
error := generateDefinitions(b, model, dat.Accuracy, dat.NumberOfModels)
|
|
if error != nil {
|
|
task.UpdateStatusLog(b, TASK_FAILED_RUNNING, "Failed generate model")
|
|
return errors.New("Failed to generate definitions")
|
|
}
|
|
}
|
|
|
|
runners := handler.DataMap["runners"].(map[string]interface{})
|
|
runner := runners[runner_id].(map[string]interface{})
|
|
runner["task"] = &task
|
|
runners[runner_id] = runner
|
|
handler.DataMap["runners"] = runners
|
|
|
|
return
|
|
}
|
|
|
|
func CleanUpFailed(b BasePack, task *Task) {
|
|
db := b.GetDb()
|
|
l := b.GetLogger()
|
|
model, err := GetBaseModel(db, *task.ModelId)
|
|
if err != nil {
|
|
l.Error("Failed to get model", "err", err)
|
|
} else {
|
|
err = model.UpdateStatus(db, FAILED_TRAINING)
|
|
if err != nil {
|
|
l.Error("Failed to get status", err)
|
|
}
|
|
}
|
|
}
|