package utils import ( "database/sql" "errors" "fmt" "io" "mime" "net/http" "net/url" "reflect" "strconv" "strings" "github.com/charmbracelet/log" "github.com/google/uuid" ) type BasePack interface { GetDb() *sql.DB GetLogger() *log.Logger } type BasePackStruct struct { Db *sql.DB Logger *log.Logger } func (b BasePackStruct) GetDb() (*sql.DB) { return b.Db } func (b BasePackStruct) GetLogger() (*log.Logger) { return b.Logger } func CheckEmpty(f url.Values, path string) bool { return !f.Has(path) || f.Get(path) == "" } func CheckNumber(f url.Values, path string, number *int) bool { if CheckEmpty(f, path) { fmt.Println("here", path) fmt.Println(f.Get(path)) return false } n, err := strconv.Atoi(f.Get(path)) if err != nil { fmt.Println(err) return false } *number = n return true } func CheckFloat64(f url.Values, path string, number *float64) bool { if CheckEmpty(f, path) { fmt.Println("here", path) fmt.Println(f.Get(path)) return false } n, err := strconv.ParseFloat(f.Get(path), 64) if err != nil { fmt.Println(err) return false } *number = n return true } func CheckId(f url.Values, path string) bool { return !CheckEmpty(f, path) && IsValidUUID(f.Get(path)) } func IsValidUUID(u string) bool { _, err := uuid.Parse(u) return err == nil } func GetIdFromUrl(c *Context, target string) (string, error) { if !c.R.URL.Query().Has(target) { return "", errors.New("Query does not have " + target) } id := c.R.URL.Query().Get("id") if len(id) == 0 { return "", errors.New("Query is empty for " + target) } if !IsValidUUID(id) { return "", errors.New("Value of query is not a valid uuid for " + target) } return id, nil } type maxBytesReader struct { w http.ResponseWriter r io.ReadCloser // underlying reader i int64 // max bytes initially, for MaxBytesError n int64 // max bytes remaining err error // sticky error } type MaxBytesError struct { Limit int64 } func (e *MaxBytesError) Error() string { // Due to Hyrum's law, this text cannot be changed. return "http: request body too large" } func (l *maxBytesReader) Read(p []byte) (n int, err error) { if l.err != nil { return 0, l.err } if len(p) == 0 { return 0, nil } // If they asked for a 32KB byte read but only 5 bytes are // remaining, no need to read 32KB. 6 bytes will answer the // question of the whether we hit the limit or go past it. // 0 < len(p) < 2^63 if int64(len(p))-1 > l.n { p = p[:l.n+1] } n, err = l.r.Read(p) if int64(n) <= l.n { l.n -= int64(n) l.err = err return n, err } n = int(l.n) l.n = 0 // The server code and client code both use // maxBytesReader. This "requestTooLarge" check is // only used by the server code. To prevent binaries // which only using the HTTP Client code (such as // cmd/go) from also linking in the HTTP server, don't // use a static type assertion to the server // "*response" type. Check this interface instead: type requestTooLarger interface { requestTooLarge() } if res, ok := l.w.(requestTooLarger); ok { res.requestTooLarge() } l.err = &MaxBytesError{l.i} return n, l.err } func (l *maxBytesReader) Close() error { return l.r.Close() } func MyParseForm(r *http.Request) (vs url.Values, err error) { if r.Body == nil { err = errors.New("missing form body") return } ct := r.Header.Get("Content-Type") // RFC 7231, section 3.1.1.5 - empty type // MAY be treated as application/octet-stream if ct == "" { ct = "application/octet-stream" } ct, _, err = mime.ParseMediaType(ct) switch { case ct == "application/x-www-form-urlencoded": var reader io.Reader = r.Body maxFormSize := int64(1<<63 - 1) if _, ok := r.Body.(*maxBytesReader); !ok { maxFormSize = int64(10 << 20) // 10 MB is a lot of text. reader = io.LimitReader(r.Body, maxFormSize+1) } b, e := io.ReadAll(reader) if e != nil { if err == nil { err = e } break } if int64(len(b)) > maxFormSize { err = errors.New("http: POST too large") return } vs, e = url.ParseQuery(string(b)) if err == nil { err = e } case ct == "multipart/form-data": // handled by ParseMultipartForm (which is calling us, or should be) // TODO(bradfitz): there are too many possible // orders to call too many functions here. // Clean this up and write more tests. // request_test.go contains the start of this, // in TestParseMultipartFormOrder and others. } return } type JustId struct{ Id string `json:"id" validate:"required"` } type Generic struct{ reflect.Type } var NotFoundError = errors.New("Not found") var CouldNotInsert = errors.New("Could not insert") func generateQuery(t reflect.Type) (query string, nargs int) { nargs = t.NumField() query = "" for i := 0; i < nargs; i += 1 { field := t.Field(i) name, ok := field.Tag.Lookup("db") if !ok { name = field.Name } if name == "__nil__" { continue } query += strings.ToLower(name) + "," } // Remove the last comma query = query[0 : len(query)-1] return } type QueryInterface interface { Prepare(str string) (*sql.Stmt, error) Query(query string, args ...any) (*sql.Rows, error) } func GetDbMultitple[T interface{}](c QueryInterface, tablename string, args ...any) ([]*T, error) { t := reflect.TypeFor[T]() query, nargs := generateQuery(t) db_query, err := c.Prepare(fmt.Sprintf("select %s from %s", query, tablename)) if err != nil { return nil, err } defer db_query.Close() rows, err := db_query.Query(args...) if err != nil { return nil, err } defer rows.Close() list := []*T{} for rows.Next() { item := new(T) if err = mapRow(item, rows, nargs); err != nil { return nil, err } list = append(list, item) } return list, nil } func mapRow(store interface{}, rows *sql.Rows, nargs int) (err error) { err = nil val := reflect.Indirect(reflect.ValueOf(store)) scan_args := make([]interface{}, nargs) for i := 0; i < nargs; i++ { valueField := val.Field(i) scan_args[i] = valueField.Addr().Interface() } err = rows.Scan(scan_args...) if err != nil { return } return nil } func InsertReturnId(c *Context, store interface{}, tablename string, returnName string) (id string, err error) { t := reflect.TypeOf(store).Elem() query, nargs := generateQuery(t) query2 := "" for i := 0; i < nargs; i += 1 { query2 += fmt.Sprintf("$%d,", i+1) } // Remove last quotation query2 = query2[0 : len(query2)-1] val := reflect.ValueOf(store).Elem() scan_args := make([]interface{}, nargs) for i := 0; i < nargs; i++ { valueField := val.Field(i) scan_args[i] = valueField.Interface() } rows, err := c.Db.Query(fmt.Sprintf("insert into %s (%s) values (%s) returning %s", tablename, query, query2, returnName), scan_args...) if err != nil { return } defer rows.Close() if !rows.Next() { return "", CouldNotInsert } err = rows.Scan(&id) if err != nil { return } return } func GetDBOnce(db QueryInterface, store interface{}, tablename string, args ...any) error { t := reflect.TypeOf(store).Elem() query, nargs := generateQuery(t) rows, err := db.Query(fmt.Sprintf("select %s from %s", query, tablename), args...) if err != nil { return err } defer rows.Close() if !rows.Next() { return NotFoundError } err = nil val := reflect.ValueOf(store).Elem() scan_args := make([]interface{}, nargs) for i := 0; i < nargs; i++ { valueField := val.Field(i) scan_args[i] = valueField.Addr().Interface() } err = rows.Scan(scan_args...) if err != nil { return err } return nil } func UpdateStatus(c *Context, table string, id string, status int) (err error) { _, err = c.Db.Exec(fmt.Sprintf("update %s set status = $1 where id = $2", table), status, id) return }