moved to psql pool closes #99

This commit is contained in:
Andre Henriques 2024-04-17 17:46:43 +01:00
parent 8ece8306dd
commit a3913ccdf5
14 changed files with 132 additions and 98 deletions

4
go.mod
View File

@ -23,16 +23,20 @@ require (
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx v3.6.2+incompatible // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect
github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/reflow v0.3.0 // indirect
github.com/muesli/termenv v0.15.2 // indirect github.com/muesli/termenv v0.15.2 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/rivo/uniseg v0.4.6 // indirect github.com/rivo/uniseg v0.4.6 // indirect
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect golang.org/x/exp v0.0.0-20240119083558-1b970713d09a // indirect
golang.org/x/net v0.21.0 // indirect golang.org/x/net v0.21.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.17.0 // indirect golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
google.golang.org/protobuf v1.32.0 // indirect google.golang.org/protobuf v1.32.0 // indirect

8
go.sum
View File

@ -39,8 +39,12 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx v3.6.2+incompatible h1:2zP5OD7kiyR3xzRYMhOcXVvkDZsImVXfj+yIyTQf3/o=
github.com/jackc/pgx v3.6.2+incompatible/go.mod h1:0ZGrqGqkRlliWnWB4zKnWtjbSWbGkVEFm4TeybAXq+I=
github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@ -60,6 +64,8 @@ github.com/muesli/reflow v0.3.0 h1:IFsN6K9NfGtjeggFP+68I4chLZV2yIKsXJFNZ+eWh6s=
github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8= github.com/muesli/reflow v0.3.0/go.mod h1:pbwTDkVPibjO2kyvBQRBxTWEEGDGq0FlB1BIKtnHY/8=
github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo=
github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
@ -81,6 +87,8 @@ golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRj
golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

48
logic/db/db.go Normal file
View File

@ -0,0 +1,48 @@
package db
import (
"context"
"github.com/charmbracelet/log"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
)
type DbContainer struct {
pool *pgxpool.Pool
}
type Db interface {
Query(query string, args ...any) (pgx.Rows, error)
Exec(query string, args ...any) (pgconn.CommandTag, error)
Begin() (pgx.Tx, error)
}
func StartUp(url string) DbContainer {
dbpool, err := pgxpool.New(context.Background(), url)
if err != nil {
log.Fatal("Cloud not create database pool")
panic(err)
}
return DbContainer{
pool: dbpool,
}
}
func (db DbContainer) Close() {
db.pool.Close()
}
func (db DbContainer) Query(query string, args ...any) (pgx.Rows, error) {
return db.pool.Query(context.Background(), query, args...)
}
func (db DbContainer) Exec(query string, args ...any) (pgconn.CommandTag, error) {
return db.pool.Exec(context.Background(), query, args...)
}
func (db DbContainer) Begin() (pgx.Tx, error) {
return db.pool.Begin(context.Background())
}

View File

@ -1,8 +1,9 @@
package dbtypes package dbtypes
import ( import (
"database/sql"
"errors" "errors"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
) )
const ( const (
@ -79,7 +80,7 @@ type BaseModel struct {
var ModelNotFoundError = errors.New("Model not found error") var ModelNotFoundError = errors.New("Model not found error")
func GetBaseModel(db *sql.DB, id string) (base *BaseModel, err error) { func GetBaseModel(db db.Db, id string) (base *BaseModel, err error) {
var model BaseModel var model BaseModel
err = GetDBOnce(db, &model, "models where id=$1", id) err = GetDBOnce(db, &model, "models where id=$1", id)
if err != nil { if err != nil {

View File

@ -1,7 +1,7 @@
package dbtypes package dbtypes
import ( import (
"database/sql" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
) )
type UserType int type UserType int
@ -20,8 +20,7 @@ type User struct {
UserType int `db:"u.user_type"` UserType int `db:"u.user_type"`
} }
func UserFromToken(db *sql.DB, token string) (*User, error) { func UserFromToken(db db.Db, token string) (*User, error) {
var user User var user User
err := GetDBOnce(db, &user, "users as u inner join tokens as t on t.user_id = u.id where t.token = $1;", token) err := GetDBOnce(db, &user, "users as u inner join tokens as t on t.user_id = u.id where t.token = $1;", token)

View File

@ -1,7 +1,6 @@
package dbtypes package dbtypes
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -14,16 +13,19 @@ import (
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5"
db "git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
) )
type BasePack interface { type BasePack interface {
GetDb() *sql.DB GetDb() db.Db
GetLogger() *log.Logger GetLogger() *log.Logger
GetHost() string GetHost() string
} }
type BasePackStruct struct { type BasePackStruct struct {
Db *sql.DB Db db.Db
Logger *log.Logger Logger *log.Logger
Host string Host string
} }
@ -32,7 +34,7 @@ func (b BasePackStruct) GetHost() string {
return b.Host return b.Host
} }
func (b BasePackStruct) GetDb() *sql.DB { func (b BasePackStruct) GetDb() db.Db {
return b.Db return b.Db
} }
@ -224,24 +226,12 @@ func generateQuery(t reflect.Type) (query string, nargs int) {
return return
} }
type QueryInterface interface { func GetDbMultitple[T interface{}](c db.Db, tablename string, args ...any) ([]*T, error) {
Prepare(str string) (*sql.Stmt, error)
Query(query string, args ...any) (*sql.Rows, error)
Exec(query string, args ...any) (sql.Result, error)
}
func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) {
t := reflect.TypeFor[T]() t := reflect.TypeFor[T]()
query, nargs := generateQuery(t) query, nargs := generateQuery(t)
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename)) rows, err := c.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
if err != nil {
return nil, err
}
defer db_query.Close()
rows, err := db_query.Query(args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -260,7 +250,7 @@ func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...a
return list, nil return list, nil
} }
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { func mapRow(store interface{}, rows pgx.Rows, nargs int) (err error) {
err = nil err = nil
val := reflect.Indirect(reflect.ValueOf(store)) val := reflect.Indirect(reflect.ValueOf(store))
@ -278,7 +268,7 @@ func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
return nil return nil
} }
func InsertReturnId(c QueryInterface, store interface{}, tablename string, returnName string) (id string, err error) { func InsertReturnId(c db.Db, store interface{}, tablename string, returnName string) (id string, err error) {
t := reflect.TypeOf(store).Elem() t := reflect.TypeOf(store).Elem()
query, nargs := generateQuery(t) query, nargs := generateQuery(t)
@ -315,14 +305,8 @@ func InsertReturnId(c QueryInterface, store interface{}, tablename string, retur
return return
} }
func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename string, args ...any) (*T, error) { func GetDbVar[T interface{}](c db.Db, var_to_extract string, tablename string, args ...any) (*T, error) {
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", var_to_extract, tablename)) rows, err := c.Query(fmt.Sprintf("select %s from %s", var_to_extract, tablename), args...)
if err != nil {
return nil, err
}
defer db_query.Close()
rows, err := db_query.Query(args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -340,7 +324,7 @@ func GetDbVar[T interface{}](c QueryInterface, var_to_extract string, tablename
return dat, nil return dat, nil
} }
func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error { func GetDBOnce(db db.Db, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem() t := reflect.TypeOf(store).Elem()
query, nargs := generateQuery(t) query, nargs := generateQuery(t)
@ -372,7 +356,7 @@ func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...a
return nil return nil
} }
func UpdateStatus(c QueryInterface, table string, id string, status int) (err error) { func UpdateStatus(c db.Db, table string, id string, status int) (err error) {
_, err = c.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) _, err = c.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
return return
} }

View File

@ -1,15 +1,15 @@
package model_classes package model_classes
import ( import (
"database/sql"
"errors" "errors"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
) )
var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error") var FailedToGetIdAfterInsertError = errors.New("Failed to Get Id After Insert Error")
func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) { func AddDataPoint(db db.Db, class_id string, file_path string, mode DATA_POINT_MODE) (id string, err error) {
id = "" id = ""
result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode, status) values ($1, $2, $3, 1) returning id;", class_id, file_path, mode) result, err := db.Query("insert into model_data_point (class_id, file_path, model_mode, status) values ($1, $2, $3, 1) returning id;", class_id, file_path, mode)
if err != nil { if err != nil {
@ -24,7 +24,7 @@ func AddDataPoint(db *sql.DB, class_id string, file_path string, mode DATA_POINT
return return
} }
func UpdateDataPointStatus(db *sql.DB, data_point_id string, status int, message *string) (err error) { func UpdateDataPointStatus(db db.Db, data_point_id string, status int, message *string) (err error) {
_, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id) _, err = db.Exec("update model_data_point set status=$1, status_message=$2 where id=$3", status, message, data_point_id)
return return
} }

View File

@ -1,9 +1,9 @@
package model_classes package model_classes
import ( import (
"database/sql"
"errors" "errors"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
) )
@ -18,7 +18,7 @@ func ListClasses(c BasePack, model_id string) (cls []*ModelClass, err error) {
return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id) return GetDbMultitple[ModelClass](c.GetDb(), "model_classes where model_id=$1", model_id)
} }
func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) { func ModelHasDataPoints(db db.Db, model_id string) (result bool, err error) {
result = false result = false
rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id) rows, err := db.Query("select mdp.id from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 limit 1;", model_id)
if err != nil { if err != nil {
@ -31,7 +31,7 @@ func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) {
var ClassAlreadyExists = errors.New("Class aready exists") var ClassAlreadyExists = errors.New("Class aready exists")
func CreateClass(db *sql.DB, model_id string, order int, name string) (id string, err error) { func CreateClass(db db.Db, model_id string, order int, name string) (id string, err error) {
id = "" id = ""
rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name) rows, err := db.Query("select id from model_classes where model_id=$1 and name=$2;", model_id, name)
if err != nil { if err != nil {
@ -57,7 +57,7 @@ func CreateClass(db *sql.DB, model_id string, order int, name string) (id string
return return
} }
func GetNumberOfWrongDataPoints(db *sql.DB, model_id string) (number int, err error) { func GetNumberOfWrongDataPoints(db db.Db, model_id string) (number int, err error) {
number = 0 number = 0
rows, err := db.Query("select count(mdp.id) from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model_id) rows, err := db.Query("select count(mdp.id) from model_data_point as mdp join model_classes as mc on mc.id = mdp.class_id where mc.model_id=$1 and mdp.status=-1;", model_id)
if err != nil { if err != nil {

View File

@ -1,7 +1,6 @@
package models_train package models_train
import ( import (
"database/sql"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -14,6 +13,7 @@ import (
"strings" "strings"
"text/template" "text/template"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
@ -40,7 +40,7 @@ func getDir() string {
} }
// This function creates a new model_definition // This function creates a new model_definition
func MakeDefenition(db *sql.DB, model_id string, target_accuracy int) (id string, err error) { func MakeDefenition(db db.Db, model_id string, target_accuracy int) (id string, err error) {
var NewDefinition = struct { var NewDefinition = struct {
ModelId string `db:"model_id"` ModelId string `db:"model_id"`
TargetAccuracy int `db:"target_accuracy"` TargetAccuracy int `db:"target_accuracy"`
@ -54,12 +54,12 @@ func ModelDefinitionUpdateStatus(c BasePack, id string, status ModelDefinitionSt
return return
} }
func MakeLayer(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string) (err error) { func MakeLayer(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string) (err error) {
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape) _, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape) values ($1, $2, $3, $4)", def_id, layer_order, layer_type, shape)
return return
} }
func MakeLayerExpandable(db *sql.DB, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) { func MakeLayerExpandable(db db.Db, def_id string, layer_order int, layer_type LayerType, shape string, exp_type int) (err error) {
_, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)", def_id, layer_order, layer_type, shape, exp_type) _, err = db.Exec("insert into model_definition_layer (def_id, layer_order, layer_type, shape, exp_type) values ($1, $2, $3, $4, $5)", def_id, layer_order, layer_type, shape, exp_type)
return return
} }
@ -1398,7 +1398,7 @@ func generateDefinitions(c BasePack, model *BaseModel, target_accuracy int, numb
return nil return nil
} }
func ExpModelHeadUpdateStatus(db *sql.DB, id string, status ModelDefinitionStatus) (err error) { func ExpModelHeadUpdateStatus(db db.Db, id string, status ModelDefinitionStatus) (err error) {
_, err = db.Exec("update model_definition set status = $1 where id = $2", status, id) _, err = db.Exec("update model_definition set status = $1 where id = $2", status, id)
return return
} }
@ -1877,18 +1877,7 @@ func handleTrain(handle *Handle) {
//Update the classes //Update the classes
{ {
stmt, err2 := c.Prepare("update model_classes set status=$1 where status=$2 and model_id=$3") _, err = c.Exec("update model_classes set status=$1 where status=$2 and model_id=$3", MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
err = err2
if err != nil {
_err := c.RollbackTx()
if _err != nil {
c.Logger.Error("Two errors happended rollback failed", "err", _err)
}
return failed()
}
defer stmt.Close()
_, err = stmt.Exec(MODEL_CLASS_STATUS_TRAINING, MODEL_CLASS_STATUS_TO_TRAIN, model.Id)
if err != nil { if err != nil {
_err := c.RollbackTx() _err := c.RollbackTx()
if _err != nil { if _err != nil {

View File

@ -1,7 +1,6 @@
package task_runner package task_runner
import ( import (
"database/sql"
"fmt" "fmt"
"math" "math"
"os" "os"
@ -10,6 +9,7 @@ import (
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/train"
@ -21,7 +21,7 @@ import (
/** /**
* Actually runs the code * Actually runs the code
*/ */
func runner(config Config, db *sql.DB, task_channel chan Task, index int, back_channel chan int) { func runner(config Config, db db.Db, task_channel chan Task, index int, back_channel chan int) {
logger := log.NewWithOptions(os.Stdout, log.Options{ logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true, ReportCaller: true,
ReportTimestamp: true, ReportTimestamp: true,
@ -125,7 +125,7 @@ func attentionSeeker(config Config, back_channel chan int) {
/** /**
* Manages what worker should to Work * Manages what worker should to Work
*/ */
func RunnerOrchestrator(db *sql.DB, config Config) { func RunnerOrchestrator(db db.Db, config Config) {
logger := log.NewWithOptions(os.Stdout, log.Options{ logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true, ReportCaller: true,
ReportTimestamp: true, ReportTimestamp: true,
@ -211,6 +211,6 @@ func RunnerOrchestrator(db *sql.DB, config Config) {
} }
} }
func StartRunners(db *sql.DB, config Config) { func StartRunners(db db.Db, config Config) {
go RunnerOrchestrator(db, config) go RunnerOrchestrator(db, config)
} }

View File

@ -2,7 +2,6 @@ package users
import ( import (
"crypto/rand" "crypto/rand"
"database/sql"
"encoding/hex" "encoding/hex"
"io" "io"
"net/http" "net/http"
@ -10,6 +9,7 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
@ -44,12 +44,12 @@ func genToken() string {
return hex.EncodeToString(token) return hex.EncodeToString(token)
} }
func deleteToken(db *sql.DB, userId string, time time.Time) (err error) { func deleteToken(db db.Db, userId string, time time.Time) (err error) {
_, err = db.Exec("delete from tokens where emit_day=$1 and user_id=$2", time, userId) _, err = db.Exec("delete from tokens where emit_day=$1 and user_id=$2", time, userId)
return return
} }
func generateToken(db *sql.DB, email string, password string, name string) (string, bool) { func generateToken(db db.Db, email string, password string, name string) (string, bool) {
row, err := db.Query("select id, salt, password from users where email = $1;", email) row, err := db.Query("select id, salt, password from users where email = $1;", email)
if err != nil || !row.Next() { if err != nil || !row.Next() {
return "", false return "", false
@ -106,7 +106,7 @@ func DeleteUser(base BasePack, task Task) (err error) {
return return
} }
func UsersEndpints(db *sql.DB, handle *Handle) { func UsersEndpints(db db.Db, handle *Handle) {
type UserLogin struct { type UserLogin struct {
Email string `json:"email"` Email string `json:"email"`

View File

@ -1,10 +1,10 @@
package utils package utils
import ( import (
"database/sql"
"os" "os"
"strings" "strings"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils"
@ -88,7 +88,7 @@ func failLog(err error) {
log.Fatal("Failed on setup", "error", err) log.Fatal("Failed on setup", "error", err)
} }
func (c *Config) Cleanup(db *sql.DB) { func (c *Config) Cleanup(db db.Db) {
if c.CleanUpOnStartup != 1 { if c.CleanUpOnStartup != 1 {
return return
} }
@ -125,7 +125,7 @@ func (c *Config) Cleanup(db *sql.DB) {
} }
} }
func (c *Config) GenerateToken(db *sql.DB) { func (c *Config) GenerateToken(db db.Db) {
if c.ServiceUser.User == "" { if c.ServiceUser.User == "" {
log.Fatal("A user needs to be set in a configuration file") log.Fatal("A user needs to be set in a configuration file")
} }

View File

@ -2,7 +2,7 @@ package utils
import ( import (
"bytes" "bytes"
"database/sql" "context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -13,10 +13,13 @@ import (
"strings" "strings"
"time" "time"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"github.com/goccy/go-json" "github.com/goccy/go-json"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
) )
type AnyMap = map[string]interface{} type AnyMap = map[string]interface{}
@ -39,7 +42,7 @@ type Handler interface {
} }
type Handle struct { type Handle struct {
Db *sql.DB Db db.Db
gets []HandleFunc gets []HandleFunc
posts []HandleFunc posts []HandleFunc
deletes []HandleFunc deletes []HandleFunc
@ -235,16 +238,16 @@ type Context struct {
Token *string Token *string
User *dbtypes.User User *dbtypes.User
Logger *log.Logger Logger *log.Logger
Db *sql.DB Db db.Db
Writer http.ResponseWriter Writer http.ResponseWriter
R *http.Request R *http.Request
Tx *sql.Tx Tx pgx.Tx
ShowMessage bool ShowMessage bool
Handle *Handle Handle *Handle
} }
// This is required for this to integrate simealy with my orl // This is required for this to integrate simealy with my orl
func (c Context) GetDb() *sql.DB { func (c Context) GetDb() db.Db {
return c.Db return c.Db
} }
@ -256,19 +259,22 @@ func (c Context) GetHost() string {
return c.Handle.Config.Hostname return c.Handle.Config.Hostname
} }
func (c Context) Query(query string, args ...any) (*sql.Rows, error) { func (c Context) Query(query string, args ...any) (pgx.Rows, error) {
if c.Tx == nil {
return c.Db.Query(query, args...) return c.Db.Query(query, args...)
} }
return c.Tx.Query(context.Background(), query, args...)
}
func (c Context) Exec(query string, args ...any) (sql.Result, error) { func (c Context) Exec(query string, args ...any) (pgconn.CommandTag, error) {
if c.Tx == nil {
return c.Db.Exec(query, args...) return c.Db.Exec(query, args...)
} }
return c.Tx.Exec(context.Background(), query, args...)
func (c Context) Prepare(str string) (*sql.Stmt, error) {
if c.Tx == nil {
return c.Db.Prepare(str)
} }
return c.Tx.Prepare(str)
func (c Context) Begin() (pgx.Tx, error) {
return c.Db.Begin()
} }
var TransactionAlreadyStarted = errors.New("Transaction already started") var TransactionAlreadyStarted = errors.New("Transaction already started")
@ -279,7 +285,7 @@ func (c *Context) StartTx() error {
return TransactionAlreadyStarted return TransactionAlreadyStarted
} }
var err error = nil var err error = nil
c.Tx, err = c.Db.Begin() c.Tx, err = c.Begin()
return err return err
} }
@ -287,7 +293,7 @@ func (c *Context) CommitTx() error {
if c.Tx == nil { if c.Tx == nil {
return TransactionNotStarted return TransactionNotStarted
} }
err := c.Tx.Commit() err := c.Tx.Commit(context.Background())
if err != nil { if err != nil {
return err return err
} }
@ -299,7 +305,7 @@ func (c *Context) RollbackTx() error {
if c.Tx == nil { if c.Tx == nil {
return TransactionNotStarted return TransactionNotStarted
} }
err := c.Tx.Rollback() err := c.Tx.Rollback(context.Background())
if err != nil { if err != nil {
return err return err
} }
@ -512,7 +518,7 @@ func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileType
}) })
} }
func NewHandler(db *sql.DB, config Config) *Handle { func NewHandler(db db.Db, config Config) *Handle {
var gets []HandleFunc var gets []HandleFunc
var posts []HandleFunc var posts []HandleFunc
var deletes []HandleFunc var deletes []HandleFunc

11
main.go
View File

@ -1,12 +1,12 @@
package main package main
import ( import (
"database/sql"
"fmt" "fmt"
"github.com/charmbracelet/log" "github.com/charmbracelet/log"
_ "github.com/lib/pq" _ "github.com/lib/pq"
"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks"
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/runner" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/runner"
@ -23,23 +23,18 @@ const (
) )
func main() { func main() {
psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+ psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+
"password=%s dbname=%s sslmode=disable", "password=%s dbname=%s sslmode=disable",
host, port, user, password, dbname) host, port, user, password, dbname)
db, err := sql.Open("postgres", psqlInfo) db := db.StartUp(psqlInfo)
if err != nil {
panic(err)
}
defer db.Close() defer db.Close()
log.Info("Starting server on :5002!")
config := LoadConfig() config := LoadConfig()
log.Info("Config loaded!", "config", config) log.Info("Config loaded!", "config", config)
config.GenerateToken(db) config.GenerateToken(db)
db.SetMaxOpenConns(config.DbInfo.MaxConnections)
StartRunners(db, config) StartRunners(db, config)
//TODO check if file structure exists to save data //TODO check if file structure exists to save data