moved to psql pool closes #99

This commit is contained in:
2024-04-17 17:46:43 +01:00
parent 8ece8306dd
commit a3913ccdf5
14 changed files with 132 additions and 98 deletions

48
logic/db/db.go Normal file
View File

@@ -0,0 +1,48 @@
package db
import (
"context"
"github.com/charmbracelet/log"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type DbContainer struct {
pool *pgxpool.Pool
}
type Db interface {
Query(query string, args ...any) (pgx.Rows, error)
Exec(query string, args ...any) (pgconn.CommandTag, error)
Begin() (pgx.Tx, error)
}
func StartUp(url string) DbContainer {
dbpool, err := pgxpool.New(context.Background(), url)
if err != nil {
log.Fatal("Cloud not create database pool")
panic(err)
}
return DbContainer{
pool: dbpool,
}
}
func (db DbContainer) Close() {
db.pool.Close()
}
func (db DbContainer) Query(query string, args ...any) (pgx.Rows, error) {
return db.pool.Query(context.Background(), query, args...)
}
func (db DbContainer) Exec(query string, args ...any) (pgconn.CommandTag, error) {
return db.pool.Exec(context.Background(), query, args...)
}
func (db DbContainer) Begin() (pgx.Tx, error) {
return db.pool.Begin(context.Background())
}

View File

@@ -1,8 +1,9 @@
package dbtypes
import (
"database/sql"
"errors"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
)
const (
@@ -79,7 +80,7 @@ type BaseModel struct {
var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
func GetBaseModel(db db.Db, id string) (base *BaseModel, err error) {
var model BaseModel
err = GetDBOnce(db, &model, "models where id=$1", id)
if err != nil {

View File

@@ -1,7 +1,7 @@
package dbtypes
import (
"database/sql"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
)
type UserType int
@@ -20,8 +20,7 @@ type User struct {
UserType int `db:"u.user_type"`
}
func UserFromToken(db *sql.DB, token string) (*User, error) {
func UserFromToken(db db.Db, token string) (*User, error) {
var user User
err := GetDBOnce(db, &user, "users as u inner join tokens as t on t.user_id = u.id where t.token = $1;", token)

View File

@@ -1,7 +1,6 @@
package dbtypes
import (
"database/sql"
"errors"
"fmt"
"io"
@@ -14,16 +13,19 @@ import (
"github.com/charmbracelet/log"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
db "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
)
type BasePack interface {
GetDb() *sql.DB
GetDb() db.Db
GetLogger() *log.Logger
GetHost() string
}
type BasePackStruct struct {
Db *sql.DB
Db db.Db
Logger *log.Logger
Host string
}
@@ -32,7 +34,7 @@ func (b BasePackStruct) GetHost() string {
return b.Host
}
func (b BasePackStruct) GetDb() *sql.DB {
func (b BasePackStruct) GetDb() db.Db {
return b.Db
}
@@ -224,24 +226,12 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
return
}
type QueryInterface interface {
Prepare(str string) (*sql.Stmt, error)
Query(query string, args ...any) (*sql.Rows, error)
Exec(query string, args ...any) (sql.Result, error)
}
func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) {
func GetDbMultitple[T interface{}](c db.Db, tablename string, args ...any) ([]*T, error) {
t := reflect.TypeFor[T]()
query, nargs := generateQuery(t)
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
if err != nil {
return nil, err
}
defer db_query.Close()
rows, err := db_query.Query(args...)
rows, err := c.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
if err != nil {
return nil, err
}
@@ -260,7 +250,7 @@ func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...a
return list, nil
}
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
func mapRow(store interface{}, rows pgx.Rows, nargs int) (err error) {
err = nil
val := reflect.Indirect(reflect.ValueOf(store))
@@ -278,7 +268,7 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
return nil
}
func InsertReturnId(c QueryInterface, store interface{}, tablename string, returnName string) (id string, err error) {
func InsertReturnId(c db.Db, store interface{}, tablename string, returnName string) (id string, err error) {
t := reflect.TypeOf(store).Elem()
query, nargs := generateQuery(t)
@@ -315,14 +305,8 @@ func InsertReturnId(c QueryInterface, store interface{}, tablename string, retur
return
}
func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename string, args ...any) (*T, error) {
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", var_to_extract, tablename))
if err != nil {
return nil, err
}
defer db_query.Close()
rows, err := db_query.Query(args...)
func GetDbVar[T interface{}](c db.Db, var_to_extract string, tablename string, args ...any) (*T, error) {
rows, err := c.Query(fmt.Sprintf("select %s from %s", var_to_extract, tablename), args...)
if err != nil {
return nil, err
}
@@ -340,7 +324,7 @@ func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename
return dat, nil
}
func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error {
func GetDBOnce(db db.Db, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem()
query, nargs := generateQuery(t)
@@ -372,7 +356,7 @@ func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...a
return nil
}
func UpdateStatus(c QueryInterface, table string, id string, status int) (err error) {
func UpdateStatus(c db.Db, table string, id string, status int) (err error) {
_, err = c.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
return
}

View File

@@ -1,15 +1,15 @@
package model_classes
import (
"database/sql"
"errors"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
)
var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error")
func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
func AddDataPoint(db db.Db, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
id = ""
result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode, status) values ($1, $2, $3, 1) returning id;", class_id, file_path, mode)
if err != nil {
@@ -24,7 +24,7 @@ func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT
return
}
func UpdateDataPointStatus(db *sql.DB, data_point_id string, status int, message *string) (err error) {
func UpdateDataPointStatus(db db.Db, data_point_id string, status int, message *string) (err error) {
_, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id)
return
}

View File

@@ -1,9 +1,9 @@
package model_classes
import (
"database/sql"
"errors"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
)
@@ -18,7 +18,7 @@ func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) {
return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id)
}
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
func ModelHasDataPoints(db db.Db, model_id string) (result bool, err error) {
result = false
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
if err != nil {
@@ -31,7 +31,7 @@ func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
var ClassAlreadyExists = errors.New("Class aready exists")
func CreateClass(db *sql.DB, model_id string, order int, name string) (id string, err error) {
func CreateClass(db db.Db, model_id string, order int, name string) (id string, err error) {
id = ""
rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name)
if err != nil {
@@ -57,7 +57,7 @@ func CreateClass(db *sql.DB, model_id string, order int, name string) (id string
return
}
func GetNumberOfWrongDataPoints(db *sql.DB, model_id string) (number int, err error) {
func GetNumberOfWrongDataPoints(db db.Db, model_id string) (number int, err error) {
number = 0
rows, err := db.Query("select count(mdp.id) from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model_id)
if err != nil {

View File

@@ -1,7 +1,6 @@
package models_train
import (
"database/sql"
"errors"
"fmt"
"io"
@@ -14,6 +13,7 @@ import (
"strings"
"text/template"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
@@ -40,7 +40,7 @@ func getDir() string {
}
// This function creates a new model_definition
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
func MakeDefenition(db db.Db, model_id string, target_accuracy int) (id string, err error) {
var NewDefinition = struct {
ModelId string `db:"model_id"`
TargetAccuracy int `db:"target_accuracy"`
@@ -54,12 +54,12 @@ func ModelDefinitionUpdateStatus(c BasePack, id string, status ModelDefinitionSt
return
}
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
func MakeLayer(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
return
}
func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) {
func MakeLayerExpandable(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) {
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)", def_id, layer_order, layer_type, shape, exp_type)
return
}
@@ -1398,7 +1398,7 @@ func generateDefinitions(c BasePack, model *BaseModel, target_accuracy int, numb
return nil
}
func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatus) (err error) {
func ExpModelHeadUpdateStatus(db db.Db, id string, status ModelDefinitionStatus) (err error) {
_, err = db.Exec("update model_definition set status = $1 where id = $2", status, id)
return
}
@@ -1877,18 +1877,7 @@ func handleTrain(handle *Handle) {
//Update the classes
{
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
_, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil {
_err := c.RollbackTx()
if _err != nil {

View File

@@ -1,7 +1,6 @@
package task_runner
import (
"database/sql"
"fmt"
"math"
"os"
@@ -10,6 +9,7 @@ import (
"github.com/charmbracelet/log"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
@@ -21,7 +21,7 @@ import (
/**
* Actually runs the code
*/
func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
func runner(config Config, db db.Db, task_channel chan Task, index int, back_channel chan int) {
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true,
ReportTimestamp: true,
@@ -125,7 +125,7 @@ func attentionSeeker(config Config, back_channel chan int) {
/**
* Manages what worker should to Work
*/
func RunnerOrchestrator(db *sql.DB, config Config) {
func RunnerOrchestrator(db db.Db, config Config) {
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true,
ReportTimestamp: true,
@@ -211,6 +211,6 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
}
}
func StartRunners(db *sql.DB, config Config) {
func StartRunners(db db.Db, config Config) {
go RunnerOrchestrator(db, config)
}

View File

@@ -2,7 +2,6 @@ package users
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"io"
"net/http"
@@ -10,6 +9,7 @@ import (
"golang.org/x/crypto/bcrypt"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
@@ -44,12 +44,12 @@ func genToken() string {
return hex.EncodeToString(token)
}
func deleteToken(db *sql.DB, userId string, time time.Time) (err error) {
func deleteToken(db db.Db, userId string, time time.Time) (err error) {
_, err = db.Exec("delete from tokens where emit_day=$1 and user_id=$2", time, userId)
return
}
func generateToken(db *sql.DB, email string, password string, name string) (string, bool) {
func generateToken(db db.Db, email string, password string, name string) (string, bool) {
row, err := db.Query("select id, salt, password from users where email = $1;", email)
if err != nil || !row.Next() {
return "", false
@@ -106,7 +106,7 @@ func DeleteUser(base BasePack, task Task) (err error) {
return
}
func UsersEndpints(db *sql.DB, handle *Handle) {
func UsersEndpints(db db.Db, handle *Handle) {
type UserLogin struct {
Email string `json:"email"`

View File

@@ -1,10 +1,10 @@
package utils
import (
"database/sql"
"os"
"strings"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
@@ -88,7 +88,7 @@ func failLog(err error) {
log.Fatal("Failed on setup", "error", err)
}
func (c *Config) Cleanup(db *sql.DB) {
func (c *Config) Cleanup(db db.Db) {
if c.CleanUpOnStartup != 1 {
return
}
@@ -125,7 +125,7 @@ func (c *Config) Cleanup(db *sql.DB) {
}
}
func (c *Config) GenerateToken(db *sql.DB) {
func (c *Config) GenerateToken(db db.Db) {
if c.ServiceUser.User == "" {
log.Fatal("A user needs to be set in a configuration file")
}

View File

@@ -2,7 +2,7 @@ package utils
import (
"bytes"
"database/sql"
"context"
"errors"
"fmt"
"io"
@@ -13,10 +13,13 @@ import (
"strings"
"time"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
"github.com/charmbracelet/log"
"github.com/go-playground/validator/v10"
"github.com/goccy/go-json"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
)
type AnyMap = map[string]interface{}
@@ -39,7 +42,7 @@ type Handler interface {
}
type Handle struct {
Db *sql.DB
Db db.Db
gets []HandleFunc
posts []HandleFunc
deletes []HandleFunc
@@ -235,16 +238,16 @@ type Context struct {
Token *string
User *dbtypes.User
Logger *log.Logger
Db *sql.DB
Db db.Db
Writer http.ResponseWriter
R *http.Request
Tx *sql.Tx
Tx pgx.Tx
ShowMessage bool
Handle *Handle
}
// This is required for this to integrate simealy with my orl
func (c Context) GetDb() *sql.DB {
func (c Context) GetDb() db.Db {
return c.Db
}
@@ -256,19 +259,22 @@ func (c Context) GetHost() string {
return c.Handle.Config.Hostname
}
func (c Context) Query(query string, args ...any) (*sql.Rows, error) {
return c.Db.Query(query, args...)
}
func (c Context) Exec(query string, args ...any) (sql.Result, error) {
return c.Db.Exec(query, args...)
}
func (c Context) Prepare(str string) (*sql.Stmt, error) {
func (c Context) Query(query string, args ...any) (pgx.Rows, error) {
if c.Tx == nil {
return c.Db.Prepare(str)
return c.Db.Query(query, args...)
}
return c.Tx.Prepare(str)
return c.Tx.Query(context.Background(), query, args...)
}
func (c Context) Exec(query string, args ...any) (pgconn.CommandTag, error) {
if c.Tx == nil {
return c.Db.Exec(query, args...)
}
return c.Tx.Exec(context.Background(), query, args...)
}
func (c Context) Begin() (pgx.Tx, error) {
return c.Db.Begin()
}
var TransactionAlreadyStarted = errors.New("Transaction already started")
@@ -279,7 +285,7 @@ func (c *Context) StartTx() error {
return TransactionAlreadyStarted
}
var err error = nil
c.Tx, err = c.Db.Begin()
c.Tx, err = c.Begin()
return err
}
@@ -287,7 +293,7 @@ func (c *Context) CommitTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Commit()
err := c.Tx.Commit(context.Background())
if err != nil {
return err
}
@@ -299,7 +305,7 @@ func (c *Context) RollbackTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Rollback()
err := c.Tx.Rollback(context.Background())
if err != nil {
return err
}
@@ -512,7 +518,7 @@ func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileType
})
}
func NewHandler(db *sql.DB, config Config) *Handle {
func NewHandler(db db.Db, config Config) *Handle {
var gets []HandleFunc
var posts []HandleFunc
var deletes []HandleFunc