moved to psql pool closes #99
This commit is contained in:
48
logic/db/db.go
Normal file
48
logic/db/db.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type DbContainer struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
type Db interface {
|
||||
Query(query string, args ...any) (pgx.Rows, error)
|
||||
Exec(query string, args ...any) (pgconn.CommandTag, error)
|
||||
Begin() (pgx.Tx, error)
|
||||
}
|
||||
|
||||
func StartUp(url string) DbContainer {
|
||||
dbpool, err := pgxpool.New(context.Background(), url)
|
||||
if err != nil {
|
||||
log.Fatal("Cloud not create database pool")
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return DbContainer{
|
||||
pool: dbpool,
|
||||
}
|
||||
}
|
||||
|
||||
func (db DbContainer) Close() {
|
||||
db.pool.Close()
|
||||
}
|
||||
|
||||
func (db DbContainer) Query(query string, args ...any) (pgx.Rows, error) {
|
||||
return db.pool.Query(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (db DbContainer) Exec(query string, args ...any) (pgconn.CommandTag, error) {
|
||||
return db.pool.Exec(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (db DbContainer) Begin() (pgx.Tx, error) {
|
||||
return db.pool.Begin(context.Background())
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package dbtypes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -79,7 +80,7 @@ type BaseModel struct {
|
||||
|
||||
var ModelNotFoundError = errors.New("Model not found error")
|
||||
|
||||
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) {
|
||||
func GetBaseModel(db db.Db, id string) (base *BaseModel, err error) {
|
||||
var model BaseModel
|
||||
err = GetDBOnce(db, &model, "models where id=$1", id)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package dbtypes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
)
|
||||
|
||||
type UserType int
|
||||
@@ -20,8 +20,7 @@ type User struct {
|
||||
UserType int `db:"u.user_type"`
|
||||
}
|
||||
|
||||
func UserFromToken(db *sql.DB, token string) (*User, error) {
|
||||
|
||||
func UserFromToken(db db.Db, token string) (*User, error) {
|
||||
var user User
|
||||
|
||||
err := GetDBOnce(db, &user, "users as u inner join tokens as t on t.user_id = u.id where t.token = $1;", token)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package dbtypes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -14,16 +13,19 @@ import (
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
db "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
)
|
||||
|
||||
type BasePack interface {
|
||||
GetDb() *sql.DB
|
||||
GetDb() db.Db
|
||||
GetLogger() *log.Logger
|
||||
GetHost() string
|
||||
}
|
||||
|
||||
type BasePackStruct struct {
|
||||
Db *sql.DB
|
||||
Db db.Db
|
||||
Logger *log.Logger
|
||||
Host string
|
||||
}
|
||||
@@ -32,7 +34,7 @@ func (b BasePackStruct) GetHost() string {
|
||||
return b.Host
|
||||
}
|
||||
|
||||
func (b BasePackStruct) GetDb() *sql.DB {
|
||||
func (b BasePackStruct) GetDb() db.Db {
|
||||
return b.Db
|
||||
}
|
||||
|
||||
@@ -224,24 +226,12 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
|
||||
return
|
||||
}
|
||||
|
||||
type QueryInterface interface {
|
||||
Prepare(str string) (*sql.Stmt, error)
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) {
|
||||
func GetDbMultitple[T interface{}](c db.Db, tablename string, args ...any) ([]*T, error) {
|
||||
t := reflect.TypeFor[T]()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
|
||||
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db_query.Close()
|
||||
|
||||
rows, err := db_query.Query(args...)
|
||||
rows, err := c.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -260,7 +250,7 @@ func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...a
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
||||
func mapRow(store interface{}, rows pgx.Rows, nargs int) (err error) {
|
||||
err = nil
|
||||
|
||||
val := reflect.Indirect(reflect.ValueOf(store))
|
||||
@@ -278,7 +268,7 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func InsertReturnId(c QueryInterface, store interface{}, tablename string, returnName string) (id string, err error) {
|
||||
func InsertReturnId(c db.Db, store interface{}, tablename string, returnName string) (id string, err error) {
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
@@ -315,14 +305,8 @@ func InsertReturnId(c QueryInterface, store interface{}, tablename string, retur
|
||||
return
|
||||
}
|
||||
|
||||
func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename string, args ...any) (*T, error) {
|
||||
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", var_to_extract, tablename))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer db_query.Close()
|
||||
|
||||
rows, err := db_query.Query(args...)
|
||||
func GetDbVar[T interface{}](c db.Db, var_to_extract string, tablename string, args ...any) (*T, error) {
|
||||
rows, err := c.Query(fmt.Sprintf("select %s from %s", var_to_extract, tablename), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -340,7 +324,7 @@ func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename
|
||||
return dat, nil
|
||||
}
|
||||
|
||||
func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error {
|
||||
func GetDBOnce(db db.Db, store interface{}, tablename string, args ...any) error {
|
||||
t := reflect.TypeOf(store).Elem()
|
||||
|
||||
query, nargs := generateQuery(t)
|
||||
@@ -372,7 +356,7 @@ func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...a
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateStatus(c QueryInterface, table string, id string, status int) (err error) {
|
||||
func UpdateStatus(c db.Db, table string, id string, status int) (err error) {
|
||||
_, err = c.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package model_classes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
)
|
||||
|
||||
var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error")
|
||||
|
||||
func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
|
||||
func AddDataPoint(db db.Db, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
|
||||
id = ""
|
||||
result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode, status) values ($1, $2, $3, 1) returning id;", class_id, file_path, mode)
|
||||
if err != nil {
|
||||
@@ -24,7 +24,7 @@ func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT
|
||||
return
|
||||
}
|
||||
|
||||
func UpdateDataPointStatus(db *sql.DB, data_point_id string, status int, message *string) (err error) {
|
||||
func UpdateDataPointStatus(db db.Db, data_point_id string, status int, message *string) (err error) {
|
||||
_, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package model_classes
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) {
|
||||
return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id)
|
||||
}
|
||||
|
||||
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
||||
func ModelHasDataPoints(db db.Db, model_id string) (result bool, err error) {
|
||||
result = false
|
||||
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
|
||||
if err != nil {
|
||||
@@ -31,7 +31,7 @@ func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
|
||||
|
||||
var ClassAlreadyExists = errors.New("Class aready exists")
|
||||
|
||||
func CreateClass(db *sql.DB, model_id string, order int, name string) (id string, err error) {
|
||||
func CreateClass(db db.Db, model_id string, order int, name string) (id string, err error) {
|
||||
id = ""
|
||||
rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name)
|
||||
if err != nil {
|
||||
@@ -57,7 +57,7 @@ func CreateClass(db *sql.DB, model_id string, order int, name string) (id string
|
||||
return
|
||||
}
|
||||
|
||||
func GetNumberOfWrongDataPoints(db *sql.DB, model_id string) (number int, err error) {
|
||||
func GetNumberOfWrongDataPoints(db db.Db, model_id string) (number int, err error) {
|
||||
number = 0
|
||||
rows, err := db.Query("select count(mdp.id) from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model_id)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package models_train
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
@@ -40,7 +40,7 @@ func getDir() string {
|
||||
}
|
||||
|
||||
// This function creates a new model_definition
|
||||
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) {
|
||||
func MakeDefenition(db db.Db, model_id string, target_accuracy int) (id string, err error) {
|
||||
var NewDefinition = struct {
|
||||
ModelId string `db:"model_id"`
|
||||
TargetAccuracy int `db:"target_accuracy"`
|
||||
@@ -54,12 +54,12 @@ func ModelDefinitionUpdateStatus(c BasePack, id string, status ModelDefinitionSt
|
||||
return
|
||||
}
|
||||
|
||||
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
|
||||
func MakeLayer(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
|
||||
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
|
||||
return
|
||||
}
|
||||
|
||||
func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) {
|
||||
func MakeLayerExpandable(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) {
|
||||
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)", def_id, layer_order, layer_type, shape, exp_type)
|
||||
return
|
||||
}
|
||||
@@ -1398,7 +1398,7 @@ func generateDefinitions(c BasePack, model *BaseModel, target_accuracy int, numb
|
||||
return nil
|
||||
}
|
||||
|
||||
func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatus) (err error) {
|
||||
func ExpModelHeadUpdateStatus(db db.Db, id string, status ModelDefinitionStatus) (err error) {
|
||||
_, err = db.Exec("update model_definition set status = $1 where id = $2", status, id)
|
||||
return
|
||||
}
|
||||
@@ -1877,18 +1877,7 @@ func handleTrain(handle *Handle) {
|
||||
|
||||
//Update the classes
|
||||
{
|
||||
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3")
|
||||
err = err2
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
c.Logger.Error("Two errors happended rollback failed", "err", _err)
|
||||
}
|
||||
return failed()
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
_, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
|
||||
if err != nil {
|
||||
_err := c.RollbackTx()
|
||||
if _err != nil {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package task_runner
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
|
||||
"github.com/charmbracelet/log"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
/**
|
||||
* Actually runs the code
|
||||
*/
|
||||
func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) {
|
||||
func runner(config Config, db db.Db, task_channel chan Task, index int, back_channel chan int) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
@@ -125,7 +125,7 @@ func attentionSeeker(config Config, back_channel chan int) {
|
||||
/**
|
||||
* Manages what worker should to Work
|
||||
*/
|
||||
func RunnerOrchestrator(db *sql.DB, config Config) {
|
||||
func RunnerOrchestrator(db db.Db, config Config) {
|
||||
logger := log.NewWithOptions(os.Stdout, log.Options{
|
||||
ReportCaller: true,
|
||||
ReportTimestamp: true,
|
||||
@@ -211,6 +211,6 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
|
||||
}
|
||||
}
|
||||
|
||||
func StartRunners(db *sql.DB, config Config) {
|
||||
func StartRunners(db db.Db, config Config) {
|
||||
go RunnerOrchestrator(db, config)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package users
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
|
||||
@@ -44,12 +44,12 @@ func genToken() string {
|
||||
return hex.EncodeToString(token)
|
||||
}
|
||||
|
||||
func deleteToken(db *sql.DB, userId string, time time.Time) (err error) {
|
||||
func deleteToken(db db.Db, userId string, time time.Time) (err error) {
|
||||
_, err = db.Exec("delete from tokens where emit_day=$1 and user_id=$2", time, userId)
|
||||
return
|
||||
}
|
||||
|
||||
func generateToken(db *sql.DB, email string, password string, name string) (string, bool) {
|
||||
func generateToken(db db.Db, email string, password string, name string) (string, bool) {
|
||||
row, err := db.Query("select id, salt, password from users where email = $1;", email)
|
||||
if err != nil || !row.Next() {
|
||||
return "", false
|
||||
@@ -106,7 +106,7 @@ func DeleteUser(base BasePack, task Task) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func UsersEndpints(db *sql.DB, handle *Handle) {
|
||||
func UsersEndpints(db db.Db, handle *Handle) {
|
||||
|
||||
type UserLogin struct {
|
||||
Email string `json:"email"`
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
|
||||
|
||||
@@ -88,7 +88,7 @@ func failLog(err error) {
|
||||
log.Fatal("Failed on setup", "error", err)
|
||||
}
|
||||
|
||||
func (c *Config) Cleanup(db *sql.DB) {
|
||||
func (c *Config) Cleanup(db db.Db) {
|
||||
if c.CleanUpOnStartup != 1 {
|
||||
return
|
||||
}
|
||||
@@ -125,7 +125,7 @@ func (c *Config) Cleanup(db *sql.DB) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Config) GenerateToken(db *sql.DB) {
|
||||
func (c *Config) GenerateToken(db db.Db) {
|
||||
if c.ServiceUser.User == "" {
|
||||
log.Fatal("A user needs to be set in a configuration file")
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -13,10 +13,13 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
|
||||
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
||||
"github.com/charmbracelet/log"
|
||||
"github.com/go-playground/validator/v10"
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
)
|
||||
|
||||
type AnyMap = map[string]interface{}
|
||||
@@ -39,7 +42,7 @@ type Handler interface {
|
||||
}
|
||||
|
||||
type Handle struct {
|
||||
Db *sql.DB
|
||||
Db db.Db
|
||||
gets []HandleFunc
|
||||
posts []HandleFunc
|
||||
deletes []HandleFunc
|
||||
@@ -235,16 +238,16 @@ type Context struct {
|
||||
Token *string
|
||||
User *dbtypes.User
|
||||
Logger *log.Logger
|
||||
Db *sql.DB
|
||||
Db db.Db
|
||||
Writer http.ResponseWriter
|
||||
R *http.Request
|
||||
Tx *sql.Tx
|
||||
Tx pgx.Tx
|
||||
ShowMessage bool
|
||||
Handle *Handle
|
||||
}
|
||||
|
||||
// This is required for this to integrate simealy with my orl
|
||||
func (c Context) GetDb() *sql.DB {
|
||||
func (c Context) GetDb() db.Db {
|
||||
return c.Db
|
||||
}
|
||||
|
||||
@@ -256,19 +259,22 @@ func (c Context) GetHost() string {
|
||||
return c.Handle.Config.Hostname
|
||||
}
|
||||
|
||||
func (c Context) Query(query string, args ...any) (*sql.Rows, error) {
|
||||
return c.Db.Query(query, args...)
|
||||
}
|
||||
|
||||
func (c Context) Exec(query string, args ...any) (sql.Result, error) {
|
||||
return c.Db.Exec(query, args...)
|
||||
}
|
||||
|
||||
func (c Context) Prepare(str string) (*sql.Stmt, error) {
|
||||
func (c Context) Query(query string, args ...any) (pgx.Rows, error) {
|
||||
if c.Tx == nil {
|
||||
return c.Db.Prepare(str)
|
||||
return c.Db.Query(query, args...)
|
||||
}
|
||||
return c.Tx.Prepare(str)
|
||||
return c.Tx.Query(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (c Context) Exec(query string, args ...any) (pgconn.CommandTag, error) {
|
||||
if c.Tx == nil {
|
||||
return c.Db.Exec(query, args...)
|
||||
}
|
||||
return c.Tx.Exec(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (c Context) Begin() (pgx.Tx, error) {
|
||||
return c.Db.Begin()
|
||||
}
|
||||
|
||||
var TransactionAlreadyStarted = errors.New("Transaction already started")
|
||||
@@ -279,7 +285,7 @@ func (c *Context) StartTx() error {
|
||||
return TransactionAlreadyStarted
|
||||
}
|
||||
var err error = nil
|
||||
c.Tx, err = c.Db.Begin()
|
||||
c.Tx, err = c.Begin()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -287,7 +293,7 @@ func (c *Context) CommitTx() error {
|
||||
if c.Tx == nil {
|
||||
return TransactionNotStarted
|
||||
}
|
||||
err := c.Tx.Commit()
|
||||
err := c.Tx.Commit(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -299,7 +305,7 @@ func (c *Context) RollbackTx() error {
|
||||
if c.Tx == nil {
|
||||
return TransactionNotStarted
|
||||
}
|
||||
err := c.Tx.Rollback()
|
||||
err := c.Tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -512,7 +518,7 @@ func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileType
|
||||
})
|
||||
}
|
||||
|
||||
func NewHandler(db *sql.DB, config Config) *Handle {
|
||||
func NewHandler(db db.Db, config Config) *Handle {
|
||||
var gets []HandleFunc
|
||||
var posts []HandleFunc
|
||||
var deletes []HandleFunc
|
||||
|
||||
Reference in New Issue
Block a user