package task_runner import ( "fmt" "math" "os" "runtime/debug" "sync" "time" "github.com/charmbracelet/log" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/users" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) /** * Actually runs the code */ func runner(config Config, db db.Db, task_channel chan Task, index int, back_channel chan int) { logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, TimeFormat: time.Kitchen, Prefix: fmt.Sprintf("Runner %d", index), }) defer func() { if r := recover(); r != nil { logger.Error("Recovered in runner", "processor id", index, "due to", r, "stack", string(debug.Stack())) back_channel <- -index } }() logger.Info("Started up") var err error base := BasePackStruct{ Db: db, Logger: logger, Host: config.Hostname, } for task := range task_channel { logger.Info("Got task", "task", task) task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task") if task.TaskType == int(TASK_TYPE_CLASSIFICATION) { logger.Info("Classification Task") if err = ClassifyTask(base, task); err != nil { logger.Error("Classification task failed", "error", err) } back_channel <- index continue } else if task.TaskType == int(TASK_TYPE_TRAINING) { logger.Info("Training Task") if err = RunTaskTrain(base, task); err != nil { logger.Error("Failed to tain the model", "error", err) } back_channel <- index continue } else if task.TaskType == int(TASK_TYPE_RETRAINING) { logger.Info("Retraining Task") if err = RunTaskRetrain(base, task); err != nil { logger.Error("Failed to tain the model", "error", err) } back_channel <- index continue } else if task.TaskType == int(TASK_TYPE_DELETE_USER) { logger.Warn("User deleting Task") if err = DeleteUser(base, task); err != nil { logger.Error("Failed to tain the model", "error", err) } back_channel <- index continue } logger.Error("Do not know how to route task", "task", task) task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task") back_channel <- index } } /** * Handle remote runner */ func handleRemoteTask(handler *Handle, base BasePack, runner_id string, task Task) { logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, TimeFormat: time.Kitchen, Prefix: fmt.Sprintf("Runner pre %s", runner_id), }) defer func() { if r := recover(); r != nil { logger.Error("Runner failed to setup for runner", "due to", r, "stack", string(debug.Stack())) // TODO maybe create better failed task task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to setup task for runner") } }() err := task.UpdateStatus(base, TASK_PICKED_UP, "Failed to setup task for runner") if err != nil { logger.Error("Failed to mark task as PICK UP") return } mutex := handler.DataMap["runners_mutex"].(*sync.Mutex) mutex.Lock() defer mutex.Unlock() switch task.TaskType { case int(TASK_TYPE_TRAINING): if err := PrepareTraining(handler, base, task, runner_id); err != nil { logger.Error("Failed to prepare for training", "err", err) } case int(TASK_TYPE_CLASSIFICATION): runners := handler.DataMap["runners"].(map[string]interface{}) runner := runners[runner_id].(map[string]interface{}) runner["task"] = &task runners[runner_id] = runner handler.DataMap["runners"] = runners default: logger.Error("Not sure what to do panicing", "taskType", task.TaskType) panic("not sure what to do") } } /** * Tells the orcchestator to look at the task list from time to time */ func attentionSeeker(config Config, back_channel chan int) { logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, TimeFormat: time.Kitchen, Prefix: "Runner Orchestrator Logger [Attention]", }) defer func() { if r := recover(); r != nil { logger.Error("Attencion seeker dies", "due to", r) } }() logger.Info("Started up") t, err := time.ParseDuration(config.GpuWorker.Pulling) if err != nil { logger.Error("Failed to load", "error", err) return } for true { back_channel <- 0 time.Sleep(t) } } /** * Manages what worker should to Work */ func RunnerOrchestrator(db db.Db, config Config, handler *Handle) { logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, TimeFormat: time.Kitchen, Prefix: "Runner Orchestrator Logger", }) // Setup vars handler.DataMap["runners"] = map[string]interface{}{} handler.DataMap["runners_mutex"] = &sync.Mutex{} base := BasePackStruct{ Db: db, Logger: logger, Host: config.Hostname, } gpu_workers := config.GpuWorker.NumberOfWorkers logger.Info("Starting runners") task_runners := make([]chan Task, gpu_workers) task_runners_used := make([]bool, gpu_workers) // One more to accomudate the Attention Seeker channel back_channel := make(chan int, gpu_workers+1) defer func() { if r := recover(); r != nil { logger.Error("Recovered in Orchestrator restarting", "due to", r) for x := range task_runners { close(task_runners[x]) } close(back_channel) go RunnerOrchestrator(db, config, handler) } }() go attentionSeeker(config, back_channel) // Start the runners for i := 0; i < gpu_workers; i++ { task_runners[i] = make(chan Task, 10) task_runners_used[i] = false go runner(config, db, task_runners[i], i+1, back_channel) } var task_to_dispatch *Task = nil for i := range back_channel { if i > 0 { logger.Info("Runner freed", "runner", i) task_runners_used[i-1] = false } else if i < 0 { logger.Error("Runner died! Restarting!", "runner", i) i = int(math.Abs(float64(i)) - 1) task_runners_used[i] = false go runner(config, db, task_runners[i], i+1, back_channel) } if task_to_dispatch == nil { var task TaskT err := GetDBOnce(db, &task, "tasks as t "+ // Get depenencies "left join tasks_dependencies as td on t.id=td.main_id "+ // Get the task that the depencey resolves to "left join tasks as t2 on t2.id=td.dependent_id "+ "where t.status=1 "+ "group by t.id having count(td.id) filter (where t2.status in (0,1,2,3)) = 0;") if err != NotFoundError && err != nil { log.Error("Failed to get tasks from db", "err", err) continue } if err == NotFoundError { task_to_dispatch = nil } else { temp := Task(task) task_to_dispatch = &temp } } if task_to_dispatch != nil { // Only let CPU tasks be done by the local users if task_to_dispatch.TaskType == int(TASK_TYPE_DELETE_USER) { for i := 0; i < len(task_runners_used); i += 1 { if !task_runners_used[i] { task_runners[i] <- *task_to_dispatch task_runners_used[i] = true task_to_dispatch = nil break } } continue } mutex := handler.DataMap["runners_mutex"].(*sync.Mutex) mutex.Lock() remote_runners := handler.DataMap["runners"].(map[string]interface{}) for k, v := range remote_runners { runner_data := v.(map[string]interface{}) runner_info := runner_data["runner_info"].(*Runner) if runner_data["task"] != nil { continue } if runner_info.UserId == task_to_dispatch.UserId { go handleRemoteTask(handler, base, k, *task_to_dispatch) task_to_dispatch = nil break } } mutex.Unlock() } } } func StartRunners(db db.Db, config Config, handler *Handle) { go RunnerOrchestrator(db, config, handler) }