168 lines
4.0 KiB
Go
168 lines
4.0 KiB
Go
package utils
|
|
|
|
import (
|
|
"database/sql"
|
|
"os"
|
|
"strings"
|
|
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
|
|
|
toml "github.com/BurntSushi/toml"
|
|
"github.com/charmbracelet/log"
|
|
)
|
|
|
|
type WorkerConfig struct {
|
|
NumberOfWorkers int `toml:"number_of_workers"`
|
|
Pulling string `toml:"pulling_time"`
|
|
}
|
|
|
|
type ServiceUser struct {
|
|
User string `toml:"user"`
|
|
UserId string `toml:"__user__id__"`
|
|
}
|
|
|
|
type DbInfo struct {
|
|
MaxConnections int `toml:"max_connections"`
|
|
}
|
|
|
|
type Config struct {
|
|
Hostname string
|
|
Port int
|
|
NumberOfWorkers int `toml:"number_of_workers"`
|
|
|
|
SupressCuda int `toml:"supress_cuda"`
|
|
CleanUpOnStartup int `toml:"clean_up_on_startup"`
|
|
|
|
GpuWorker WorkerConfig `toml:"Worker"`
|
|
|
|
ServiceUser ServiceUser `toml:"ServiceUser"`
|
|
|
|
DbInfo DbInfo `toml:"DB"`
|
|
}
|
|
|
|
func LoadConfig() Config {
|
|
|
|
log.Info("Loading the config file")
|
|
|
|
dat, err := os.ReadFile("./config.toml")
|
|
if err != nil {
|
|
log.Error("Failed to load config file", "err", err)
|
|
// Use default values
|
|
return Config{
|
|
Hostname: "localhost",
|
|
Port: 8000,
|
|
NumberOfWorkers: 10,
|
|
CleanUpOnStartup: 1,
|
|
SupressCuda: 1,
|
|
GpuWorker: WorkerConfig{
|
|
NumberOfWorkers: 1,
|
|
Pulling: "500ms",
|
|
},
|
|
ServiceUser: ServiceUser{
|
|
User: "Service",
|
|
UserId: "",
|
|
},
|
|
DbInfo: DbInfo{
|
|
MaxConnections: 200,
|
|
},
|
|
}
|
|
}
|
|
|
|
var conf Config
|
|
_, err = toml.Decode(string(dat), &conf)
|
|
|
|
if conf.SupressCuda == 1 {
|
|
log.Warn("Supressing Cuda Messages!")
|
|
os.Setenv("TF_CPP_MIN_VLOG_LEVEL", "3")
|
|
os.Setenv("TF_CPP_MIN_LOG_LEVEL", "3")
|
|
}
|
|
|
|
return conf
|
|
}
|
|
|
|
func failLog(err error) {
|
|
if err == nil {
|
|
return
|
|
}
|
|
log.Fatal("Failed on setup", "error", err)
|
|
}
|
|
|
|
func (c *Config) Cleanup(db *sql.DB) {
|
|
if c.CleanUpOnStartup != 1 {
|
|
return
|
|
}
|
|
|
|
_, err := db.Exec("update models set status=$1 where status=$2", FAILED_PREPARING_ZIP_FILE, PREPARING_ZIP_FILE)
|
|
failLog(err)
|
|
_, err = db.Exec("update models set status=$1 where status=$2", FAILED_PREPARING, PREPARING)
|
|
failLog(err)
|
|
_, err = db.Exec("update tasks set status=$1 where status=$2", TASK_PICKED_UP, TASK_TODO)
|
|
failLog(err)
|
|
|
|
tasks, err := GetDbMultitple[Task](db, "tasks where status=$1", TASK_RUNNING)
|
|
failLog(err)
|
|
|
|
base := BasePackStruct{Db: db, Logger: log.Default()}
|
|
|
|
for i := range tasks {
|
|
if tasks[i].TaskType == int(TASK_TYPE_CLASSIFICATION) {
|
|
tasks[i].UpdateStatus(base, TASK_TODO, "Reseting task")
|
|
continue
|
|
}
|
|
if tasks[i].TaskType == int(TASK_TYPE_RETRAINING) {
|
|
tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again")
|
|
_, err = db.Exec("update models set status=$1 where id=$2", READY_RETRAIN_FAILED, tasks[i].ModelId)
|
|
failLog(err)
|
|
continue
|
|
}
|
|
if tasks[i].TaskType == int(TASK_TYPE_TRAINING) {
|
|
tasks[i].UpdateStatus(base, TASK_FAILED_RUNNING, "Task inturupted by server restart please try again")
|
|
_, err = db.Exec("update models set status=$1 where id=$2", FAILED_TRAINING, tasks[i].ModelId)
|
|
failLog(err)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Config) GenerateToken(db *sql.DB) {
|
|
if c.ServiceUser.User == "" {
|
|
log.Fatal("A user needs to be set in a configuration file")
|
|
}
|
|
|
|
var user struct {
|
|
Password string `db:"password"`
|
|
UserId string `db:"id"`
|
|
}
|
|
|
|
err := GetDBOnce(db, &user, "users where username=$1;", c.ServiceUser.User)
|
|
if err == NotFoundError {
|
|
var newUser = struct {
|
|
Username string
|
|
Email string
|
|
Salt string
|
|
Password string
|
|
UserType UserType `db:"user_type"`
|
|
}{
|
|
c.ServiceUser.User,
|
|
c.ServiceUser.User,
|
|
"",
|
|
"",
|
|
User_Admin,
|
|
}
|
|
id, err := InsertReturnId(db, &newUser, "users", "id")
|
|
if err != nil {
|
|
log.Fatal("Failed to create user", "err", err)
|
|
}
|
|
c.ServiceUser.UserId = id
|
|
} else if err != nil {
|
|
log.Fatal("To get user name", "err", err)
|
|
return
|
|
} else {
|
|
if len(strings.ReplaceAll(user.Password, " ", "")) != 0 {
|
|
log.Fatal("User already exists and is not the service user", "user", user)
|
|
}
|
|
c.ServiceUser.UserId = user.UserId
|
|
}
|
|
}
|