moved to psql pool closes #99
This commit is contained in:
parent
8ece8306dd
commit
a3913ccdf5
4
go.mod
4
go.mod
@ -23,16 +23,20 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx v3.6.2+incompatible // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.15 // indirect
|
||||
github.com/muesli/reflow v0.3.0 // indirect
|
||||
github.com/muesli/termenv v0.15.2 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.6 // indirect
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sync v0.1.0 // indirect
|
||||
golang.org/x/sys v0.17.0 // indirect
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.32.0 // indirect
|
||||
|
8
go.sum
8
go.sum
@ -39,8 +39,12 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o=
|
||||
github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
|
||||
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
|
||||
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
|
||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
@ -60,6 +64,8 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
|
||||
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
|
||||
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
|
||||
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
@ -81,6 +87,8 @@ golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRj
|
||||
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
|
48
logic/db/db.go
Normal file
48
logic/db/db.go
Normal file
@ -0,0 +1,48 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type DbContainer struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
type Db interface {
|
||||
Query(query string, args ...any) (pgx.Rows, error)
|
||||
Exec(query string, args ...any) (pgconn.CommandTag, error)
|
||||
Begin() (pgx.Tx, error)
|
||||
}
|
||||
|
||||
func StartUp(url string) DbContainer {
|
||||
dbpool, err := pgxpool.New(context.Background(), url)
|
||||
if err != nil {
|
||||
log.Fatal("Cloud not create database pool")
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return DbContainer{
|
||||
pool: dbpool,
|
||||
}
|
||||
}
|
||||
|
||||
func (db DbContainer) Close() {
|
||||
db.pool.Close()
|
||||
}
|
||||
|
||||
func (db DbContainer) Query(query string, args ...any) (pgx.Rows, error) {
|
||||
return db.pool.Query(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (db DbContainer) Exec(query string, args ...any) (pgconn.CommandTag, error) {
|
||||
return db.pool.Exec(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (db DbContainer) Begin() (pgx.Tx, error) {
|
||||
return db.pool.Begin(context.Background())
|
||||
}
|
@ -1,8 +1,9 @@
|
||||
package dbtypes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -79,7 +80,7 @@ type BaseModel struct {
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
func GetBaseModel(db db.Db, id string) (base *BaseModel, err error) {
|
||||
var model BaseModel
|
||||
err = GetDBOnce(db, &model, "models where id=$1", id)
|
||||
if err != nil {
|
||||
|
@ -1,7 +1,7 @@
|
||||
package dbtypes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
)
|
||||
|
||||
type UserType int
|
||||
@ -20,8 +20,7 @@ type User struct {
|
||||
UserType int `db:"u.user_type"`
|
||||
}
|
||||
|
||||
func UserFromToken(db *sql.DB, token string) (*User, error) {
|
||||
|
||||
func UserFromToken(db db.Db, token string) (*User, error) {
|
||||
var user User
|
||||
|
||||
err := GetDBOnce(db, &user, "users as u inner join tokens as t on t.user_id = u.id where t.token = $1;", token)
|
||||
|
@ -1,7 +1,6 @@
|
||||
package dbtypes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -14,16 +13,19 @@ import (
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
db "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
)
|
||||
|
||||
type BasePack interface {
|
||||
GetDb() *sql.DB
|
||||
GetDb() db.Db
|
||||
GetLogger() *log.Logger
|
||||
GetHost() string
|
||||
}
|
||||
|
||||
type BasePackStruct struct {
|
||||
Db *sql.DB
|
||||
Db db.Db
|
||||
Logger *log.Logger
|
||||
Host string
|
||||
}
|
||||
@ -32,7 +34,7 @@ func (b BasePackStruct) GetHost() string {
|
||||
return b.Host
|
||||
}
|
||||
|
||||
func (b BasePackStruct) GetDb() *sql.DB {
|
||||
func (b BasePackStruct) GetDb() db.Db {
|
||||
return b.Db
|
||||
}
|
||||
|
||||
@ -224,24 +226,12 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
|
||||
return
|
||||
}
|
||||
|
||||
type QueryInterface interface {
|
||||
Prepare(str string) (*sql.Stmt, error)
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) {
|
||||
func GetDbMultitple[T interface{}](c db.Db, tablename string, args ...any) ([]*T, error) {
|
||||
t := reflect.TypeFor[T]()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
|
||||
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db_query.Close()
|
||||
|
||||
rows, err := db_query.Query(args...)
|
||||
rows, err := c.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -260,7 +250,7 @@ func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...a
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
||||
func mapRow(store interface{}, rows pgx.Rows, nargs int) (err error) {
|
||||
err = nil
|
||||
|
||||
val := reflect.Indirect(reflect.ValueOf(store))
|
||||
@ -278,7 +268,7 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertReturnId(c QueryInterface, store interface{}, tablename string, returnName string) (id string, err error) {
|
||||
func InsertReturnId(c db.Db, store interface{}, tablename string, returnName string) (id string, err error) {
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
@ -315,14 +305,8 @@ func InsertReturnId(c QueryInterface, store interface{}, tablename string, retur
|
||||
return
|
||||
}
|
||||
|
||||
func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename string, args ...any) (*T, error) {
|
||||
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", var_to_extract, tablename))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db_query.Close()
|
||||
|
||||
rows, err := db_query.Query(args...)
|
||||
func GetDbVar[T interface{}](c db.Db, var_to_extract string, tablename string, args ...any) (*T, error) {
|
||||
rows, err := c.Query(fmt.Sprintf("select %s from %s", var_to_extract, tablename), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -340,7 +324,7 @@ func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename
|
||||
return dat, nil
|
||||
}
|
||||
|
||||
func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error {
|
||||
func GetDBOnce(db db.Db, store interface{}, tablename string, args ...any) error {
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
@ -372,7 +356,7 @@ func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...a
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateStatus(c QueryInterface, table string, id string, status int) (err error) {
|
||||
func UpdateStatus(c db.Db, table string, id string, status int) (err error) {
|
||||
_, err = c.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
|
||||
return
|
||||
}
|
||||
|
@ -1,15 +1,15 @@
|
||||
package model_classes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
)
|
||||
|
||||
var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error")
|
||||
|
||||
func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
|
||||
func AddDataPoint(db db.Db, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
|
||||
id = ""
|
||||
result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode, status) values ($1, $2, $3, 1) returning id;", class_id, file_path, mode)
|
||||
if err != nil {
|
||||
@ -24,7 +24,7 @@ func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateDataPointStatus(db *sql.DB, data_point_id string, status int, message *string) (err error) {
|
||||
func UpdateDataPointStatus(db db.Db, data_point_id string, status int, message *string) (err error) {
|
||||
_, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id)
|
||||
return
|
||||
}
|
||||
|
@ -1,9 +1,9 @@
|
||||
package model_classes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
)
|
||||
|
||||
@ -18,7 +18,7 @@ func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) {
|
||||
return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id)
|
||||
}
|
||||
|
||||
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
||||
func ModelHasDataPoints(db db.Db, model_id string) (result bool, err error) {
|
||||
result = false
|
||||
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
|
||||
if err != nil {
|
||||
@ -31,7 +31,7 @@ func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
||||
|
||||
var ClassAlreadyExists = errors.New("Class aready exists")
|
||||
|
||||
func CreateClass(db *sql.DB, model_id string, order int, name string) (id string, err error) {
|
||||
func CreateClass(db db.Db, model_id string, order int, name string) (id string, err error) {
|
||||
id = ""
|
||||
rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name)
|
||||
if err != nil {
|
||||
@ -57,7 +57,7 @@ func CreateClass(db *sql.DB, model_id string, order int, name string) (id string
|
||||
return
|
||||
}
|
||||
|
||||
func GetNumberOfWrongDataPoints(db *sql.DB, model_id string) (number int, err error) {
|
||||
func GetNumberOfWrongDataPoints(db db.Db, model_id string) (number int, err error) {
|
||||
number = 0
|
||||
rows, err := db.Query("select count(mdp.id) from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model_id)
|
||||
if err != nil {
|
||||
|
@ -1,7 +1,6 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -14,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
@ -40,7 +40,7 @@ func getDir() string {
|
||||
}
|
||||
|
||||
// This function creates a new model_definition
|
||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||
func MakeDefenition(db db.Db, model_id string, target_accuracy int) (id string, err error) {
|
||||
var NewDefinition = struct {
|
||||
ModelId string `db:"model_id"`
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
@ -54,12 +54,12 @@ func ModelDefinitionUpdateStatus(c BasePack, id string, status ModelDefinitionSt
|
||||
return
|
||||
}
|
||||
|
||||
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
|
||||
func MakeLayer(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
|
||||
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
|
||||
return
|
||||
}
|
||||
|
||||
func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) {
|
||||
func MakeLayerExpandable(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) {
|
||||
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)", def_id, layer_order, layer_type, shape, exp_type)
|
||||
return
|
||||
}
|
||||
@ -1398,7 +1398,7 @@ func generateDefinitions(c BasePack, model *BaseModel, target_accuracy int, numb
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatus) (err error) {
|
||||
func ExpModelHeadUpdateStatus(db db.Db, id string, status ModelDefinitionStatus) (err error) {
|
||||
_, err = db.Exec("update model_definition set status = $1 where id = $2", status, id)
|
||||
return
|
||||
}
|
||||
@ -1877,18 +1877,7 @@ func handleTrain(handle *Handle) {
|
||||
|
||||
//Update the classes
|
||||
{
|
||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||
err = err2
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
_, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
|
@ -1,7 +1,6 @@
|
||||
package task_runner
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
@ -10,6 +9,7 @@ import (
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||
@ -21,7 +21,7 @@ import (
|
||||
/**
|
||||
* Actually runs the code
|
||||
*/
|
||||
func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
|
||||
func runner(config Config, db db.Db, task_channel chan Task, index int, back_channel chan int) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
@ -125,7 +125,7 @@ func attentionSeeker(config Config, back_channel chan int) {
|
||||
/**
|
||||
* Manages what worker should to Work
|
||||
*/
|
||||
func RunnerOrchestrator(db *sql.DB, config Config) {
|
||||
func RunnerOrchestrator(db db.Db, config Config) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
@ -211,6 +211,6 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
|
||||
}
|
||||
}
|
||||
|
||||
func StartRunners(db *sql.DB, config Config) {
|
||||
func StartRunners(db db.Db, config Config) {
|
||||
go RunnerOrchestrator(db, config)
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package users
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -10,6 +9,7 @@ import (
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
@ -44,12 +44,12 @@ func genToken() string {
|
||||
return hex.EncodeToString(token)
|
||||
}
|
||||
|
||||
func deleteToken(db *sql.DB, userId string, time time.Time) (err error) {
|
||||
func deleteToken(db db.Db, userId string, time time.Time) (err error) {
|
||||
_, err = db.Exec("delete from tokens where emit_day=$1 and user_id=$2", time, userId)
|
||||
return
|
||||
}
|
||||
|
||||
func generateToken(db *sql.DB, email string, password string, name string) (string, bool) {
|
||||
func generateToken(db db.Db, email string, password string, name string) (string, bool) {
|
||||
row, err := db.Query("select id, salt, password from users where email = $1;", email)
|
||||
if err != nil || !row.Next() {
|
||||
return "", false
|
||||
@ -106,7 +106,7 @@ func DeleteUser(base BasePack, task Task) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func UsersEndpints(db *sql.DB, handle *Handle) {
|
||||
func UsersEndpints(db db.Db, handle *Handle) {
|
||||
|
||||
type UserLogin struct {
|
||||
Email string `json:"email"`
|
||||
|
@ -1,10 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
|
||||
@ -88,7 +88,7 @@ func failLog(err error) {
|
||||
log.Fatal("Failed on setup", "error", err)
|
||||
}
|
||||
|
||||
func (c *Config) Cleanup(db *sql.DB) {
|
||||
func (c *Config) Cleanup(db db.Db) {
|
||||
if c.CleanUpOnStartup != 1 {
|
||||
return
|
||||
}
|
||||
@ -125,7 +125,7 @@ func (c *Config) Cleanup(db *sql.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) GenerateToken(db *sql.DB) {
|
||||
func (c *Config) GenerateToken(db db.Db) {
|
||||
if c.ServiceUser.User == "" {
|
||||
log.Fatal("A user needs to be set in a configuration file")
|
||||
}
|
||||
|
@ -2,7 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -13,10 +13,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type AnyMap = map[string]interface{}
|
||||
@ -39,7 +42,7 @@ type Handler interface {
|
||||
}
|
||||
|
||||
type Handle struct {
|
||||
Db *sql.DB
|
||||
Db db.Db
|
||||
gets []HandleFunc
|
||||
posts []HandleFunc
|
||||
deletes []HandleFunc
|
||||
@ -235,16 +238,16 @@ type Context struct {
|
||||
Token *string
|
||||
User *dbtypes.User
|
||||
Logger *log.Logger
|
||||
Db *sql.DB
|
||||
Db db.Db
|
||||
Writer http.ResponseWriter
|
||||
R *http.Request
|
||||
Tx *sql.Tx
|
||||
Tx pgx.Tx
|
||||
ShowMessage bool
|
||||
Handle *Handle
|
||||
}
|
||||
|
||||
// This is required for this to integrate simealy with my orl
|
||||
func (c Context) GetDb() *sql.DB {
|
||||
func (c Context) GetDb() db.Db {
|
||||
return c.Db
|
||||
}
|
||||
|
||||
@ -256,19 +259,22 @@ func (c Context) GetHost() string {
|
||||
return c.Handle.Config.Hostname
|
||||
}
|
||||
|
||||
func (c Context) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
func (c Context) Query(query string, args ...any) (pgx.Rows, error) {
|
||||
if c.Tx == nil {
|
||||
return c.Db.Query(query, args...)
|
||||
}
|
||||
return c.Tx.Query(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (c Context) Exec(query string, args ...any) (sql.Result, error) {
|
||||
func (c Context) Exec(query string, args ...any) (pgconn.CommandTag, error) {
|
||||
if c.Tx == nil {
|
||||
return c.Db.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (c Context) Prepare(str string) (*sql.Stmt, error) {
|
||||
if c.Tx == nil {
|
||||
return c.Db.Prepare(str)
|
||||
return c.Tx.Exec(context.Background(), query, args...)
|
||||
}
|
||||
return c.Tx.Prepare(str)
|
||||
|
||||
func (c Context) Begin() (pgx.Tx, error) {
|
||||
return c.Db.Begin()
|
||||
}
|
||||
|
||||
var TransactionAlreadyStarted = errors.New("Transaction already started")
|
||||
@ -279,7 +285,7 @@ func (c *Context) StartTx() error {
|
||||
return TransactionAlreadyStarted
|
||||
}
|
||||
var err error = nil
|
||||
c.Tx, err = c.Db.Begin()
|
||||
c.Tx, err = c.Begin()
|
||||
return err
|
||||
}
|
||||
|
||||
@ -287,7 +293,7 @@ func (c *Context) CommitTx() error {
|
||||
if c.Tx == nil {
|
||||
return TransactionNotStarted
|
||||
}
|
||||
err := c.Tx.Commit()
|
||||
err := c.Tx.Commit(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -299,7 +305,7 @@ func (c *Context) RollbackTx() error {
|
||||
if c.Tx == nil {
|
||||
return TransactionNotStarted
|
||||
}
|
||||
err := c.Tx.Rollback()
|
||||
err := c.Tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -512,7 +518,7 @@ func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileType
|
||||
})
|
||||
}
|
||||
|
||||
func NewHandler(db *sql.DB, config Config) *Handle {
|
||||
func NewHandler(db db.Db, config Config) *Handle {
|
||||
var gets []HandleFunc
|
||||
var posts []HandleFunc
|
||||
var deletes []HandleFunc
|
||||
|
11
main.go
11
main.go
@ -1,12 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
_ "github.com/lib/pq"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/runner"
|
||||
@ -23,23 +23,18 @@ const (
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+
|
||||
"password=%s dbname=%s sslmode=disable",
|
||||
host, port, user, password, dbname)
|
||||
|
||||
db, err := sql.Open("postgres", psqlInfo)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db := db.StartUp(psqlInfo)
|
||||
defer db.Close()
|
||||
log.Info("Starting server on :5002!")
|
||||
|
||||
config := LoadConfig()
|
||||
log.Info("Config loaded!", "config", config)
|
||||
config.GenerateToken(db)
|
||||
|
||||
db.SetMaxOpenConns(config.DbInfo.MaxConnections)
|
||||
|
||||
StartRunners(db, config)
|
||||
|
||||
//TODO check if file structure exists to save data
|
||||
|
Loading…
Reference in New Issue
Block a user