193 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			193 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package task_runner
 | |
| 
 | |
| import (
 | |
| 	"database/sql"
 | |
| 	"fmt"
 | |
| 	"math"
 | |
| 	"os"
 | |
| 	"runtime/debug"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/charmbracelet/log"
 | |
| 
 | |
| 	. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
 | |
| 	. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
 | |
| 	. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
 | |
| 	. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
 | |
| 	. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
 | |
| )
 | |
| 
 | |
| /**
 | |
| * Actually runs the code
 | |
|  */
 | |
| func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
 | |
| 	logger := log.NewWithOptions(os.Stdout, log.Options{
 | |
| 		ReportCaller:    true,
 | |
| 		ReportTimestamp: true,
 | |
| 		TimeFormat:      time.Kitchen,
 | |
| 		Prefix:          fmt.Sprintf("Runner %d", index),
 | |
| 	})
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			logger.Error("Recovered in runner", "processor id", index, "due to", r, "stack", string(debug.Stack()))
 | |
| 			back_channel <- -index
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	logger.Info("Started up")
 | |
| 
 | |
| 	var err error
 | |
| 
 | |
| 	base := BasePackStruct{
 | |
| 		Db:     db,
 | |
| 		Logger: logger,
 | |
| 		Host:   config.Hostname,
 | |
| 	}
 | |
| 
 | |
| 	for task := range task_channel {
 | |
| 		logger.Info("Got task", "task", task)
 | |
| 		task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task")
 | |
| 
 | |
| 		if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
 | |
| 			logger.Info("Classification Task")
 | |
| 			if err = ClassifyTask(base, task); err != nil {
 | |
| 				logger.Error("Classification task failed", "error", err)
 | |
| 			}
 | |
| 
 | |
| 			back_channel <- index
 | |
| 			continue
 | |
| 		} else if task.TaskType == int(TASK_TYPE_TRAINING) {
 | |
| 			logger.Info("Training Task")
 | |
| 			if err = RunTaskTrain(base, task); err != nil {
 | |
| 				logger.Error("Failed to tain the model", "error", err)
 | |
| 			}
 | |
| 
 | |
| 			back_channel <- index
 | |
| 			continue
 | |
| 		}
 | |
| 
 | |
| 		logger.Error("Do not know how to route task", "task", task)
 | |
| 		task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task")
 | |
| 		back_channel <- index
 | |
| 	}
 | |
| }
 | |
| 
 | |
| /**
 | |
| * Tells the orcchestator to look at the task list from time to time
 | |
|  */
 | |
| func attentionSeeker(config Config, back_channel chan int) {
 | |
| 	logger := log.NewWithOptions(os.Stdout, log.Options{
 | |
| 		ReportCaller:    true,
 | |
| 		ReportTimestamp: true,
 | |
| 		TimeFormat:      time.Kitchen,
 | |
| 		Prefix:          "Runner Orchestrator Logger [Attention]",
 | |
| 	})
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			logger.Error("Attencion seeker dies", "due to", r)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	logger.Info("Started up")
 | |
| 
 | |
| 	t, err := time.ParseDuration(config.GpuWorker.Pulling)
 | |
| 	if err != nil {
 | |
| 		logger.Error("Failed to load", "error", err)
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	for true {
 | |
| 		back_channel <- 0
 | |
| 
 | |
| 		time.Sleep(t)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| /**
 | |
| * Manages what worker should to Work
 | |
|  */
 | |
| func RunnerOrchestrator(db *sql.DB, config Config) {
 | |
| 	logger := log.NewWithOptions(os.Stdout, log.Options{
 | |
| 		ReportCaller:    true,
 | |
| 		ReportTimestamp: true,
 | |
| 		TimeFormat:      time.Kitchen,
 | |
| 		Prefix:          "Runner Orchestrator Logger",
 | |
| 	})
 | |
| 
 | |
| 	gpu_workers := config.GpuWorker.NumberOfWorkers
 | |
| 
 | |
| 	logger.Info("Starting runners")
 | |
| 
 | |
| 	task_runners := make([]chan Task, gpu_workers)
 | |
| 	task_runners_used := make([]bool, gpu_workers)
 | |
| 	// One more to accomudate the Attention Seeker channel
 | |
| 	back_channel := make(chan int, gpu_workers+1)
 | |
| 
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			logger.Error("Recovered in Orchestrator restarting", "due to", r)
 | |
| 			for x := range task_runners {
 | |
| 				close(task_runners[x])
 | |
| 			}
 | |
| 			close(back_channel)
 | |
| 			go RunnerOrchestrator(db, config)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	go attentionSeeker(config, back_channel)
 | |
| 
 | |
| 	// Start the runners
 | |
| 	for i := 0; i < gpu_workers; i++ {
 | |
| 		task_runners[i] = make(chan Task, 10)
 | |
| 		task_runners_used[i] = false
 | |
| 		go runner(config, db, task_runners[i], i+1, back_channel)
 | |
| 	}
 | |
| 
 | |
| 	var task_to_dispatch *Task = nil
 | |
| 
 | |
| 	for i := range back_channel {
 | |
| 
 | |
| 		if i > 0 {
 | |
| 			logger.Info("Runner freed", "runner", i)
 | |
| 			task_runners_used[i-1] = false
 | |
| 		} else if i < 0 {
 | |
| 			logger.Error("Runner died! Restarting!", "runner", i)
 | |
| 			i = int(math.Abs(float64(i)) - 1)
 | |
| 			task_runners_used[i] = false
 | |
| 			go runner(config, db, task_runners[i], i+1, back_channel)
 | |
| 		}
 | |
| 
 | |
| 		if task_to_dispatch == nil {
 | |
| 			var task Task
 | |
| 			err := GetDBOnce(db, &task, "tasks where status=$1 limit 1", TASK_TODO)
 | |
| 			if err != NotFoundError && err != nil {
 | |
| 				log.Error("Failed to get tasks from db")
 | |
| 				continue
 | |
| 			}
 | |
| 			if err == NotFoundError {
 | |
| 				task_to_dispatch = nil
 | |
| 			} else {
 | |
| 				task_to_dispatch = &task
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if task_to_dispatch != nil {
 | |
| 			for i := 0; i < len(task_runners_used); i += 1 {
 | |
| 				if !task_runners_used[i] {
 | |
| 					task_runners[i] <- *task_to_dispatch
 | |
| 					task_runners_used[i] = true
 | |
| 					task_to_dispatch = nil
 | |
| 					break
 | |
| 				}
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func StartRunners(db *sql.DB, config Config) {
 | |
| 	go RunnerOrchestrator(db, config)
 | |
| }
 |