fyp/logic/utils/utils.go

283 lines
6.0 KiB
Go
Raw Normal View History

package utils
import (
"database/sql"
"errors"
"fmt"
"io"
"mime"
"net/http"
"net/url"
2024-02-08 18:20:58 +00:00
"reflect"
"strconv"
2024-02-08 18:20:58 +00:00
"strings"
"github.com/google/uuid"
)
func CheckEmpty(f url.Values, path string) bool {
return !f.Has(path) || f.Get(path) == ""
}
func CheckNumber(f url.Values, path string, number *int) bool {
2024-02-08 18:20:58 +00:00
if CheckEmpty(f, path) {
fmt.Println("here", path)
fmt.Println(f.Get(path))
return false
}
n, err := strconv.Atoi(f.Get(path))
if err != nil {
fmt.Println(err)
return false
}
*number = n
return true
}
2023-10-21 00:26:52 +01:00
func CheckFloat64(f url.Values, path string, number *float64) bool {
2024-02-08 18:20:58 +00:00
if CheckEmpty(f, path) {
fmt.Println("here", path)
fmt.Println(f.Get(path))
return false
}
n, err := strconv.ParseFloat(f.Get(path), 64)
if err != nil {
fmt.Println(err)
return false
}
*number = n
return true
2023-10-21 00:26:52 +01:00
}
func CheckId(f url.Values, path string) bool {
2024-02-08 18:20:58 +00:00
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
}
func IsValidUUID(u string) bool {
_, err := uuid.Parse(u)
return err == nil
}
func GetIdFromUrl(r *http.Request, target string) (string, error) {
if !r.URL.Query().Has(target) {
2024-02-08 18:20:58 +00:00
return "", errors.New("Query does not have " + target)
}
id := r.URL.Query().Get("id")
if len(id) == 0 {
2024-02-08 18:20:58 +00:00
return "", errors.New("Query is empty for " + target)
}
if !IsValidUUID(id) {
2024-02-08 18:20:58 +00:00
return "", errors.New("Value of query is not a valid uuid for " + target)
}
2024-02-08 18:20:58 +00:00
return id, nil
}
type maxBytesReader struct {
w http.ResponseWriter
r io.ReadCloser // underlying reader
i int64 // max bytes initially, for MaxBytesError
n int64 // max bytes remaining
err error // sticky error
}
type MaxBytesError struct {
Limit int64
}
func (e *MaxBytesError) Error() string {
// Due to Hyrum's law, this text cannot be changed.
return "http: request body too large"
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
if l.err != nil {
return 0, l.err
}
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
// 0 < len(p) < 2^63
if int64(len(p))-1 > l.n {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
n = int(l.n)
l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
l.err = &MaxBytesError{l.i}
return n, l.err
}
func (l *maxBytesReader) Close() error {
return l.r.Close()
}
func MyParseForm(r *http.Request) (vs url.Values, err error) {
if r.Body == nil {
err = errors.New("missing form body")
return
}
ct := r.Header.Get("Content-Type")
// RFC 7231, section 3.1.1.5 - empty type
// MAY be treated as application/octet-stream
if ct == "" {
ct = "application/octet-stream"
}
ct, _, err = mime.ParseMediaType(ct)
switch {
case ct == "application/x-www-form-urlencoded":
var reader io.Reader = r.Body
maxFormSize := int64(1<<63 - 1)
if _, ok := r.Body.(*maxBytesReader); !ok {
maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
reader = io.LimitReader(r.Body, maxFormSize+1)
}
b, e := io.ReadAll(reader)
if e != nil {
if err == nil {
err = e
}
break
}
if int64(len(b)) > maxFormSize {
err = errors.New("http: POST too large")
return
}
vs, e = url.ParseQuery(string(b))
if err == nil {
err = e
}
case ct == "multipart/form-data":
// handled by ParseMultipartForm (which is calling us, or should be)
// TODO(bradfitz): there are too many possible
// orders to call too many functions here.
// Clean this up and write more tests.
// request_test.go contains the start of this,
// in TestParseMultipartFormOrder and others.
}
return
}
2024-02-08 18:20:58 +00:00
type JustId struct { Id string }
2024-02-08 18:20:58 +00:00
type Generic struct{ reflect.Type }
var NotFoundError = errors.New("Not found")
func GetDbMultitple[T interface{}](c *Context, tablename string, args ...any) ([]*T, error) {
t := reflect.TypeFor[T]()
nargs := t.NumField()
2024-02-08 18:20:58 +00:00
query := ""
for i := 0; i < nargs; i += 1 {
query += strings.ToLower(t.Field(i).Name) + ","
}
// Remove the last comma
query = query[0 : len(query)-1]
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*T{}
for rows.Next() {
item := new(T)
if err = mapRow(item, rows, nargs); err != nil {
return nil, err
}
list = append(list, item)
}
return list, nil
}
2024-02-08 18:20:58 +00:00
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
err = nil
val := reflect.Indirect(reflect.ValueOf(store))
scan_args := make([]interface{}, nargs);
for i := 0; i < nargs; i++ {
valueField := val.Field(i)
scan_args[i] = valueField.Addr().Interface()
}
err = rows.Scan(scan_args...)
if err != nil {
return
}
return nil
}
func GetDBOnce(c *Context, store interface{}, tablename string, args ...any) error {
t := reflect.TypeOf(store).Elem()
2024-02-08 18:20:58 +00:00
nargs := t.NumField()
query := ""
for i := 0; i < nargs; i += 1 {
query += strings.ToLower(t.Field(i).Name) + ","
}
// Remove the last comma
2024-02-08 18:20:58 +00:00
query = query[0 : len(query)-1]
rows, err := c.Db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return NotFoundError
}
err = nil
2024-02-08 18:20:58 +00:00
val := reflect.ValueOf(store).Elem()
scan_args := make([]interface{}, nargs);
for i := 0; i < nargs; i++ {
valueField := val.Field(i)
scan_args[i] = valueField.Addr().Interface()
}
err = rows.Scan(scan_args...)
if err != nil {
return err
}
return nil
}