make models only train one at the time closes #48

This commit is contained in:
2024-04-15 23:04:53 +01:00
parent 2318bad5d8
commit f4e70d7a73
11 changed files with 607 additions and 475 deletions

View File

@@ -19,11 +19,17 @@ import (
type BasePack interface {
GetDb() *sql.DB
GetLogger() *log.Logger
GetHost() string
}
type BasePackStruct struct {
Db *sql.DB
Logger *log.Logger
Host string
}
func (b BasePackStruct) GetHost() string {
return b.Host
}
func (b BasePackStruct) GetDb() *sql.DB {

View File

@@ -5,7 +5,6 @@ import (
"errors"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
type ModelClass struct {
@@ -15,8 +14,8 @@ type ModelClass struct {
Status int `json:"status"`
}
func ListClasses(c *Context, model_id string) (cls []*ModelClass, err error) {
return GetDbMultitple[ModelClass](c, "model_classes where model_id=$1", model_id)
func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) {
return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id)
}
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {

File diff suppressed because it is too large Load Diff

View File

@@ -3,13 +3,16 @@ package task_runner
import (
"database/sql"
"fmt"
"math"
"os"
"runtime/debug"
"time"
"github.com/charmbracelet/log"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
)
@@ -17,7 +20,7 @@ import (
/**
* Actually runs the code
*/
func runner(db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true,
ReportTimestamp: true,
@@ -27,36 +30,46 @@ func runner(db *sql.DB, task_channel chan Task, index int, back_channel chan int
defer func() {
if r := recover(); r != nil {
logger.Error("Recovered in file processor", "processor id", index, "due to", r)
logger.Error("Recovered in runner", "processor id", index, "due to", r, "stack", string(debug.Stack()))
back_channel <- -index
}
}()
logger.Info("Started up")
var err error
var err error
base := BasePackStruct{
Db: db,
Logger: logger,
}
base := BasePackStruct{
Db: db,
Logger: logger,
Host: config.Hostname,
}
for task := range task_channel {
logger.Info("Got task", "task", task)
task.UpdateStatusLog(base, TASK_PICKED_UP, "Runner picked up task")
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
logger.Info("Classification Task")
if err = ClassifyTask(base, task); err != nil {
logger.Error("Classification task failed", "error", "err")
}
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
logger.Info("Classification Task")
if err = ClassifyTask(base, task); err != nil {
logger.Error("Classification task failed", "error", err)
}
back_channel <- index
continue
}
back_channel <- index
continue
} else if task.TaskType == int(TASK_TYPE_TRAINING) {
logger.Info("Training Task")
if err = RunTaskTrain(base, task); err != nil {
logger.Error("Failed to tain the model", "error", err)
}
back_channel <- index
continue
}
logger.Error("Do not know how to route task", "task", task)
back_channel <- index
logger.Error("Do not know how to route task", "task", task)
task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Do not know how to route task")
back_channel <- index
}
}
@@ -71,18 +84,24 @@ func attentionSeeker(config Config, back_channel chan int) {
Prefix: "Runner Orchestrator Logger [Attention]",
})
defer func() {
if r := recover(); r != nil {
logger.Error("Attencion seeker dies", "due to", r)
}
}()
logger.Info("Started up")
t, err := time.ParseDuration(config.GpuWorker.Pulling)
if err != nil {
logger.Error("Failed to load", "error", err)
return
}
t, err := time.ParseDuration(config.GpuWorker.Pulling)
if err != nil {
logger.Error("Failed to load", "error", err)
return
}
for true {
back_channel <- 0
time.Sleep(t)
time.Sleep(t)
}
}
@@ -106,16 +125,27 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
// One more to accomudate the Attention Seeker channel
back_channel := make(chan int, gpu_workers+1)
defer func() {
if r := recover(); r != nil {
logger.Error("Recovered in Orchestrator restarting", "due to", r)
for x := range task_runners {
close(task_runners[x])
}
close(back_channel)
go RunnerOrchestrator(db, config)
}
}()
go attentionSeeker(config, back_channel)
// Start the runners
for i := 0; i < gpu_workers; i++ {
task_runners[i] = make(chan Task, 10)
task_runners_used[i] = false
go runner(db, task_runners[i], i+1, back_channel)
go runner(config, db, task_runners[i], i+1, back_channel)
}
var task_to_dispatch *Task = nil
var task_to_dispatch *Task = nil
for i := range back_channel {
@@ -124,34 +154,35 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
task_runners_used[i-1] = false
} else if i < 0 {
logger.Error("Runner died! Restarting!", "runner", i)
task_runners_used[i-1] = false
go runner(db, task_runners[i-1], i, back_channel)
i = int(math.Abs(float64(i)) - 1)
task_runners_used[i] = false
go runner(config, db, task_runners[i], i+1, back_channel)
}
if task_to_dispatch == nil {
var task Task
err := GetDBOnce(db, &task, "tasks where status=$1 limit 1", TASK_TODO)
if err != NotFoundError && err != nil{
log.Error("Failed to get tasks from db")
continue
}
if err == NotFoundError {
task_to_dispatch = nil
} else {
task_to_dispatch = &task
}
}
if task_to_dispatch == nil {
var task Task
err := GetDBOnce(db, &task, "tasks where status=$1 limit 1", TASK_TODO)
if err != NotFoundError && err != nil {
log.Error("Failed to get tasks from db")
continue
}
if err == NotFoundError {
task_to_dispatch = nil
} else {
task_to_dispatch = &task
}
}
if task_to_dispatch != nil {
for i := 0; i < len(task_runners_used); i += 1 {
if !task_runners_used[i] {
task_runners[i] <- *task_to_dispatch
task_runners_used[i] = true
task_to_dispatch = nil
break
}
}
}
if task_to_dispatch != nil {
for i := 0; i < len(task_runners_used); i += 1 {
if !task_runners_used[i] {
task_runners[i] <- *task_to_dispatch
task_runners_used[i] = true
task_to_dispatch = nil
break
}
}
}
}
}

View File

@@ -16,6 +16,7 @@ type Task struct {
UserConfirmed int `db:"user_confirmed" json:"user_confirmed"`
Compacted int `db:"compacted" json:"compacted"`
TaskType int `db:"task_type" json:"type"`
ExtraTaskInfo string `db:"extra_task_info" json:"extra_task_info"`
Result string `db:"result" json:"result"`
CreatedOn time.Time `db:"created_on" json:"created"`
}
@@ -35,7 +36,8 @@ const (
type TaskType int
const (
TASK_TYPE_CLASSIFICATION TaskType = 1
TASK_TYPE_CLASSIFICATION TaskType = 1 + iota
TASK_TYPE_TRAINING
)
func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) {

View File

@@ -199,6 +199,10 @@ func (c Context) GetLogger() *log.Logger {
return c.Logger
}
func (c Context) GetHost() string {
return c.Handle.Config.Hostname
}
func (c Context) Query(query string, args ...any) (*sql.Rows, error) {
return c.Db.Query(query, args...)
}
@@ -337,11 +341,11 @@ func (c *Context) GetModelFromId(id_path string) (*dbtypes.BaseModel, *Error) {
return model, nil
}
func ModelUpdateStatus(c *Context, id string, status int) {
_, err := c.Db.Exec("update models set status=$1 where id=$2;", status, id)
func ModelUpdateStatus(c dbtypes.BasePack, id string, status int) {
_, err := c.GetDb().Exec("update models set status=$1 where id=$2;", status, id)
if err != nil {
c.Logger.Error("Failed to update model status", "err", err)
c.Logger.Warn("TODO Maybe handle better")
c.GetLogger().Error("Failed to update model status", "err", err)
c.GetLogger().Warn("TODO Maybe handle better")
}
}