diff --git a/logic/db_types/user.go b/logic/db_types/user.go index 08838b4..735a443 100644 --- a/logic/db_types/user.go +++ b/logic/db_types/user.go @@ -2,7 +2,6 @@ package dbtypes import ( "database/sql" - "errors" ) type UserType int @@ -14,34 +13,20 @@ const ( ) type User struct { - Id string - Username string - Email string - UserType int + Id string `db:"u.id"` + Username string `db:"u.username"` + Email string `db:"u.email"` + UserType int `db:"u.user_type"` } -var ErrUserNotFound = errors.New("User Not found") - func UserFromToken(db *sql.DB, token string) (*User, error) { - rows, err := db.Query("select users.id, users.username, users.email, users.user_type from users inner join tokens on tokens.user_id = users.id where tokens.token = $1;", token) - if err != nil { - return nil, err - } - defer rows.Close() - var id string - var username string - var email string - var user_type int + var user User - if !rows.Next() { - return nil, ErrUserNotFound - } - - err = rows.Scan(&id, &username, &email, &user_type) + err := GetDBOnce(db, &user, "users as u inner join tokens as t on t.user_id = u.id where t.token = $1;", token) if err != nil { return nil, err } - return &User{id, username, email, user_type}, nil + return &user, nil } diff --git a/logic/db_types/utils.go b/logic/db_types/utils.go new file mode 100644 index 0000000..80466ec --- /dev/null +++ b/logic/db_types/utils.go @@ -0,0 +1,347 @@ +package dbtypes + +import ( + "database/sql" + "errors" + "fmt" + "io" + "mime" + "net/http" + "net/url" + "reflect" + "strconv" + "strings" + + "github.com/charmbracelet/log" + "github.com/google/uuid" +) + +type BasePack interface { + GetDb() *sql.DB + GetLogger() *log.Logger +} + +type BasePackStruct struct { + Db *sql.DB + Logger *log.Logger +} + +func (b BasePackStruct) GetDb() *sql.DB { + return b.Db +} + +func (b BasePackStruct) GetLogger() *log.Logger { + return b.Logger +} + +func CheckEmpty(f url.Values, path string) bool { + return !f.Has(path) || f.Get(path) == "" +} + +func CheckNumber(f url.Values, path string, number *int) bool { + if CheckEmpty(f, path) { + fmt.Println("here", path) + fmt.Println(f.Get(path)) + return false + } + n, err := strconv.Atoi(f.Get(path)) + if err != nil { + fmt.Println(err) + return false + } + *number = n + return true +} + +func CheckFloat64(f url.Values, path string, number *float64) bool { + if CheckEmpty(f, path) { + fmt.Println("here", path) + fmt.Println(f.Get(path)) + return false + } + n, err := strconv.ParseFloat(f.Get(path), 64) + if err != nil { + fmt.Println(err) + return false + } + *number = n + return true +} + +func CheckId(f url.Values, path string) bool { + return !CheckEmpty(f, path) && IsValidUUID(f.Get(path)) +} + +func IsValidUUID(u string) bool { + _, err := uuid.Parse(u) + return err == nil +} + +type maxBytesReader struct { + w http.ResponseWriter + r io.ReadCloser // underlying reader + i int64 // max bytes initially, for MaxBytesError + n int64 // max bytes remaining + err error // sticky error +} + +type MaxBytesError struct { + Limit int64 +} + +func (e *MaxBytesError) Error() string { + // Due to Hyrum's law, this text cannot be changed. + return "http: request body too large" +} + +func (l *maxBytesReader) Read(p []byte) (n int, err error) { + if l.err != nil { + return 0, l.err + } + if len(p) == 0 { + return 0, nil + } + // If they asked for a 32KB byte read but only 5 bytes are + // remaining, no need to read 32KB. 6 bytes will answer the + // question of the whether we hit the limit or go past it. + // 0 < len(p) < 2^63 + if int64(len(p))-1 > l.n { + p = p[:l.n+1] + } + n, err = l.r.Read(p) + + if int64(n) <= l.n { + l.n -= int64(n) + l.err = err + return n, err + } + + n = int(l.n) + l.n = 0 + + // The server code and client code both use + // maxBytesReader. This "requestTooLarge" check is + // only used by the server code. To prevent binaries + // which only using the HTTP Client code (such as + // cmd/go) from also linking in the HTTP server, don't + // use a static type assertion to the server + // "*response" type. Check this interface instead: + type requestTooLarger interface { + requestTooLarge() + } + if res, ok := l.w.(requestTooLarger); ok { + res.requestTooLarge() + } + l.err = &MaxBytesError{l.i} + return n, l.err +} + +func (l *maxBytesReader) Close() error { + return l.r.Close() +} + +func MyParseForm(r *http.Request) (vs url.Values, err error) { + if r.Body == nil { + err = errors.New("missing form body") + return + } + ct := r.Header.Get("Content-Type") + // RFC 7231, section 3.1.1.5 - empty type + // MAY be treated as application/octet-stream + if ct == "" { + ct = "application/octet-stream" + } + ct, _, err = mime.ParseMediaType(ct) + switch { + case ct == "application/x-www-form-urlencoded": + var reader io.Reader = r.Body + maxFormSize := int64(1<<63 - 1) + if _, ok := r.Body.(*maxBytesReader); !ok { + maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + reader = io.LimitReader(r.Body, maxFormSize+1) + } + b, e := io.ReadAll(reader) + if e != nil { + if err == nil { + err = e + } + break + } + if int64(len(b)) > maxFormSize { + err = errors.New("http: POST too large") + return + } + vs, e = url.ParseQuery(string(b)) + if err == nil { + err = e + } + case ct == "multipart/form-data": + // handled by ParseMultipartForm (which is calling us, or should be) + // TODO(bradfitz): there are too many possible + // orders to call too many functions here. + // Clean this up and write more tests. + // request_test.go contains the start of this, + // in TestParseMultipartFormOrder and others. + } + return +} + +type JustId struct { + Id string `json:"id" validate:"required"` +} + +type Generic struct{ reflect.Type } + +var NotFoundError = errors.New("Not found") +var CouldNotInsert = errors.New("Could not insert") + +func generateQuery(t reflect.Type) (query string, nargs int) { + nargs = t.NumField() + query = "" + + for i := 0; i < nargs; i += 1 { + field := t.Field(i) + name, ok := field.Tag.Lookup("db") + if !ok { + name = field.Name + } + + if name == "__nil__" { + continue + } + query += strings.ToLower(name) + "," + } + + // Remove the last comma + query = query[0 : len(query)-1] + + return +} + +type QueryInterface interface { + Prepare(str string) (*sql.Stmt, error) + Query(query string, args ...any) (*sql.Rows, error) + Exec(query string, args ...any) (sql.Result, error) +} + +func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) { + t := reflect.TypeFor[T]() + + query, nargs := generateQuery(t) + + db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename)) + if err != nil { + return nil, err + } + defer db_query.Close() + + rows, err := db_query.Query(args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*T{} + + for rows.Next() { + item := new(T) + if err = mapRow(item, rows, nargs); err != nil { + return nil, err + } + list = append(list, item) + } + + return list, nil +} + +func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { + err = nil + + val := reflect.Indirect(reflect.ValueOf(store)) + scan_args := make([]interface{}, nargs) + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Addr().Interface() + } + + err = rows.Scan(scan_args...) + if err != nil { + return + } + + return nil +} + +func InsertReturnId(c QueryInterface, store interface{}, tablename string, returnName string) (id string, err error) { + t := reflect.TypeOf(store).Elem() + + query, nargs := generateQuery(t) + + query2 := "" + for i := 0; i < nargs; i += 1 { + query2 += fmt.Sprintf("$%d,", i+1) + } + // Remove last quotation + query2 = query2[0 : len(query2)-1] + + val := reflect.ValueOf(store).Elem() + scan_args := make([]interface{}, nargs) + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Interface() + } + + rows, err := c.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...) + if err != nil { + return + } + defer rows.Close() + + if !rows.Next() { + return "", CouldNotInsert + } + + err = rows.Scan(&id) + if err != nil { + return + } + + return +} + +func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error { + t := reflect.TypeOf(store).Elem() + + query, nargs := generateQuery(t) + + rows, err := db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + return NotFoundError + } + + err = nil + + val := reflect.ValueOf(store).Elem() + scan_args := make([]interface{}, nargs) + for i := 0; i < nargs; i++ { + valueField := val.Field(i) + scan_args[i] = valueField.Addr().Interface() + } + + err = rows.Scan(scan_args...) + if err != nil { + return err + } + + return nil +} + +func UpdateStatus(c QueryInterface, table string, id string, status int) (err error) { + _, err = c.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) + return +} diff --git a/logic/models/classes/list.go b/logic/models/classes/list.go index d8e63c6..735d555 100644 --- a/logic/models/classes/list.go +++ b/logic/models/classes/list.go @@ -3,7 +3,7 @@ package model_classes import ( "strconv" - "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) @@ -33,7 +33,7 @@ func HandleList(handle *Handle) { Model_id string } - err = utils.GetDBOnce(c, &class_row, "model_classes where id=$1", id) + err = GetDBOnce(c, &class_row, "model_classes where id=$1", id) if err == NotFoundError { return c.JsonBadRequest("Model Class not found!") } else if err != nil { @@ -47,7 +47,7 @@ func HandleList(handle *Handle) { Status int `json:"status"` } - rows, err := utils.GetDbMultitple[baserow](c, "model_data_point where class_id=$1 limit 11 offset $2", id, page*10) + rows, err := GetDbMultitple[baserow](c, "model_data_point where class_id=$1 limit 11 offset $2", id, page*10) if err != nil { return c.Error500(err) } @@ -60,7 +60,7 @@ func HandleList(handle *Handle) { max_len := min(11, len(rows)) - c.ShowMessage = false; + c.ShowMessage = false return c.SendJSON(ReturnType{ ImageList: rows[0:max_len], Page: page, diff --git a/logic/models/classes/main.go b/logic/models/classes/main.go index 22ac648..a13485a 100644 --- a/logic/models/classes/main.go +++ b/logic/models/classes/main.go @@ -4,7 +4,8 @@ import ( "database/sql" "errors" - . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) type ModelClass struct { @@ -15,7 +16,7 @@ type ModelClass struct { } func ListClasses(c *Context, model_id string) (cls []*ModelClass, err error) { - return GetDbMultitple[ModelClass](c, "model_classes where model_id=$1", model_id) + return GetDbMultitple[ModelClass](c, "model_classes where model_id=$1", model_id) } func ModelHasDataPoints(db *sql.DB, model_id string) (result bool, err error) { diff --git a/logic/models/data.go b/logic/models/data.go index b5e5dd4..20f88ab 100644 --- a/logic/models/data.go +++ b/logic/models/data.go @@ -12,6 +12,7 @@ import ( "sort" "strings" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" diff --git a/logic/models/delete.go b/logic/models/delete.go index 7d6e346..dd54d65 100644 --- a/logic/models/delete.go +++ b/logic/models/delete.go @@ -5,9 +5,9 @@ import ( "os" "path" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" - utils "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) func deleteModelJSON(c *Context, id string) *Error { @@ -47,7 +47,7 @@ func handleDelete(handle *Handle) { Status int } - err := utils.GetDBOnce(c, &model, "models where id=$1 and user_id=$2;", dat.Id, c.User.Id) + err := GetDBOnce(c, &model, "models where id=$1 and user_id=$2;", dat.Id, c.User.Id) if err == NotFoundError { return c.SendJSONStatus(http.StatusNotFound, "Model not found!") } else if err != nil { diff --git a/logic/models/edit.go b/logic/models/edit.go index 3f83567..f07f1a3 100644 --- a/logic/models/edit.go +++ b/logic/models/edit.go @@ -3,9 +3,9 @@ package models import ( "fmt" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" - "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) @@ -15,10 +15,10 @@ func handleEdit(handle *Handle) { return nil } - model, err_ := c.GetModelFromId("id") - if err_ != nil { - return err_ - } + model, err_ := c.GetModelFromId("id") + if err_ != nil { + return err_ + } wrong_number, err := model_classes.GetNumberOfWrongDataPoints(c.Db, model.Id) if err != nil { @@ -41,7 +41,7 @@ func handleEdit(handle *Handle) { NumberOfInvalidImages int `json:"number_of_invalid_images"` } - c.ShowMessage = false; + c.ShowMessage = false return c.SendJSON(ReturnType{ Classes: cls, HasData: has_data, @@ -54,10 +54,10 @@ func handleEdit(handle *Handle) { return nil } - model, err_ := c.GetModelFromId("id") - if err_ != nil { - return err_ - } + model, err_ := c.GetModelFromId("id") + if err_ != nil { + return err_ + } type defrow struct { Id string @@ -180,8 +180,8 @@ func handleEdit(handle *Handle) { Layers: lay, } } - - c.ShowMessage = false; + + c.ShowMessage = false return c.SendJSON(defsToReturn) }) @@ -207,14 +207,14 @@ func handleEdit(handle *Handle) { } var model rowmodel = rowmodel{} - err = utils.GetDBOnce(c, &model, "models where id=$1 and user_id=$2", id, c.User.Id) + err = GetDBOnce(c, &model, "models where id=$1 and user_id=$2", id, c.User.Id) if err == NotFoundError { return c.SendJSONStatus(404, "Model not found") } else if err != nil { return c.Error500(err) } - - c.ShowMessage = false + + c.ShowMessage = false return c.SendJSON(model) }) } diff --git a/logic/models/list.go b/logic/models/list.go index 2a5c4c9..ddafc1c 100644 --- a/logic/models/list.go +++ b/logic/models/list.go @@ -1,32 +1,32 @@ package models import ( - "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) // Auth level set when path is definied as 1 func handleStats(c *Context) *Error { - var b struct { - Id string `json:"id" validate:"required"` - } + var b struct { + Id string `json:"id" validate:"required"` + } - if _err := c.ToJSON(&b); _err != nil { - return _err; - } + if _err := c.ToJSON(&b); _err != nil { + return _err + } - type Row struct { - Name string `db:"mc.name" json:"name"` - Training string `db:"count(mdp.id) filter (where mdp.model_mode=1)" json:"training"` - Testing string `db:"count(mdp.id) filter (where mdp.model_mode=2)" json:"testing"` - } + type Row struct { + Name string `db:"mc.name" json:"name"` + Training string `db:"count(mdp.id) filter (where mdp.model_mode=1)" json:"training"` + Testing string `db:"count(mdp.id) filter (where mdp.model_mode=2)" json:"testing"` + } - rows, err := GetDbMultitple[Row](c, "model_data_point as mdp inner join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 group by mc.name order by mc.name asc;", b.Id) - if err != nil { - return c.Error500(err) - } - - c.ShowMessage = false + rows, err := GetDbMultitple[Row](c, "model_data_point as mdp inner join model_classes as mc on mc.id=mdp.class_id where mc.model_id=$1 group by mc.name order by mc.name asc;", b.Id) + if err != nil { + return c.Error500(err) + } + + c.ShowMessage = false return c.SendJSON(rows) } @@ -41,12 +41,12 @@ func handleList(handle *Handle) { Id string `json:"id"` } - got, err := utils.GetDbMultitple[Row](c, "models where user_id=$1", c.User.Id) + got, err := GetDbMultitple[Row](c, "models where user_id=$1", c.User.Id) if err != nil { return c.Error500(nil) } - - c.ShowMessage = true + + c.ShowMessage = true return c.SendJSON(got) }) diff --git a/logic/models/run.go b/logic/models/run.go index 7a064c9..1d72bac 100644 --- a/logic/models/run.go +++ b/logic/models/run.go @@ -5,9 +5,9 @@ import ( "os" "path" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" - . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" tf "github.com/galeone/tensorflow/tensorflow/go" "github.com/galeone/tensorflow/tensorflow/go/op" @@ -119,7 +119,7 @@ func runModelExp(base BasePack, model *BaseModel, def_id string, inputImage *tf. } func ClassifyTask(base BasePack, task Task) (err error) { - task.UpdateStatusLog(base, TASK_RUNNING, "Runner running task") + task.UpdateStatusLog(base, TASK_RUNNING, "Runner running task") model, err := GetBaseModel(base.GetDb(), task.ModelId) if err != nil { @@ -206,13 +206,13 @@ func ClassifyTask(base BasePack, task Task) (err error) { Confidence: confidence, } - err = task.SetResult(base, returnValue) - if err != nil { + err = task.SetResult(base, returnValue) + if err != nil { task.UpdateStatusLog(base, TASK_FAILED_RUNNING, "Failed to save model results") return - } + } task.UpdateStatusLog(base, TASK_DONE, "Model ran successfully") - return + return } diff --git a/logic/models/train/train.go b/logic/models/train/train.go index 0f1cde1..72370ba 100644 --- a/logic/models/train/train.go +++ b/logic/models/train/train.go @@ -14,6 +14,7 @@ import ( "strings" "text/template" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" model_classes "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/classes" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" @@ -336,7 +337,7 @@ func generateCvsExpandExp(c *Context, run_path string, model_id string, offset i // This is to load some extra data so that the model has more things to train on // - data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count * 10) + data_other, err := c.Db.Query("select mdp.id, mc.class_order, mdp.file_path from model_data_point as mdp inner join model_classes as mc on mc.id = mdp.class_id where mc.model_id = $1 and mdp.model_mode=$2 and mc.status=$3 limit $4;", model_id, model_classes.DATA_POINT_MODE_TRAINING, MODEL_CLASS_STATUS_TRAINED, count*10) if err != nil { return } @@ -1905,7 +1906,7 @@ func handleTrain(handle *Handle) { return c.Error500(err) } - c.ShowMessage = false + c.ShowMessage = false return nil }) @@ -1955,7 +1956,7 @@ func handleTrain(handle *Handle) { return c.Error500(err) } - c.ShowMessage = false + c.ShowMessage = false return nil }) } diff --git a/logic/tasks/handleUpload.go b/logic/tasks/handleUpload.go index c80352d..ead9947 100644 --- a/logic/tasks/handleUpload.go +++ b/logic/tasks/handleUpload.go @@ -7,6 +7,7 @@ import ( "os" "path" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" @@ -46,7 +47,7 @@ func handleUpload(handler *Handle) { var requestData struct { ModelId string `json:"id" validate:"required"` } - + _err := c.ParseJson(&requestData, json_data) if _err != nil { return _err @@ -115,7 +116,7 @@ func handleUpload(handler *Handle) { Id string `json:"task_id"` } { "Provided image does not match the model", id}) } - + UpdateStatus(c, "tasks", id, 1) return c.SendJSON(struct {Id string `json:"id"`}{id}) diff --git a/logic/tasks/list.go b/logic/tasks/list.go index 0343ae4..3af8dba 100644 --- a/logic/tasks/list.go +++ b/logic/tasks/list.go @@ -1,6 +1,7 @@ package tasks import ( + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" diff --git a/logic/tasks/runner/runner.go b/logic/tasks/runner/runner.go index 521d940..24d5b8f 100644 --- a/logic/tasks/runner/runner.go +++ b/logic/tasks/runner/runner.go @@ -8,9 +8,10 @@ import ( "github.com/charmbracelet/log" - . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/tasks/utils" - . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) /** @@ -36,13 +37,13 @@ func runner(db *sql.DB, task_channel chan Task, index int, back_channel chan int var err error base := BasePackStruct{ - Db: db, + Db: db, Logger: logger, } for task := range task_channel { logger.Info("Got task", "task", task) - + if task.TaskType == int(TASK_TYPE_CLASSIFICATION) { logger.Info("Classification Task") if err = ClassifyTask(base, task); err != nil { @@ -80,7 +81,7 @@ func attentionSeeker(config Config, back_channel chan int) { for true { back_channel <- 0 - + time.Sleep(t) } } diff --git a/logic/tasks/utils/utils.go b/logic/tasks/utils/utils.go index c9b44c3..2c05ca3 100644 --- a/logic/tasks/utils/utils.go +++ b/logic/tasks/utils/utils.go @@ -3,7 +3,7 @@ package tasks_utils import ( "time" - . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "github.com/goccy/go-json" ) diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 09b1c69..17c0675 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -204,6 +204,10 @@ func (c Context) Query(query string, args ...any) (*sql.Rows, error) { return c.Db.Query(query, args...) } +func (c Context) Exec(query string, args ...any) (sql.Result, error) { + return c.Db.Exec(query, args...) +} + func (c Context) Prepare(str string) (*sql.Stmt, error) { if c.Tx == nil { return c.Db.Prepare(str) diff --git a/logic/utils/utils.go b/logic/utils/utils.go index 7ced1b7..58d0a20 100644 --- a/logic/utils/utils.go +++ b/logic/utils/utils.go @@ -1,83 +1,11 @@ package utils import ( - "database/sql" "errors" - "fmt" - "io" - "mime" - "net/http" - "net/url" - "reflect" - "strconv" - "strings" - "github.com/charmbracelet/log" - "github.com/google/uuid" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" ) - -type BasePack interface { - GetDb() *sql.DB - GetLogger() *log.Logger -} - -type BasePackStruct struct { - Db *sql.DB - Logger *log.Logger -} - -func (b BasePackStruct) GetDb() (*sql.DB) { - return b.Db -} - -func (b BasePackStruct) GetLogger() (*log.Logger) { - return b.Logger -} - -func CheckEmpty(f url.Values, path string) bool { - return !f.Has(path) || f.Get(path) == "" -} - -func CheckNumber(f url.Values, path string, number *int) bool { - if CheckEmpty(f, path) { - fmt.Println("here", path) - fmt.Println(f.Get(path)) - return false - } - n, err := strconv.Atoi(f.Get(path)) - if err != nil { - fmt.Println(err) - return false - } - *number = n - return true -} - -func CheckFloat64(f url.Values, path string, number *float64) bool { - if CheckEmpty(f, path) { - fmt.Println("here", path) - fmt.Println(f.Get(path)) - return false - } - n, err := strconv.ParseFloat(f.Get(path), 64) - if err != nil { - fmt.Println(err) - return false - } - *number = n - return true -} - -func CheckId(f url.Values, path string) bool { - return !CheckEmpty(f, path) && IsValidUUID(f.Get(path)) -} - -func IsValidUUID(u string) bool { - _, err := uuid.Parse(u) - return err == nil -} - func GetIdFromUrl(c *Context, target string) (string, error) { if !c.R.URL.Query().Has(target) { return "", errors.New("Query does not have " + target) @@ -94,269 +22,3 @@ func GetIdFromUrl(c *Context, target string) (string, error) { return id, nil } - -type maxBytesReader struct { - w http.ResponseWriter - r io.ReadCloser // underlying reader - i int64 // max bytes initially, for MaxBytesError - n int64 // max bytes remaining - err error // sticky error -} - -type MaxBytesError struct { - Limit int64 -} - -func (e *MaxBytesError) Error() string { - // Due to Hyrum's law, this text cannot be changed. - return "http: request body too large" -} - -func (l *maxBytesReader) Read(p []byte) (n int, err error) { - if l.err != nil { - return 0, l.err - } - if len(p) == 0 { - return 0, nil - } - // If they asked for a 32KB byte read but only 5 bytes are - // remaining, no need to read 32KB. 6 bytes will answer the - // question of the whether we hit the limit or go past it. - // 0 < len(p) < 2^63 - if int64(len(p))-1 > l.n { - p = p[:l.n+1] - } - n, err = l.r.Read(p) - - if int64(n) <= l.n { - l.n -= int64(n) - l.err = err - return n, err - } - - n = int(l.n) - l.n = 0 - - // The server code and client code both use - // maxBytesReader. This "requestTooLarge" check is - // only used by the server code. To prevent binaries - // which only using the HTTP Client code (such as - // cmd/go) from also linking in the HTTP server, don't - // use a static type assertion to the server - // "*response" type. Check this interface instead: - type requestTooLarger interface { - requestTooLarge() - } - if res, ok := l.w.(requestTooLarger); ok { - res.requestTooLarge() - } - l.err = &MaxBytesError{l.i} - return n, l.err -} - -func (l *maxBytesReader) Close() error { - return l.r.Close() -} - -func MyParseForm(r *http.Request) (vs url.Values, err error) { - if r.Body == nil { - err = errors.New("missing form body") - return - } - ct := r.Header.Get("Content-Type") - // RFC 7231, section 3.1.1.5 - empty type - // MAY be treated as application/octet-stream - if ct == "" { - ct = "application/octet-stream" - } - ct, _, err = mime.ParseMediaType(ct) - switch { - case ct == "application/x-www-form-urlencoded": - var reader io.Reader = r.Body - maxFormSize := int64(1<<63 - 1) - if _, ok := r.Body.(*maxBytesReader); !ok { - maxFormSize = int64(10 << 20) // 10 MB is a lot of text. - reader = io.LimitReader(r.Body, maxFormSize+1) - } - b, e := io.ReadAll(reader) - if e != nil { - if err == nil { - err = e - } - break - } - if int64(len(b)) > maxFormSize { - err = errors.New("http: POST too large") - return - } - vs, e = url.ParseQuery(string(b)) - if err == nil { - err = e - } - case ct == "multipart/form-data": - // handled by ParseMultipartForm (which is calling us, or should be) - // TODO(bradfitz): there are too many possible - // orders to call too many functions here. - // Clean this up and write more tests. - // request_test.go contains the start of this, - // in TestParseMultipartFormOrder and others. - } - return -} - -type JustId struct{ Id string `json:"id" validate:"required"` } - -type Generic struct{ reflect.Type } - -var NotFoundError = errors.New("Not found") -var CouldNotInsert = errors.New("Could not insert") - -func generateQuery(t reflect.Type) (query string, nargs int) { - nargs = t.NumField() - query = "" - - for i := 0; i < nargs; i += 1 { - field := t.Field(i) - name, ok := field.Tag.Lookup("db") - if !ok { - name = field.Name - } - - if name == "__nil__" { - continue - } - query += strings.ToLower(name) + "," - } - - // Remove the last comma - query = query[0 : len(query)-1] - - return -} - -type QueryInterface interface { - Prepare(str string) (*sql.Stmt, error) - Query(query string, args ...any) (*sql.Rows, error) -} - -func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) { - t := reflect.TypeFor[T]() - - query, nargs := generateQuery(t) - - db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename)) - if err != nil { - return nil, err - } - defer db_query.Close() - - rows, err := db_query.Query(args...) - if err != nil { - return nil, err - } - defer rows.Close() - - list := []*T{} - - for rows.Next() { - item := new(T) - if err = mapRow(item, rows, nargs); err != nil { - return nil, err - } - list = append(list, item) - } - - return list, nil -} - -func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { - err = nil - - val := reflect.Indirect(reflect.ValueOf(store)) - scan_args := make([]interface{}, nargs) - for i := 0; i < nargs; i++ { - valueField := val.Field(i) - scan_args[i] = valueField.Addr().Interface() - } - - err = rows.Scan(scan_args...) - if err != nil { - return - } - - return nil -} - -func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) { - t := reflect.TypeOf(store).Elem() - - query, nargs := generateQuery(t) - - query2 := "" - for i := 0; i < nargs; i += 1 { - query2 += fmt.Sprintf("$%d,", i+1) - } - // Remove last quotation - query2 = query2[0 : len(query2)-1] - - val := reflect.ValueOf(store).Elem() - scan_args := make([]interface{}, nargs) - for i := 0; i < nargs; i++ { - valueField := val.Field(i) - scan_args[i] = valueField.Interface() - } - - rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...) - if err != nil { - return - } - defer rows.Close() - - if !rows.Next() { - return "", CouldNotInsert - } - - err = rows.Scan(&id) - if err != nil { - return - } - - return -} - -func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error { - t := reflect.TypeOf(store).Elem() - - query, nargs := generateQuery(t) - - rows, err := db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) - if err != nil { - return err - } - defer rows.Close() - - if !rows.Next() { - return NotFoundError - } - - err = nil - - val := reflect.ValueOf(store).Elem() - scan_args := make([]interface{}, nargs) - for i := 0; i < nargs; i++ { - valueField := val.Field(i) - scan_args[i] = valueField.Addr().Interface() - } - - err = rows.Scan(scan_args...) - if err != nil { - return err - } - - return nil -} - -func UpdateStatus(c *Context, table string, id string, status int) (err error) { - _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) - return -} diff --git a/users.go b/users.go index 4feef5a..21dc06f 100644 --- a/users.go +++ b/users.go @@ -10,8 +10,8 @@ import ( "golang.org/x/crypto/bcrypt" + . "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" - "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/utils" ) @@ -226,7 +226,7 @@ func usersEndpints(db *sql.DB, handle *Handle) { Id string } - err := utils.GetDBOnce(c, &data, "users where id=$1", dat.Id) + err := GetDBOnce(c, &data, "users where id=$1", dat.Id) if err == NotFoundError { return c.JsonBadRequest("User does not exist") } else if err != nil { @@ -235,7 +235,7 @@ func usersEndpints(db *sql.DB, handle *Handle) { } var data JustId - err := utils.GetDBOnce(c, &data, "users where email=$1", dat.Email) + err := GetDBOnce(c, &data, "users where email=$1", dat.Email) if err != nil && err != NotFoundError { return c.E500M("Falied to get data for user", err) } @@ -260,7 +260,7 @@ func usersEndpints(db *sql.DB, handle *Handle) { User_Type int } - err = utils.GetDBOnce(c, &user, "users where id=$1", dat.Id) + err = GetDBOnce(c, &user, "users where id=$1", dat.Id) if err != nil { return c.E500M("Failed to get user data", err) } diff --git a/webpage/src/lib/requests.svelte.ts b/webpage/src/lib/requests.svelte.ts index 831261b..2f98462 100644 --- a/webpage/src/lib/requests.svelte.ts +++ b/webpage/src/lib/requests.svelte.ts @@ -1,101 +1,101 @@ import { goto } from '$app/navigation'; import { userStore } from 'routes/UserStore.svelte'; -const API = "/api"; +const API = '/api'; export async function get(url: string) { const headers = new Headers(); //headers.append('content-type', 'application/json'); headers.append('response-type', 'application/json'); - if (userStore.user) { - headers.append('token', userStore.user.token); - } + if (userStore.user) { + headers.append('token', userStore.user.token); + } - let r = await fetch(`${API}/${url}`, { + const r = await fetch(`${API}/${url}`, { method: 'GET', - headers: headers, + headers: headers }); - if (r.status === 401) { - userStore.user = undefined; - goto("/login") - } else if (r.status !== 200) { - throw r; - } + if (r.status === 401) { + userStore.user = undefined; + goto('/login'); + } else if (r.status !== 200) { + throw r; + } - return r.json(); + return r.json(); } export async function post(url: string, body: any) { const headers = new Headers(); headers.append('content-type', 'application/json'); - if (userStore.user) { - headers.append('token', userStore.user.token); - } + if (userStore.user) { + headers.append('token', userStore.user.token); + } - let r = await fetch(`${API}/${url}`, { + const r = await fetch(`${API}/${url}`, { method: 'POST', headers: headers, - body: JSON.stringify(body), + body: JSON.stringify(body) }); - - if (r.status === 401) { - userStore.user = undefined; - goto("/login") - throw r; - } else if (r.status !== 200) { - throw r; - } - return r.json(); + if (r.status === 401) { + userStore.user = undefined; + goto('/login'); + throw r; + } else if (r.status !== 200) { + throw r; + } + + return r.json(); } export async function rdelete(url: string, body: any) { const headers = new Headers(); headers.append('content-type', 'application/json'); - if (userStore.user) { - headers.append('token', userStore.user.token); - } + if (userStore.user) { + headers.append('token', userStore.user.token); + } - let r = await fetch(`${API}/${url}`, { + const r = await fetch(`${API}/${url}`, { method: 'DELETE', headers: headers, - body: JSON.stringify(body), + body: JSON.stringify(body) }); - if (r.status === 401) { - userStore.user = undefined; - goto("/login") - } else if (r.status !== 200) { - throw r; - } + if (r.status === 401) { + userStore.user = undefined; + goto('/login'); + } else if (r.status !== 200) { + throw r; + } - return r.json(); + return r.json(); } export async function postFormData(url: string, body: FormData) { const headers = new Headers(); //headers.append('content-type', 'multipart/form-data'); headers.append('response-type', 'application/json'); - if (userStore.user) { - headers.append('token', userStore.user.token); - } + if (userStore.user) { + headers.append('token', userStore.user.token); + } - let r = await fetch(`${API}/${url}`, { + const r = await fetch(`${API}/${url}`, { method: 'POST', headers: headers, - body: body, + body: body }); - if (r.status == 401) { - userStore.user = undefined; - goto('/login'); - throw new Error("Redirect"); - } + if (r.status == 401) { + userStore.user = undefined; + goto('/login'); + throw new Error('Redirect'); + } - if (r.status !== 200) { - throw r; - } + if (r.status !== 200) { + throw r; + } - return r.json(); + return r.json(); }