fyp/logic/utils/handler.go

577 lines
13 KiB
Go
Raw Normal View History

package utils
2023-09-18 00:26:42 +01:00
import (
"bytes"
"database/sql"
"errors"
2023-09-18 00:26:42 +01:00
"fmt"
"io"
2023-09-18 00:26:42 +01:00
"net/http"
"os"
"path"
"runtime/debug"
2023-09-18 00:26:42 +01:00
"strings"
"time"
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
2023-10-20 12:37:56 +01:00
"github.com/charmbracelet/log"
"github.com/go-playground/validator/v10"
2024-02-23 23:49:23 +00:00
"github.com/goccy/go-json"
2023-09-18 00:26:42 +01:00
)
type AnyMap = map[string]interface{}
2023-09-18 00:26:42 +01:00
type Error struct {
Code int
2024-03-09 10:52:08 +00:00
data any
2023-09-18 00:26:42 +01:00
}
type HandleFunc struct {
path string
2024-03-09 10:52:08 +00:00
fn func(c *Context) *Error
}
2023-09-18 00:26:42 +01:00
type Handler interface {
New()
Startup()
2024-03-09 10:52:08 +00:00
Get(fn func(c *Context) *Error)
Post(fn func(c *Context) *Error)
2023-09-18 00:26:42 +01:00
}
type Handle struct {
Db *sql.DB
gets []HandleFunc
posts []HandleFunc
deletes []HandleFunc
Config Config
validate *validator.Validate
}
2023-09-18 00:26:42 +01:00
2024-03-09 10:52:08 +00:00
func handleError(err *Error, c *Context) {
2023-09-18 00:26:42 +01:00
if err != nil {
2024-03-09 10:52:08 +00:00
c.Logger.Warn("Responded with error", "code", err.Code, "data", err.data)
c.Writer.WriteHeader(err.Code)
var e *Error
if err.data != nil {
e = c.SendJSON(err.data)
} else {
e = c.SendJSON(500)
2023-09-18 00:26:42 +01:00
}
2024-03-09 10:52:08 +00:00
if e != nil {
c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack()))
2024-03-09 10:52:08 +00:00
c.Writer.Write([]byte("505"))
2023-09-18 00:26:42 +01:00
}
}
}
2024-04-13 23:55:01 +01:00
// This group of functions defines some endpoints
2024-03-09 10:52:08 +00:00
func (x *Handle) Get(path string, fn func(c *Context) *Error) {
x.gets = append(x.gets, HandleFunc{path, fn})
}
2024-04-14 15:19:32 +01:00
func (x *Handle) GetAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) {
2024-04-13 23:55:01 +01:00
inner_fn := func(c *Context) *Error {
if !c.CheckAuthLevel(authLevel) {
return nil
}
return fn(c)
}
x.gets = append(x.gets, HandleFunc{path, inner_fn})
}
2024-03-09 10:52:08 +00:00
func (x *Handle) Post(path string, fn func(c *Context) *Error) {
x.posts = append(x.posts, HandleFunc{path, fn})
}
2024-04-14 15:19:32 +01:00
func (x *Handle) PostAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) {
inner_fn := func(c *Context) *Error {
if !c.CheckAuthLevel(authLevel) {
return nil
}
return fn(c)
}
x.posts = append(x.posts, HandleFunc{path, inner_fn})
}
func PostAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, dat *T) *Error) {
inner_fn := func(c *Context) *Error {
2024-04-14 15:19:32 +01:00
if !c.CheckAuthLevel(authLevel) {
return nil
}
obj := new(T)
if err := c.ToJSON(obj); err != nil {
return err
}
return fn(c, obj)
}
x.posts = append(x.posts, HandleFunc{path, inner_fn})
}
func PostAuthFormJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, dat *T, file []byte) *Error) {
inner_fn := func(c *Context) *Error {
if !c.CheckAuthLevel(1) {
return nil
}
read_form, err := c.R.MultipartReader()
if err != nil {
return c.JsonBadRequest("Please provide a valid form data request!")
}
var json_data string
var file []byte
for {
part, err_part := read_form.NextPart()
if err_part == io.EOF {
break
} else if err_part != nil {
return c.JsonBadRequest("Please provide a valid form data request!")
}
if part.FormName() == "json_data" {
buf := new(bytes.Buffer)
buf.ReadFrom(part)
json_data = buf.String()
}
if part.FormName() == "file" {
buf := new(bytes.Buffer)
buf.ReadFrom(part)
file = buf.Bytes()
}
}
if !c.CheckAuthLevel(authLevel) {
return nil
}
dat := new(T)
decoder := json.NewDecoder(strings.NewReader(json_data))
if err := c.decodeAndValidade(decoder, dat); err != nil {
return err
}
return fn(c, dat, file)
}
x.posts = append(x.posts, HandleFunc{path, inner_fn})
}
2024-03-09 10:52:08 +00:00
func (x *Handle) Delete(path string, fn func(c *Context) *Error) {
x.deletes = append(x.deletes, HandleFunc{path, fn})
}
2024-04-14 15:19:32 +01:00
func (x *Handle) DeleteAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) {
inner_fn := func(c *Context) *Error {
if !c.CheckAuthLevel(authLevel) {
return nil
}
return fn(c)
}
x.posts = append(x.posts, HandleFunc{path, inner_fn})
}
2024-04-14 15:19:32 +01:00
func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, obj *T) *Error) {
inner_fn := func(c *Context) *Error {
if !c.CheckAuthLevel(authLevel) {
return nil
}
obj := new(T)
if err := c.ToJSON(obj); err != nil {
return err
}
return fn(c, obj)
}
x.deletes = append(x.deletes, HandleFunc{path, inner_fn})
}
2024-04-13 23:55:01 +01:00
// This function handles loop of a list of possible handler functions
func handleLoop(array []HandleFunc, context *Context) {
defer func() {
if r := recover(); r != nil {
context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack()))
handleError(&Error{500, "500"}, context)
}
}()
2024-04-12 20:36:23 +01:00
2024-04-13 23:55:01 +01:00
for _, s := range array {
2024-03-09 10:52:08 +00:00
if s.path == context.R.URL.Path {
handleError(s.fn(context), context)
return
}
}
context.ShowMessage = false
2024-03-09 10:52:08 +00:00
handleError(&Error{404, "Endpoint not found"}, context)
}
2024-04-14 15:19:32 +01:00
func (c *Context) CheckAuthLevel(authLevel dbtypes.UserType) bool {
if authLevel > 0 {
2024-04-13 23:55:01 +01:00
if c.User == nil {
contextlessLogoff(c.Writer)
return false
}
2024-04-14 15:19:32 +01:00
if c.User.UserType < int(authLevel) {
2024-04-13 23:55:01 +01:00
c.Writer.WriteHeader(http.StatusUnauthorized)
e := c.SendJSON("Not Authorized")
if e != nil {
c.Writer.WriteHeader(http.StatusInternalServerError)
c.Writer.Write([]byte("You can not access this resource!"))
}
return false
}
}
return true
}
type Context struct {
2024-04-08 14:17:13 +01:00
Token *string
User *dbtypes.User
Logger *log.Logger
Db *sql.DB
Writer http.ResponseWriter
R *http.Request
Tx *sql.Tx
ShowMessage bool
Handle *Handle
2024-04-08 14:17:13 +01:00
}
2024-04-13 23:55:01 +01:00
// This is required for this to integrate simealy with my orl
func (c Context) GetDb() *sql.DB {
return c.Db
2024-04-12 20:36:23 +01:00
}
func (c Context) GetLogger() *log.Logger {
return c.Logger
2024-04-12 20:36:23 +01:00
}
func (c Context) GetHost() string {
return c.Handle.Config.Hostname
}
2024-04-12 20:36:23 +01:00
func (c Context) Query(query string, args ...any) (*sql.Rows, error) {
return c.Db.Query(query, args...)
2024-04-12 20:36:23 +01:00
}
2024-04-14 14:51:16 +01:00
func (c Context) Exec(query string, args ...any) (sql.Result, error) {
return c.Db.Exec(query, args...)
}
2024-04-08 14:17:13 +01:00
func (c Context) Prepare(str string) (*sql.Stmt, error) {
if c.Tx == nil {
return c.Db.Prepare(str)
}
return c.Tx.Prepare(str)
}
var TransactionAlreadyStarted = errors.New("Transaction already started")
var TransactionNotStarted = errors.New("Transaction not started")
func (c *Context) StartTx() error {
if c.Tx != nil {
return TransactionAlreadyStarted
}
var err error = nil
c.Tx, err = c.Db.Begin()
return err
}
func (c *Context) CommitTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Commit()
if err != nil {
return err
}
c.Tx = nil
return nil
}
func (c *Context) RollbackTx() error {
if c.Tx == nil {
return TransactionNotStarted
}
err := c.Tx.Rollback()
if err != nil {
return err
}
c.Tx = nil
return nil
2023-10-06 09:45:47 +01:00
}
2024-04-12 20:36:23 +01:00
/**
* Parse and vailidates the json
*/
2024-04-12 20:36:23 +01:00
func (c Context) ParseJson(dat any, str string) *Error {
decoder := json.NewDecoder(strings.NewReader(str))
return c.decodeAndValidade(decoder, dat)
2024-04-12 20:36:23 +01:00
}
2024-02-23 23:49:23 +00:00
2024-04-12 20:36:23 +01:00
func (c Context) ToJSON(dat any) *Error {
2024-03-09 10:52:08 +00:00
decoder := json.NewDecoder(c.R.Body)
return c.decodeAndValidade(decoder, dat)
2024-04-12 20:36:23 +01:00
}
2024-02-23 23:49:23 +00:00
2024-04-12 20:36:23 +01:00
func (c Context) decodeAndValidade(decoder *json.Decoder, dat any) *Error {
2024-03-09 10:52:08 +00:00
err := decoder.Decode(dat)
if err != nil {
c.Logger.Error("Failed to decode json", "dat", dat, "err", err)
return c.JsonBadRequest("Bad Request! Invalid json passed!")
2024-03-09 10:52:08 +00:00
}
err = c.Handle.validate.Struct(dat)
if err != nil {
c.Logger.Error("Failed invalid json passed", "dat", dat, "err", err)
return c.JsonBadRequest("Bad Request! Invalid json passed!")
}
2024-03-09 10:52:08 +00:00
return nil
2024-02-23 23:49:23 +00:00
}
func (c Context) SendJSON(dat any) *Error {
2024-03-09 10:52:08 +00:00
c.Writer.Header().Add("content-type", "application/json")
text, err := json.Marshal(dat)
if err != nil {
return c.Error500(err)
}
if _, err = c.Writer.Write(text); err != nil {
return c.Error500(err)
}
return nil
2024-02-23 23:49:23 +00:00
}
func (c Context) SendJSONStatus(status int, dat any) *Error {
2024-03-09 10:52:08 +00:00
c.Writer.Header().Add("content-type", "application/json")
c.Writer.WriteHeader(status)
text, err := json.Marshal(dat)
if err != nil {
return c.Error500(err)
}
if _, err = c.Writer.Write(text); err != nil {
return c.Error500(err)
}
return nil
2024-02-23 23:49:23 +00:00
}
func (c Context) JsonBadRequest(dat any) *Error {
2024-03-09 10:52:08 +00:00
c.SetReportCaller(true)
c.Logger.Warn("Request failed with a bad request", "dat", dat)
c.SetReportCaller(false)
return c.ErrorCode(nil, 404, dat)
}
func (c Context) JsonErrorBadRequest(err error, dat any) *Error {
2023-10-20 12:37:56 +01:00
c.SetReportCaller(true)
2024-03-09 10:52:08 +00:00
c.Logger.Error("Error while processing request", "err", err, "dat", dat)
2023-10-20 12:37:56 +01:00
c.SetReportCaller(false)
2024-03-09 10:52:08 +00:00
return c.SendJSONStatus(http.StatusBadRequest, dat)
}
2024-04-14 15:19:32 +01:00
func (c *Context) GetModelFromId(id_path string) (*dbtypes.BaseModel, *Error) {
2024-03-09 10:52:08 +00:00
id, err := GetIdFromUrl(c, id_path)
2023-10-06 10:46:45 +01:00
if err != nil {
2024-03-09 10:52:08 +00:00
return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found")
2023-10-06 10:46:45 +01:00
}
2024-04-14 15:19:32 +01:00
model, err := dbtypes.GetBaseModel(c.Db, id)
if err == dbtypes.ModelNotFoundError {
2024-03-09 10:52:08 +00:00
return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found")
} else if err != nil {
return nil, c.Error500(err)
2023-10-20 12:37:56 +01:00
}
2023-10-06 10:46:45 +01:00
2024-04-08 14:17:13 +01:00
return model, nil
2024-03-09 10:52:08 +00:00
}
func ModelUpdateStatus(c dbtypes.BasePack, id string, status int) {
_, err := c.GetDb().Exec("update models set status=$1 where id=$2;", status, id)
2024-03-09 10:52:08 +00:00
if err != nil {
c.GetLogger().Error("Failed to update model status", "err", err)
c.GetLogger().Warn("TODO Maybe handle better")
2024-03-09 10:52:08 +00:00
}
2023-10-06 10:46:45 +01:00
}
func (c Context) SetReportCaller(report bool) {
2023-10-20 12:37:56 +01:00
if report {
c.Logger.SetCallerOffset(2)
c.Logger.SetReportCaller(true)
} else {
c.Logger.SetCallerOffset(1)
c.Logger.SetReportCaller(false)
}
2023-10-06 10:46:45 +01:00
}
2024-03-09 10:52:08 +00:00
func (c Context) ErrorCode(err error, code int, data any) *Error {
2023-10-20 12:37:56 +01:00
if code == 400 {
c.SetReportCaller(true)
c.Logger.Warn("When returning BadRequest(400) please use context.Error400\n")
c.SetReportCaller(false)
}
2023-10-06 09:45:47 +01:00
if err != nil {
c.Logger.Error("Something went wrong returning with:", "Error", err)
2023-10-06 09:45:47 +01:00
}
2024-03-09 10:52:08 +00:00
return &Error{code, data}
2023-10-12 12:08:12 +01:00
}
2024-04-13 23:55:01 +01:00
// Deprecated: Use the E500M instead
2023-10-06 10:46:45 +01:00
func (c Context) Error500(err error) *Error {
return c.ErrorCode(err, http.StatusInternalServerError, nil)
}
2024-04-12 20:36:23 +01:00
func (c Context) E500M(msg string, err error) *Error {
return c.ErrorCode(err, http.StatusInternalServerError, msg)
}
var LogoffError = errors.New("Invalid token!")
2024-03-09 10:52:08 +00:00
func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseWriter) (*Context, error) {
var token *string
2023-10-20 12:37:56 +01:00
logger := log.NewWithOptions(os.Stdout, log.Options{
2024-02-08 18:20:58 +00:00
ReportCaller: true,
2023-10-20 12:37:56 +01:00
ReportTimestamp: true,
TimeFormat: time.Kitchen,
Prefix: r.URL.Path,
})
2023-10-12 12:08:12 +01:00
2024-03-09 10:52:08 +00:00
t := r.Header.Get("token")
if t != "" {
token = &t
}
// TODO check that the token is still valid
if token == nil {
return &Context{
Logger: logger,
Db: handler.Db,
Writer: w,
R: r,
ShowMessage: true,
Handle: &x,
}, nil
}
user, err := dbtypes.UserFromToken(x.Db, *token)
if err != nil {
return nil, errors.Join(err, LogoffError)
}
return &Context{token, user, logger, handler.Db, w, r, nil, true, &x}, nil
}
2024-03-09 10:52:08 +00:00
func contextlessLogoff(w http.ResponseWriter) {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("\"Not Authorized\""))
}
2024-03-02 12:45:49 +00:00
func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) {
2024-03-09 10:52:08 +00:00
http.HandleFunc("/api"+pathTest, func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1)
2024-03-02 12:45:49 +00:00
user_path := r.URL.Path[len(pathTest):]
found := false
index := -1
for i, fileType := range fileTypes {
if strings.HasSuffix(user_path, fileType) {
found = true
index = i
2023-10-20 12:37:56 +01:00
break
}
}
if !found {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("File not found"))
return
}
bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path))
if err != nil {
fmt.Println(err)
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("Failed to load file"))
return
}
w.Header().Set("Content-Type", contentTypes[index])
w.Write(bytes)
})
}
func NewHandler(db *sql.DB, config Config) *Handle {
var gets []HandleFunc
var posts []HandleFunc
var deletes []HandleFunc
validate := validator.New()
x := &Handle{db, gets, posts, deletes, config, validate}
2023-09-18 00:26:42 +01:00
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
2024-03-09 10:52:08 +00:00
w.Header().Add("Access-Control-Allow-Origin", "*")
w.Header().Add("Access-Control-Allow-Headers", "*")
w.Header().Add("Access-Control-Allow-Methods", "*")
// Decide answertype
/* if !(r.Header.Get("content-type") == "application/json" || r.Header.Get("response-type") == "application/json") {
2024-03-09 10:52:08 +00:00
w.WriteHeader(500)
w.Write([]byte("Please set content-type to application/json or set response-type to application/json\n"))
return
}*/
2024-03-09 10:52:08 +00:00
if !strings.HasPrefix(r.URL.Path, "/api") {
w.WriteHeader(404)
w.Write([]byte("Path not found"))
return
}
2024-03-09 10:52:08 +00:00
r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1)
//Login state
2024-03-09 10:52:08 +00:00
context, err := x.createContext(x, r, w)
if err != nil {
2024-03-09 10:52:08 +00:00
contextlessLogoff(w)
return
}
2024-03-09 10:52:08 +00:00
// context.Logger.Info("Parsing", "path", r.URL.Path)
2024-02-23 23:49:23 +00:00
if r.Method == "GET" {
2024-04-13 23:55:01 +01:00
handleLoop(x.gets, context)
2024-04-08 14:17:13 +01:00
} else if r.Method == "POST" {
2024-04-13 23:55:01 +01:00
handleLoop(x.posts, context)
2024-04-08 14:17:13 +01:00
} else if r.Method == "DELETE" {
2024-04-13 23:55:01 +01:00
handleLoop(x.deletes, context)
2024-04-08 14:17:13 +01:00
} else if r.Method == "OPTIONS" {
// do nothing
} else {
panic("TODO handle method: " + r.Method)
}
if context.ShowMessage {
context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
}
})
return x
2023-09-18 00:26:42 +01:00
}
func (x Handle) Startup() {
2024-02-02 16:16:26 +00:00
log.Info("Starting up!\n")
2023-10-24 22:35:11 +01:00
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", x.Config.Port), nil))
2023-09-18 00:26:42 +01:00
}