From a3913ccdf533705b8d3b3e31f28faec886d5c23e Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Wed, 17 Apr 2024 17:46:43 +0100 Subject: [PATCH] moved to psql pool closes #99 --- go.mod | 4 +++ go.sum | 8 +++++ logic/db/db.go | 48 ++++++++++++++++++++++++++++++ logic/db_types/types.go | 5 ++-- logic/db_types/user.go | 5 ++-- logic/db_types/utils.go | 44 +++++++++------------------ logic/models/classes/data_point.go | 6 ++-- logic/models/classes/main.go | 8 ++--- logic/models/train/train.go | 23 ++++---------- logic/tasks/runner/runner.go | 8 ++--- logic/users/users.go | 8 ++--- logic/utils/config.go | 6 ++-- logic/utils/handler.go | 46 +++++++++++++++------------- main.go | 11 ++----- 14 files changed, 132 insertions(+), 98 deletions(-) create mode 100644 logic/db/db.go diff --git a/go.mod b/go.mod index 1b0929f..9ac9940 100644 --- a/go.mod +++ b/go.mod @@ -23,16 +23,20 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx v3.6.2+incompatible // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect + github.com/pkg/errors v0.9.1 // indirect github.com/rivo/uniseg v0.4.6 // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect golang.org/x/net v0.21.0 // indirect + golang.org/x/sync v0.1.0 // indirect golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect google.golang.org/protobuf v1.32.0 // indirect diff --git a/go.sum b/go.sum index 7d6727e..3a312a2 100644 --- a/go.sum +++ b/go.sum @@ -39,8 +39,12 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o= +github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I= github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= @@ -60,6 +64,8 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= @@ -81,6 +87,8 @@ golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRj golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/logic/db/db.go b/logic/db/db.go new file mode 100644 index 0000000..e4b5c94 --- /dev/null +++ b/logic/db/db.go @@ -0,0 +1,48 @@ +package db + +import ( + "context" + + "github.com/charmbracelet/log" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgxpool" +) + +type DbContainer struct { + pool *pgxpool.Pool +} + +type Db interface { + Query(query string, args ...any) (pgx.Rows, error) + Exec(query string, args ...any) (pgconn.CommandTag, error) + Begin() (pgx.Tx, error) +} + +func StartUp(url string) DbContainer { + dbpool, err := pgxpool.New(context.Background(), url) + if err != nil { + log.Fatal("Cloud not create database pool") + panic(err) + } + + return DbContainer{ + pool: dbpool, + } +} + +func (db DbContainer) Close() { + db.pool.Close() +} + +func (db DbContainer) Query(query string, args ...any) (pgx.Rows, error) { + return db.pool.Query(context.Background(), query, args...) +} + +func (db DbContainer) Exec(query string, args ...any) (pgconn.CommandTag, error) { + return db.pool.Exec(context.Background(), query, args...) +} + +func (db DbContainer) Begin() (pgx.Tx, error) { + return db.pool.Begin(context.Background()) +} diff --git a/logic/db_types/types.go b/logic/db_types/types.go index 7200369..97fb993 100644 --- a/logic/db_types/types.go +++ b/logic/db_types/types.go @@ -1,8 +1,9 @@ package dbtypes import ( - "database/sql" "errors" + + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" ) const ( @@ -79,7 +80,7 @@ type BaseModel struct { var ModelNotFoundError = errors.New("Model not found error") -func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { +func GetBaseModel(db db.Db, id string) (base *BaseModel, err error) { var model BaseModel err = GetDBOnce(db, &model, "models where id=$1", id) if err != nil { diff --git a/logic/db_types/user.go b/logic/db_types/user.go index ee91e64..1996241 100644 --- a/logic/db_types/user.go +++ b/logic/db_types/user.go @@ -1,7 +1,7 @@ package dbtypes import ( - "database/sql" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" ) type UserType int @@ -20,8 +20,7 @@ type User struct { UserType int `db:"u.user_type"` } -func UserFromToken(db *sql.DB, token string) (*User, error) { - +func UserFromToken(db db.Db, token string) (*User, error) { var user User err := GetDBOnce(db, &user, "users as u inner join tokens as t on t.user_id = u.id where t.token = $1;", token) diff --git a/logic/db_types/utils.go b/logic/db_types/utils.go index 771f139..352cacd 100644 --- a/logic/db_types/utils.go +++ b/logic/db_types/utils.go @@ -1,7 +1,6 @@ package dbtypes import ( - "database/sql" "errors" "fmt" "io" @@ -14,16 +13,19 @@ import ( "github.com/charmbracelet/log" "github.com/google/uuid" + "github.com/jackc/pgx/v5" + + db "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" ) type BasePack interface { - GetDb() *sql.DB + GetDb() db.Db GetLogger() *log.Logger GetHost() string } type BasePackStruct struct { - Db *sql.DB + Db db.Db Logger *log.Logger Host string } @@ -32,7 +34,7 @@ func (b BasePackStruct) GetHost() string { return b.Host } -func (b BasePackStruct) GetDb() *sql.DB { +func (b BasePackStruct) GetDb() db.Db { return b.Db } @@ -224,24 +226,12 @@ func generateQuery(t reflect.Type) (query string, nargs int) { return } -type QueryInterface interface { - Prepare(str string) (*sql.Stmt, error) - Query(query string, args ...any) (*sql.Rows, error) - Exec(query string, args ...any) (sql.Result, error) -} - -func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) { +func GetDbMultitple[T interface{}](c db.Db, tablename string, args ...any) ([]*T, error) { t := reflect.TypeFor[T]() query, nargs := generateQuery(t) - db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename)) - if err != nil { - return nil, err - } - defer db_query.Close() - - rows, err := db_query.Query(args...) + rows, err := c.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) if err != nil { return nil, err } @@ -260,7 +250,7 @@ func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...a return list, nil } -func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { +func mapRow(store interface{}, rows pgx.Rows, nargs int) (err error) { err = nil val := reflect.Indirect(reflect.ValueOf(store)) @@ -278,7 +268,7 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { return nil } -func InsertReturnId(c QueryInterface, store interface{}, tablename string, returnName string) (id string, err error) { +func InsertReturnId(c db.Db, store interface{}, tablename string, returnName string) (id string, err error) { t := reflect.TypeOf(store).Elem() query, nargs := generateQuery(t) @@ -315,14 +305,8 @@ func InsertReturnId(c QueryInterface, store interface{}, tablename string, retur return } -func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename string, args ...any) (*T, error) { - db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", var_to_extract, tablename)) - if err != nil { - return nil, err - } - defer db_query.Close() - - rows, err := db_query.Query(args...) +func GetDbVar[T interface{}](c db.Db, var_to_extract string, tablename string, args ...any) (*T, error) { + rows, err := c.Query(fmt.Sprintf("select %s from %s", var_to_extract, tablename), args...) if err != nil { return nil, err } @@ -340,7 +324,7 @@ func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename return dat, nil } -func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error { +func GetDBOnce(db db.Db, store interface{}, tablename string, args ...any) error { t := reflect.TypeOf(store).Elem() query, nargs := generateQuery(t) @@ -372,7 +356,7 @@ func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...a return nil } -func UpdateStatus(c QueryInterface, table string, id string, status int) (err error) { +func UpdateStatus(c db.Db, table string, id string, status int) (err error) { _, err = c.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) return } diff --git a/logic/models/classes/data_point.go b/logic/models/classes/data_point.go index ecd16fe..7edc654 100644 --- a/logic/models/classes/data_point.go +++ b/logic/models/classes/data_point.go @@ -1,15 +1,15 @@ package model_classes import ( - "database/sql" "errors" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" ) var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error") -func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) { +func AddDataPoint(db db.Db, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) { id = "" result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode, status) values ($1, $2, $3, 1) returning id;", class_id, file_path, mode) if err != nil { @@ -24,7 +24,7 @@ func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT return } -func UpdateDataPointStatus(db *sql.DB, data_point_id string, status int, message *string) (err error) { +func UpdateDataPointStatus(db db.Db, data_point_id string, status int, message *string) (err error) { _, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id) return } diff --git a/logic/models/classes/main.go b/logic/models/classes/main.go index 017074a..81d7563 100644 --- a/logic/models/classes/main.go +++ b/logic/models/classes/main.go @@ -1,9 +1,9 @@ package model_classes import ( - "database/sql" "errors" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" ) @@ -18,7 +18,7 @@ func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) { return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id) } -func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) { +func ModelHasDataPoints(db db.Db, model_id string) (result bool, err error) { result = false rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id) if err != nil { @@ -31,7 +31,7 @@ func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) { var ClassAlreadyExists = errors.New("Class aready exists") -func CreateClass(db *sql.DB, model_id string, order int, name string) (id string, err error) { +func CreateClass(db db.Db, model_id string, order int, name string) (id string, err error) { id = "" rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name) if err != nil { @@ -57,7 +57,7 @@ func CreateClass(db *sql.DB, model_id string, order int, name string) (id string return } -func GetNumberOfWrongDataPoints(db *sql.DB, model_id string) (number int, err error) { +func GetNumberOfWrongDataPoints(db db.Db, model_id string) (number int, err error) { number = 0 rows, err := db.Query("select count(mdp.id) from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model_id) if err != nil { diff --git a/logic/models/train/train.go b/logic/models/train/train.go index e695267..fa0ef6a 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -1,7 +1,6 @@ package models_train import ( - "database/sql" "errors" "fmt" "io" @@ -14,6 +13,7 @@ import ( "strings" "text/template" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" @@ -40,7 +40,7 @@ func getDir() string { } // This function creates a new model_definition -func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { +func MakeDefenition(db db.Db, model_id string, target_accuracy int) (id string, err error) { var NewDefinition = struct { ModelId string `db:"model_id"` TargetAccuracy int `db:"target_accuracy"` @@ -54,12 +54,12 @@ func ModelDefinitionUpdateStatus(c BasePack, id string, status ModelDefinitionSt return } -func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) { +func MakeLayer(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string) (err error) { _, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape) return } -func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) { +func MakeLayerExpandable(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) { _, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)", def_id, layer_order, layer_type, shape, exp_type) return } @@ -1398,7 +1398,7 @@ func generateDefinitions(c BasePack, model *BaseModel, target_accuracy int, numb return nil } -func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatus) (err error) { +func ExpModelHeadUpdateStatus(db db.Db, id string, status ModelDefinitionStatus) (err error) { _, err = db.Exec("update model_definition set status = $1 where id = $2", status, id) return } @@ -1877,18 +1877,7 @@ func handleTrain(handle *Handle) { //Update the classes { - stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3") - err = err2 - if err != nil { - _err := c.RollbackTx() - if _err != nil { - c.Logger.Error("Two errors happended rollback failed", "err", _err) - } - return failed() - } - defer stmt.Close() - - _, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id) + _, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id) if err != nil { _err := c.RollbackTx() if _err != nil { diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go index 8bae0da..4e6f967 100644 --- a/logic/tasks/runner/runner.go +++ b/logic/tasks/runner/runner.go @@ -1,7 +1,6 @@ package task_runner import ( - "database/sql" "fmt" "math" "os" @@ -10,6 +9,7 @@ import ( "github.com/charmbracelet/log" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" @@ -21,7 +21,7 @@ import ( /** * Actually runs the code */ -func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) { +func runner(config Config, db db.Db, task_channel chan Task, index int, back_channel chan int) { logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, @@ -125,7 +125,7 @@ func attentionSeeker(config Config, back_channel chan int) { /** * Manages what worker should to Work */ -func RunnerOrchestrator(db *sql.DB, config Config) { +func RunnerOrchestrator(db db.Db, config Config) { logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, @@ -211,6 +211,6 @@ func RunnerOrchestrator(db *sql.DB, config Config) { } } -func StartRunners(db *sql.DB, config Config) { +func StartRunners(db db.Db, config Config) { go RunnerOrchestrator(db, config) } diff --git a/logic/users/users.go b/logic/users/users.go index 084842c..5c2c9d3 100644 --- a/logic/users/users.go +++ b/logic/users/users.go @@ -2,7 +2,6 @@ package users import ( "crypto/rand" - "database/sql" "encoding/hex" "io" "net/http" @@ -10,6 +9,7 @@ import ( "golang.org/x/crypto/bcrypt" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" @@ -44,12 +44,12 @@ func genToken() string { return hex.EncodeToString(token) } -func deleteToken(db *sql.DB, userId string, time time.Time) (err error) { +func deleteToken(db db.Db, userId string, time time.Time) (err error) { _, err = db.Exec("delete from tokens where emit_day=$1 and user_id=$2", time, userId) return } -func generateToken(db *sql.DB, email string, password string, name string) (string, bool) { +func generateToken(db db.Db, email string, password string, name string) (string, bool) { row, err := db.Query("select id, salt, password from users where email = $1;", email) if err != nil || !row.Next() { return "", false @@ -106,7 +106,7 @@ func DeleteUser(base BasePack, task Task) (err error) { return } -func UsersEndpints(db *sql.DB, handle *Handle) { +func UsersEndpints(db db.Db, handle *Handle) { type UserLogin struct { Email string `json:"email"` diff --git a/logic/utils/config.go b/logic/utils/config.go index 4c9fbee..085aa2b 100644 --- a/logic/utils/config.go +++ b/logic/utils/config.go @@ -1,10 +1,10 @@ package utils import ( - "database/sql" "os" "strings" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" @@ -88,7 +88,7 @@ func failLog(err error) { log.Fatal("Failed on setup", "error", err) } -func (c *Config) Cleanup(db *sql.DB) { +func (c *Config) Cleanup(db db.Db) { if c.CleanUpOnStartup != 1 { return } @@ -125,7 +125,7 @@ func (c *Config) Cleanup(db *sql.DB) { } } -func (c *Config) GenerateToken(db *sql.DB) { +func (c *Config) GenerateToken(db db.Db) { if c.ServiceUser.User == "" { log.Fatal("A user needs to be set in a configuration file") } diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 6bed6fc..328ebda 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -2,7 +2,7 @@ package utils import ( "bytes" - "database/sql" + "context" "errors" "fmt" "io" @@ -13,10 +13,13 @@ import ( "strings" "time" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "github.com/charmbracelet/log" "github.com/go-playground/validator/v10" "github.com/goccy/go-json" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) type AnyMap = map[string]interface{} @@ -39,7 +42,7 @@ type Handler interface { } type Handle struct { - Db *sql.DB + Db db.Db gets []HandleFunc posts []HandleFunc deletes []HandleFunc @@ -235,16 +238,16 @@ type Context struct { Token *string User *dbtypes.User Logger *log.Logger - Db *sql.DB + Db db.Db Writer http.ResponseWriter R *http.Request - Tx *sql.Tx + Tx pgx.Tx ShowMessage bool Handle *Handle } // This is required for this to integrate simealy with my orl -func (c Context) GetDb() *sql.DB { +func (c Context) GetDb() db.Db { return c.Db } @@ -256,19 +259,22 @@ func (c Context) GetHost() string { return c.Handle.Config.Hostname } -func (c Context) Query(query string, args ...any) (*sql.Rows, error) { - return c.Db.Query(query, args...) -} - -func (c Context) Exec(query string, args ...any) (sql.Result, error) { - return c.Db.Exec(query, args...) -} - -func (c Context) Prepare(str string) (*sql.Stmt, error) { +func (c Context) Query(query string, args ...any) (pgx.Rows, error) { if c.Tx == nil { - return c.Db.Prepare(str) + return c.Db.Query(query, args...) } - return c.Tx.Prepare(str) + return c.Tx.Query(context.Background(), query, args...) +} + +func (c Context) Exec(query string, args ...any) (pgconn.CommandTag, error) { + if c.Tx == nil { + return c.Db.Exec(query, args...) + } + return c.Tx.Exec(context.Background(), query, args...) +} + +func (c Context) Begin() (pgx.Tx, error) { + return c.Db.Begin() } var TransactionAlreadyStarted = errors.New("Transaction already started") @@ -279,7 +285,7 @@ func (c *Context) StartTx() error { return TransactionAlreadyStarted } var err error = nil - c.Tx, err = c.Db.Begin() + c.Tx, err = c.Begin() return err } @@ -287,7 +293,7 @@ func (c *Context) CommitTx() error { if c.Tx == nil { return TransactionNotStarted } - err := c.Tx.Commit() + err := c.Tx.Commit(context.Background()) if err != nil { return err } @@ -299,7 +305,7 @@ func (c *Context) RollbackTx() error { if c.Tx == nil { return TransactionNotStarted } - err := c.Tx.Rollback() + err := c.Tx.Rollback(context.Background()) if err != nil { return err } @@ -512,7 +518,7 @@ func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileType }) } -func NewHandler(db *sql.DB, config Config) *Handle { +func NewHandler(db db.Db, config Config) *Handle { var gets []HandleFunc var posts []HandleFunc var deletes []HandleFunc diff --git a/main.go b/main.go index cf53776..35f5daf 100644 --- a/main.go +++ b/main.go @@ -1,12 +1,12 @@ package main import ( - "database/sql" "fmt" "github.com/charmbracelet/log" _ "github.com/lib/pq" + "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/runner" @@ -23,23 +23,18 @@ const ( ) func main() { + psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+ "password=%s dbname=%s sslmode=disable", host, port, user, password, dbname) - db, err := sql.Open("postgres", psqlInfo) - if err != nil { - panic(err) - } + db := db.StartUp(psqlInfo) defer db.Close() - log.Info("Starting server on :5002!") config := LoadConfig() log.Info("Config loaded!", "config", config) config.GenerateToken(db) - db.SetMaxOpenConns(config.DbInfo.MaxConnections) - StartRunners(db, config) //TODO check if file structure exists to save data