363 lines
7.6 KiB
Go
363 lines
7.6 KiB
Go
package utils
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"net/url"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/charmbracelet/log"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
|
|
type BasePack interface {
|
|
GetDb() *sql.DB
|
|
GetLogger() *log.Logger
|
|
}
|
|
|
|
type BasePackStruct struct {
|
|
Db *sql.DB
|
|
Logger *log.Logger
|
|
}
|
|
|
|
func (b BasePackStruct) GetDb() (*sql.DB) {
|
|
return b.Db
|
|
}
|
|
|
|
func (b BasePackStruct) GetLogger() (*log.Logger) {
|
|
return b.Logger
|
|
}
|
|
|
|
func CheckEmpty(f url.Values, path string) bool {
|
|
return !f.Has(path) || f.Get(path) == ""
|
|
}
|
|
|
|
func CheckNumber(f url.Values, path string, number *int) bool {
|
|
if CheckEmpty(f, path) {
|
|
fmt.Println("here", path)
|
|
fmt.Println(f.Get(path))
|
|
return false
|
|
}
|
|
n, err := strconv.Atoi(f.Get(path))
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return false
|
|
}
|
|
*number = n
|
|
return true
|
|
}
|
|
|
|
func CheckFloat64(f url.Values, path string, number *float64) bool {
|
|
if CheckEmpty(f, path) {
|
|
fmt.Println("here", path)
|
|
fmt.Println(f.Get(path))
|
|
return false
|
|
}
|
|
n, err := strconv.ParseFloat(f.Get(path), 64)
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
return false
|
|
}
|
|
*number = n
|
|
return true
|
|
}
|
|
|
|
func CheckId(f url.Values, path string) bool {
|
|
return !CheckEmpty(f, path) && IsValidUUID(f.Get(path))
|
|
}
|
|
|
|
func IsValidUUID(u string) bool {
|
|
_, err := uuid.Parse(u)
|
|
return err == nil
|
|
}
|
|
|
|
func GetIdFromUrl(c *Context, target string) (string, error) {
|
|
if !c.R.URL.Query().Has(target) {
|
|
return "", errors.New("Query does not have " + target)
|
|
}
|
|
|
|
id := c.R.URL.Query().Get("id")
|
|
if len(id) == 0 {
|
|
return "", errors.New("Query is empty for " + target)
|
|
}
|
|
|
|
if !IsValidUUID(id) {
|
|
return "", errors.New("Value of query is not a valid uuid for " + target)
|
|
}
|
|
|
|
return id, nil
|
|
}
|
|
|
|
type maxBytesReader struct {
|
|
w http.ResponseWriter
|
|
r io.ReadCloser // underlying reader
|
|
i int64 // max bytes initially, for MaxBytesError
|
|
n int64 // max bytes remaining
|
|
err error // sticky error
|
|
}
|
|
|
|
type MaxBytesError struct {
|
|
Limit int64
|
|
}
|
|
|
|
func (e *MaxBytesError) Error() string {
|
|
// Due to Hyrum's law, this text cannot be changed.
|
|
return "http: request body too large"
|
|
}
|
|
|
|
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
|
|
if l.err != nil {
|
|
return 0, l.err
|
|
}
|
|
if len(p) == 0 {
|
|
return 0, nil
|
|
}
|
|
// If they asked for a 32KB byte read but only 5 bytes are
|
|
// remaining, no need to read 32KB. 6 bytes will answer the
|
|
// question of the whether we hit the limit or go past it.
|
|
// 0 < len(p) < 2^63
|
|
if int64(len(p))-1 > l.n {
|
|
p = p[:l.n+1]
|
|
}
|
|
n, err = l.r.Read(p)
|
|
|
|
if int64(n) <= l.n {
|
|
l.n -= int64(n)
|
|
l.err = err
|
|
return n, err
|
|
}
|
|
|
|
n = int(l.n)
|
|
l.n = 0
|
|
|
|
// The server code and client code both use
|
|
// maxBytesReader. This "requestTooLarge" check is
|
|
// only used by the server code. To prevent binaries
|
|
// which only using the HTTP Client code (such as
|
|
// cmd/go) from also linking in the HTTP server, don't
|
|
// use a static type assertion to the server
|
|
// "*response" type. Check this interface instead:
|
|
type requestTooLarger interface {
|
|
requestTooLarge()
|
|
}
|
|
if res, ok := l.w.(requestTooLarger); ok {
|
|
res.requestTooLarge()
|
|
}
|
|
l.err = &MaxBytesError{l.i}
|
|
return n, l.err
|
|
}
|
|
|
|
func (l *maxBytesReader) Close() error {
|
|
return l.r.Close()
|
|
}
|
|
|
|
func MyParseForm(r *http.Request) (vs url.Values, err error) {
|
|
if r.Body == nil {
|
|
err = errors.New("missing form body")
|
|
return
|
|
}
|
|
ct := r.Header.Get("Content-Type")
|
|
// RFC 7231, section 3.1.1.5 - empty type
|
|
// MAY be treated as application/octet-stream
|
|
if ct == "" {
|
|
ct = "application/octet-stream"
|
|
}
|
|
ct, _, err = mime.ParseMediaType(ct)
|
|
switch {
|
|
case ct == "application/x-www-form-urlencoded":
|
|
var reader io.Reader = r.Body
|
|
maxFormSize := int64(1<<63 - 1)
|
|
if _, ok := r.Body.(*maxBytesReader); !ok {
|
|
maxFormSize = int64(10 << 20) // 10 MB is a lot of text.
|
|
reader = io.LimitReader(r.Body, maxFormSize+1)
|
|
}
|
|
b, e := io.ReadAll(reader)
|
|
if e != nil {
|
|
if err == nil {
|
|
err = e
|
|
}
|
|
break
|
|
}
|
|
if int64(len(b)) > maxFormSize {
|
|
err = errors.New("http: POST too large")
|
|
return
|
|
}
|
|
vs, e = url.ParseQuery(string(b))
|
|
if err == nil {
|
|
err = e
|
|
}
|
|
case ct == "multipart/form-data":
|
|
// handled by ParseMultipartForm (which is calling us, or should be)
|
|
// TODO(bradfitz): there are too many possible
|
|
// orders to call too many functions here.
|
|
// Clean this up and write more tests.
|
|
// request_test.go contains the start of this,
|
|
// in TestParseMultipartFormOrder and others.
|
|
}
|
|
return
|
|
}
|
|
|
|
type JustId struct{ Id string }
|
|
|
|
type Generic struct{ reflect.Type }
|
|
|
|
var NotFoundError = errors.New("Not found")
|
|
var CouldNotInsert = errors.New("Could not insert")
|
|
|
|
func generateQuery(t reflect.Type) (query string, nargs int) {
|
|
nargs = t.NumField()
|
|
query = ""
|
|
|
|
for i := 0; i < nargs; i += 1 {
|
|
field := t.Field(i)
|
|
name, ok := field.Tag.Lookup("db")
|
|
if !ok {
|
|
name = field.Name
|
|
}
|
|
|
|
if name == "__nil__" {
|
|
continue
|
|
}
|
|
query += strings.ToLower(name) + ","
|
|
}
|
|
|
|
// Remove the last comma
|
|
query = query[0 : len(query)-1]
|
|
|
|
return
|
|
}
|
|
|
|
type QueryInterface interface {
|
|
Prepare(str string) (*sql.Stmt, error)
|
|
Query(query string, args ...any) (*sql.Rows, error)
|
|
}
|
|
|
|
func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) {
|
|
t := reflect.TypeFor[T]()
|
|
|
|
query, nargs := generateQuery(t)
|
|
|
|
db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer db_query.Close()
|
|
|
|
rows, err := db_query.Query(args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
list := []*T{}
|
|
|
|
for rows.Next() {
|
|
item := new(T)
|
|
if err = mapRow(item, rows, nargs); err != nil {
|
|
return nil, err
|
|
}
|
|
list = append(list, item)
|
|
}
|
|
|
|
return list, nil
|
|
}
|
|
|
|
func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) {
|
|
err = nil
|
|
|
|
val := reflect.Indirect(reflect.ValueOf(store))
|
|
scan_args := make([]interface{}, nargs)
|
|
for i := 0; i < nargs; i++ {
|
|
valueField := val.Field(i)
|
|
scan_args[i] = valueField.Addr().Interface()
|
|
}
|
|
|
|
err = rows.Scan(scan_args...)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) {
|
|
t := reflect.TypeOf(store).Elem()
|
|
|
|
query, nargs := generateQuery(t)
|
|
|
|
query2 := ""
|
|
for i := 0; i < nargs; i += 1 {
|
|
query2 += fmt.Sprintf("$%d,", i+1)
|
|
}
|
|
// Remove last quotation
|
|
query2 = query2[0 : len(query2)-1]
|
|
|
|
val := reflect.ValueOf(store).Elem()
|
|
scan_args := make([]interface{}, nargs)
|
|
for i := 0; i < nargs; i++ {
|
|
valueField := val.Field(i)
|
|
scan_args[i] = valueField.Interface()
|
|
}
|
|
|
|
rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
return "", CouldNotInsert
|
|
}
|
|
|
|
err = rows.Scan(&id)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error {
|
|
t := reflect.TypeOf(store).Elem()
|
|
|
|
query, nargs := generateQuery(t)
|
|
|
|
rows, err := db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
return NotFoundError
|
|
}
|
|
|
|
err = nil
|
|
|
|
val := reflect.ValueOf(store).Elem()
|
|
scan_args := make([]interface{}, nargs)
|
|
for i := 0; i < nargs; i++ {
|
|
valueField := val.Field(i)
|
|
scan_args[i] = valueField.Addr().Interface()
|
|
}
|
|
|
|
err = rows.Scan(scan_args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func UpdateStatus(c *Context, table string, id string, status int) (err error) {
|
|
_, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id)
|
|
return
|
|
}
|