feat: add tasks closes #74
This commit is contained in:
123
logic/tasks/handleUpload.go
Normal file
123
logic/tasks/handleUpload.go
Normal file
@@ -0,0 +1,123 @@
|
||||
package tasks
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func handleUpload(handler *Handle) {
|
||||
handler.PostAuth("/tasks/start/image", 1, func(c *Context) *Error {
|
||||
|
||||
read_form, err := c.R.MultipartReader()
|
||||
if err != nil {
|
||||
return c.JsonBadRequest("Please provide a valid form data request!")
|
||||
}
|
||||
|
||||
var json_data string
|
||||
var file []byte
|
||||
|
||||
for {
|
||||
part, err_part := read_form.NextPart()
|
||||
if err_part == io.EOF {
|
||||
break
|
||||
} else if err_part != nil {
|
||||
return c.JsonBadRequest("Please provide a valid form data request!")
|
||||
}
|
||||
if part.FormName() == "json_data" {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(part)
|
||||
json_data = buf.String()
|
||||
}
|
||||
if part.FormName() == "file" {
|
||||
buf := new(bytes.Buffer)
|
||||
buf.ReadFrom(part)
|
||||
file = buf.Bytes()
|
||||
}
|
||||
}
|
||||
|
||||
var requestData struct {
|
||||
ModelId string `json:"id" validate:"required"`
|
||||
}
|
||||
|
||||
_err := c.ParseJson(&requestData, json_data)
|
||||
if _err != nil {
|
||||
return _err
|
||||
}
|
||||
|
||||
model, err := GetBaseModel(c.Db, requestData.ModelId)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
|
||||
switch model.Status {
|
||||
case READY:
|
||||
case READY_RETRAIN:
|
||||
case READY_ALTERATION:
|
||||
case READY_ALTERATION_FAILED:
|
||||
case READY_RETRAIN_FAILED:
|
||||
// Model can run
|
||||
|
||||
default:
|
||||
return c.SendJSONStatus(http.StatusBadRequest, "Model not in the correct status to be able to evaludate a model")
|
||||
}
|
||||
|
||||
// TODO Check if the user can use this model
|
||||
|
||||
type CreateNewTask struct {
|
||||
UserId string `db:"user_id"`
|
||||
ModelId string `db:"model_id"`
|
||||
TaskType int `db:"task_type"`
|
||||
Status int `db:"status"`
|
||||
}
|
||||
|
||||
newTask := CreateNewTask{
|
||||
UserId: c.User.Id,
|
||||
ModelId: model.Id,
|
||||
// TODO move this to an enum
|
||||
TaskType: 1,
|
||||
Status: 0,
|
||||
}
|
||||
|
||||
id, err := InsertReturnId(c, &newTask, "tasks", "id")
|
||||
if err != nil {
|
||||
return c.E500M("Error 500", err)
|
||||
}
|
||||
|
||||
save_path := path.Join("savedData", model.Id, "tasks")
|
||||
os.MkdirAll(save_path, os.ModePerm)
|
||||
|
||||
img_path := path.Join(save_path, id+"."+model.Format)
|
||||
|
||||
img_file, err := os.Create(img_path)
|
||||
if err != nil {
|
||||
if _err := UpdateTaskStatus(c,id, -1, "Failed to create the file"); _err != nil {
|
||||
c.Logger.Error("Failed to update tasks")
|
||||
}
|
||||
return c.E500M("Failed to create the file", err)
|
||||
}
|
||||
defer img_file.Close()
|
||||
img_file.Write(file)
|
||||
|
||||
if !TestImgForModel(c, model, img_path) {
|
||||
if _err := UpdateTaskStatus(c, id, -1, "The provided image is not a valid image for this model"); _err != nil {
|
||||
c.Logger.Error("Failed to update tasks")
|
||||
}
|
||||
return c.JsonBadRequest(struct {
|
||||
Message string `json:"message"`
|
||||
Id string `json:"task_id"`
|
||||
} { "Provided image does not match the model", id})
|
||||
}
|
||||
|
||||
UpdateStatus(c, "tasks", id, 1)
|
||||
|
||||
return c.SendJSON(struct {Id string `json:"id"`}{id})
|
||||
})
|
||||
}
|
||||
11
logic/tasks/index.go
Normal file
11
logic/tasks/index.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package tasks
|
||||
|
||||
import (
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func HandleTasks (handle *Handle) {
|
||||
handleUpload(handle)
|
||||
handleList(handle)
|
||||
}
|
||||
|
||||
61
logic/tasks/list.go
Normal file
61
logic/tasks/list.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package tasks
|
||||
|
||||
import (
|
||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
)
|
||||
|
||||
func handleList(handler *Handle) {
|
||||
handler.PostAuth("/tasks/list", 1, func(c *Context) *Error {
|
||||
var err error = nil
|
||||
|
||||
var requestData struct {
|
||||
ModelId string `json:"model_id"`
|
||||
Page int `json:"page"`
|
||||
}
|
||||
|
||||
if _err := c.ToJSON(&requestData); _err != nil {
|
||||
return _err
|
||||
}
|
||||
|
||||
if requestData.ModelId == "" && c.User.UserType < int(dbtypes.User_Admin) {
|
||||
return c.SendJSONStatus(400, "Please provide a model_id")
|
||||
}
|
||||
|
||||
if requestData.ModelId != "" {
|
||||
_, err := GetBaseModel(c.Db, requestData.ModelId)
|
||||
if err == ModelNotFoundError {
|
||||
return c.SendJSONStatus(404, "Model not found!")
|
||||
} else if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
var rows []*Task = nil
|
||||
|
||||
if requestData.ModelId != "" {
|
||||
rows, err = GetDbMultitple[Task](c, "tasks where model_id=$1 order by created_on desc limit 11 offset $2", requestData.ModelId, requestData.Page * 10)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
} else {
|
||||
rows, err = GetDbMultitple[Task](c, "tasks order by created_on desc limit 11 offset $1", requestData.Page * 10)
|
||||
if err != nil {
|
||||
return c.Error500(err)
|
||||
}
|
||||
}
|
||||
|
||||
max_len := min(11, len(rows))
|
||||
|
||||
c.ShowMessage = false
|
||||
return c.SendJSON(struct {
|
||||
TaskList []*Task `json:"task_list"`
|
||||
ShowNext bool `json:"show_next"`
|
||||
} {
|
||||
rows[0:max_len],
|
||||
len(rows) > 10,
|
||||
})
|
||||
})
|
||||
}
|
||||
160
logic/tasks/runner/runner.go
Normal file
160
logic/tasks/runner/runner.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package task_runner
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
)
|
||||
|
||||
/**
|
||||
* Actually runs the code
|
||||
*/
|
||||
func runner(db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: fmt.Sprintf("Runner %d", index),
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Recovered in file processor", "processor id", index, "due to", r)
|
||||
back_channel <- -index
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("Started up")
|
||||
|
||||
var err error
|
||||
|
||||
base := BasePackStruct{
|
||||
Db: db,
|
||||
Logger: logger,
|
||||
}
|
||||
|
||||
for task := range task_channel {
|
||||
logger.Info("Got task", "task", task)
|
||||
|
||||
if task.TaskType == int(TASK_TYPE_CLASSIFICATION) {
|
||||
logger.Info("Classification Task")
|
||||
if err = ClassifyTask(base, task); err != nil {
|
||||
logger.Error("Classification task failed", "error", "err")
|
||||
}
|
||||
|
||||
back_channel <- index
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
logger.Error("Do not know how to route task", "task", task)
|
||||
back_channel <- index
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tells the orcchestator to look at the task list from time to time
|
||||
*/
|
||||
func attentionSeeker(config Config, back_channel chan int) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: "Runner Orchestrator Logger [Attention]",
|
||||
})
|
||||
|
||||
logger.Info("Started up")
|
||||
|
||||
t, err := time.ParseDuration(config.GpuWorker.Pulling)
|
||||
if err != nil {
|
||||
logger.Error("Failed to load", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for true {
|
||||
back_channel <- 0
|
||||
|
||||
time.Sleep(t)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages what worker should to Work
|
||||
*/
|
||||
func RunnerOrchestrator(db *sql.DB, config Config) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
TimeFormat: time.Kitchen,
|
||||
Prefix: "Runner Orchestrator Logger",
|
||||
})
|
||||
|
||||
gpu_workers := config.GpuWorker.NumberOfWorkers
|
||||
|
||||
logger.Info("Starting runners")
|
||||
|
||||
task_runners := make([]chan Task, gpu_workers)
|
||||
task_runners_used := make([]bool, gpu_workers)
|
||||
// One more to accomudate the Attention Seeker channel
|
||||
back_channel := make(chan int, gpu_workers+1)
|
||||
|
||||
go attentionSeeker(config, back_channel)
|
||||
|
||||
// Start the runners
|
||||
for i := 0; i < gpu_workers; i++ {
|
||||
task_runners[i] = make(chan Task, 10)
|
||||
task_runners_used[i] = false
|
||||
go runner(db, task_runners[i], i+1, back_channel)
|
||||
}
|
||||
|
||||
var task_to_dispatch *Task = nil
|
||||
|
||||
for i := range back_channel {
|
||||
|
||||
if i > 0 {
|
||||
logger.Info("Runner freed", "runner", i)
|
||||
task_runners_used[i-1] = false
|
||||
} else if i < 0 {
|
||||
logger.Error("Runner died! Restarting!", "runner", i)
|
||||
task_runners_used[i-1] = false
|
||||
go runner(db, task_runners[i-1], i, back_channel)
|
||||
}
|
||||
|
||||
if task_to_dispatch == nil {
|
||||
var task Task
|
||||
err := GetDBOnce(db, &task, "tasks where status=$1 limit 1", TASK_TODO)
|
||||
if err != NotFoundError && err != nil{
|
||||
log.Error("Failed to get tasks from db")
|
||||
continue
|
||||
}
|
||||
if err == NotFoundError {
|
||||
task_to_dispatch = nil
|
||||
} else {
|
||||
task_to_dispatch = &task
|
||||
}
|
||||
}
|
||||
|
||||
if task_to_dispatch != nil {
|
||||
for i := 0; i < len(task_runners_used); i += 1 {
|
||||
if !task_runners_used[i] {
|
||||
task_runners[i] <- *task_to_dispatch
|
||||
task_runners_used[i] = true
|
||||
task_to_dispatch = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func StartRunners(db *sql.DB, config Config) {
|
||||
go RunnerOrchestrator(db, config)
|
||||
}
|
||||
68
logic/tasks/utils/utils.go
Normal file
68
logic/tasks/utils/utils.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package tasks_utils
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils"
|
||||
"github.com/goccy/go-json"
|
||||
)
|
||||
|
||||
type Task struct {
|
||||
Id string `db:"id" json:"id"`
|
||||
UserId string `db:"user_id" json:"user_id"`
|
||||
ModelId string `db:"model_id" json:"model_id"`
|
||||
Status int `db:"status" json:"status"`
|
||||
StatusMessage string `db:"status_message" json:"status_message"`
|
||||
UserConfirmed int `db:"user_confirmed" json:"user_confirmed"`
|
||||
Compacted int `db:"compacted" json:"compacted"`
|
||||
TaskType int `db:"task_type" json:"type"`
|
||||
Result string `db:"result" json:"result"`
|
||||
CreatedOn time.Time `db:"created_on" json:"created"`
|
||||
}
|
||||
|
||||
type TaskStatus int
|
||||
|
||||
const (
|
||||
TASK_FAILED_RUNNING TaskStatus = -2
|
||||
TASK_FAILED_CREATION = -1
|
||||
TASK_PREPARING = 0
|
||||
TASK_TODO = 1
|
||||
TASK_PICKED_UP = 2
|
||||
TASK_RUNNING = 3
|
||||
TASK_DONE = 4
|
||||
)
|
||||
|
||||
type TaskType int
|
||||
|
||||
const (
|
||||
TASK_TYPE_CLASSIFICATION TaskType = 1
|
||||
)
|
||||
|
||||
func (t Task) UpdateStatus(base BasePack, status TaskStatus, message string) (err error) {
|
||||
return UpdateTaskStatus(base, t.Id, status, message)
|
||||
}
|
||||
|
||||
/**
|
||||
* Call the UpdateStatus function and logs on the case of failure!
|
||||
* This varient does not return any error message
|
||||
*/
|
||||
func (t Task) UpdateStatusLog(base BasePack, status TaskStatus, message string) {
|
||||
err := t.UpdateStatus(base, status, message)
|
||||
if err != nil {
|
||||
base.GetLogger().Error("Failed to update task status", "error", err, "task", t.Id)
|
||||
}
|
||||
}
|
||||
|
||||
func UpdateTaskStatus(base BasePack, id string, status TaskStatus, message string) (err error) {
|
||||
_, err = base.GetDb().Exec("update tasks set status=$1, status_message=$2 where id=$3", status, message, id)
|
||||
return
|
||||
}
|
||||
|
||||
func (t Task) SetResult(base BasePack, result any) (err error) {
|
||||
text, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = base.GetDb().Exec("update tasks set result=$1 where id=$2", text, t.Id)
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user