make models only train one at the time closes #48
This commit is contained in:
@@ -3,13 +3,16 @@ package task_runner
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
@@ -17,7 +20,7 @@ import (
|
||||
/**
|
||||
* Actually runs the code
|
||||
*/
|
||||
func runner(db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
|
||||
func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
@@ -27,36 +30,46 @@ func runner(db *sql.DB, task_channel chan Task, index int, back_channel chan int
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Recovered in file processor", "processor id", index, "due to", r)
|
||||
logger.Error("Recovered in runner", "processor id", index, "due to", r, "stack", string(debug.Stack()))
|
||||
back_channel <- -index
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("Started up")
|
||||
|
||||
var err error
|
||||
var err error
|
||||
|
||||
base := BasePackStruct{
|
||||
Db: db,
|
||||
Logger: logger,
|
||||
}
|
||||
base := BasePackStruct{
|
||||
Db: db,
|
||||
Logger: logger,
|
||||
Host: config.Hostname,
|
||||
}
|
||||
|
||||
for task := range task_channel {
|
||||
logger.Info("Got task", "task", task)
|
||||
task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task")
|
||||
|
||||
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
|
||||
logger.Info("Classification Task")
|
||||
if err = ClassifyTask(base, task); err != nil {
|
||||
logger.Error("Classification task failed", "error", "err")
|
||||
}
|
||||
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
|
||||
logger.Info("Classification Task")
|
||||
if err = ClassifyTask(base, task); err != nil {
|
||||
logger.Error("Classification task failed", "error", err)
|
||||
}
|
||||
|
||||
back_channel <- index
|
||||
continue
|
||||
}
|
||||
back_channel <- index
|
||||
continue
|
||||
} else if task.TaskType == int(TASK_TYPE_TRAINING) {
|
||||
logger.Info("Training Task")
|
||||
if err = RunTaskTrain(base, task); err != nil {
|
||||
logger.Error("Failed to tain the model", "error", err)
|
||||
}
|
||||
|
||||
back_channel <- index
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Error("Do not know how to route task", "task", task)
|
||||
back_channel <- index
|
||||
logger.Error("Do not know how to route task", "task", task)
|
||||
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task")
|
||||
back_channel <- index
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,18 +84,24 @@ func attentionSeeker(config Config, back_channel chan int) {
|
||||
Prefix: "Runner Orchestrator Logger [Attention]",
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Attencion seeker dies", "due to", r)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("Started up")
|
||||
|
||||
t, err := time.ParseDuration(config.GpuWorker.Pulling)
|
||||
if err != nil {
|
||||
logger.Error("Failed to load", "error", err)
|
||||
return
|
||||
}
|
||||
t, err := time.ParseDuration(config.GpuWorker.Pulling)
|
||||
if err != nil {
|
||||
logger.Error("Failed to load", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for true {
|
||||
back_channel <- 0
|
||||
|
||||
time.Sleep(t)
|
||||
time.Sleep(t)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,16 +125,27 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
|
||||
// One more to accomudate the Attention Seeker channel
|
||||
back_channel := make(chan int, gpu_workers+1)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Recovered in Orchestrator restarting", "due to", r)
|
||||
for x := range task_runners {
|
||||
close(task_runners[x])
|
||||
}
|
||||
close(back_channel)
|
||||
go RunnerOrchestrator(db, config)
|
||||
}
|
||||
}()
|
||||
|
||||
go attentionSeeker(config, back_channel)
|
||||
|
||||
// Start the runners
|
||||
for i := 0; i < gpu_workers; i++ {
|
||||
task_runners[i] = make(chan Task, 10)
|
||||
task_runners_used[i] = false
|
||||
go runner(db, task_runners[i], i+1, back_channel)
|
||||
go runner(config, db, task_runners[i], i+1, back_channel)
|
||||
}
|
||||
|
||||
var task_to_dispatch *Task = nil
|
||||
var task_to_dispatch *Task = nil
|
||||
|
||||
for i := range back_channel {
|
||||
|
||||
@@ -124,34 +154,35 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
|
||||
task_runners_used[i-1] = false
|
||||
} else if i < 0 {
|
||||
logger.Error("Runner died! Restarting!", "runner", i)
|
||||
task_runners_used[i-1] = false
|
||||
go runner(db, task_runners[i-1], i, back_channel)
|
||||
i = int(math.Abs(float64(i)) - 1)
|
||||
task_runners_used[i] = false
|
||||
go runner(config, db, task_runners[i], i+1, back_channel)
|
||||
}
|
||||
|
||||
if task_to_dispatch == nil {
|
||||
var task Task
|
||||
err := GetDBOnce(db, &task, "tasks where status=$1 limit 1", TASK_TODO)
|
||||
if err != NotFoundError && err != nil{
|
||||
log.Error("Failed to get tasks from db")
|
||||
continue
|
||||
}
|
||||
if err == NotFoundError {
|
||||
task_to_dispatch = nil
|
||||
} else {
|
||||
task_to_dispatch = &task
|
||||
}
|
||||
}
|
||||
if task_to_dispatch == nil {
|
||||
var task Task
|
||||
err := GetDBOnce(db, &task, "tasks where status=$1 limit 1", TASK_TODO)
|
||||
if err != NotFoundError && err != nil {
|
||||
log.Error("Failed to get tasks from db")
|
||||
continue
|
||||
}
|
||||
if err == NotFoundError {
|
||||
task_to_dispatch = nil
|
||||
} else {
|
||||
task_to_dispatch = &task
|
||||
}
|
||||
}
|
||||
|
||||
if task_to_dispatch != nil {
|
||||
for i := 0; i < len(task_runners_used); i += 1 {
|
||||
if !task_runners_used[i] {
|
||||
task_runners[i] <- *task_to_dispatch
|
||||
task_runners_used[i] = true
|
||||
task_to_dispatch = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if task_to_dispatch != nil {
|
||||
for i := 0; i < len(task_runners_used); i += 1 {
|
||||
if !task_runners_used[i] {
|
||||
task_runners[i] <- *task_to_dispatch
|
||||
task_runners_used[i] = true
|
||||
task_to_dispatch = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type Task struct {
|
||||
UserConfirmed int `db:"user_confirmed" json:"user_confirmed"`
|
||||
Compacted int `db:"compacted" json:"compacted"`
|
||||
TaskType int `db:"task_type" json:"type"`
|
||||
ExtraTaskInfo string `db:"extra_task_info" json:"extra_task_info"`
|
||||
Result string `db:"result" json:"result"`
|
||||
CreatedOn time.Time `db:"created_on" json:"created"`
|
||||
}
|
||||
@@ -35,7 +36,8 @@ const (
|
||||
type TaskType int
|
||||
|
||||
const (
|
||||
TASK_TYPE_CLASSIFICATION TaskType = 1
|
||||
TASK_TYPE_CLASSIFICATION TaskType = 1 + iota
|
||||
TASK_TYPE_TRAINING
|
||||
)
|
||||
|
||||
func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) {
|
||||
|
||||
Reference in New Issue
Block a user