584 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			584 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package utils
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"context"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net/http"
 | |
| 	"os"
 | |
| 	"path"
 | |
| 	"runtime/debug"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"git.andr3h3nriqu3s.com/andr3/fyp/logic/db"
 | |
| 	dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
 | |
| 	"github.com/charmbracelet/log"
 | |
| 	"github.com/go-playground/validator/v10"
 | |
| 	"github.com/goccy/go-json"
 | |
| 	"github.com/jackc/pgx/v5"
 | |
| 	"github.com/jackc/pgx/v5/pgconn"
 | |
| )
 | |
| 
 | |
| type AnyMap = map[string]interface{}
 | |
| 
 | |
| type Error struct {
 | |
| 	Code int
 | |
| 	data any
 | |
| }
 | |
| 
 | |
| type HandleFunc struct {
 | |
| 	path string
 | |
| 	fn   func(c *Context) *Error
 | |
| }
 | |
| 
 | |
| type Handler interface {
 | |
| 	New()
 | |
| 	Startup()
 | |
| 	Get(fn func(c *Context) *Error)
 | |
| 	Post(fn func(c *Context) *Error)
 | |
| }
 | |
| 
 | |
| type Handle struct {
 | |
| 	Db       db.Db
 | |
| 	gets     []HandleFunc
 | |
| 	posts    []HandleFunc
 | |
| 	deletes  []HandleFunc
 | |
| 	DataMap  map[string]interface{}
 | |
| 	Config   Config
 | |
| 	validate *validator.Validate
 | |
| }
 | |
| 
 | |
| func handleError(err *Error, c *Context) {
 | |
| 	if err != nil {
 | |
| 		c.Logger.Warn("Responded with error", "code", err.Code, "data", err.data)
 | |
| 		c.Writer.WriteHeader(err.Code)
 | |
| 		var e *Error
 | |
| 		if err.data != nil {
 | |
| 			e = c.SendJSON(err.data)
 | |
| 		} else {
 | |
| 			e = c.SendJSON(500)
 | |
| 		}
 | |
| 		if e != nil {
 | |
| 			c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack()))
 | |
| 			c.Writer.Write([]byte("505"))
 | |
| 		}
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // This group of functions defines some endpoints
 | |
| func (x *Handle) Get(path string, fn func(c *Context) *Error) {
 | |
| 	x.gets = append(x.gets, HandleFunc{path, fn})
 | |
| }
 | |
| 
 | |
| func (x *Handle) GetAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) {
 | |
| 	inner_fn := func(c *Context) *Error {
 | |
| 		if !c.CheckAuthLevel(authLevel) {
 | |
| 			return nil
 | |
| 		}
 | |
| 		return fn(c)
 | |
| 	}
 | |
| 	x.gets = append(x.gets, HandleFunc{path, inner_fn})
 | |
| }
 | |
| func (x *Handle) Post(path string, fn func(c *Context) *Error) {
 | |
| 	x.posts = append(x.posts, HandleFunc{path, fn})
 | |
| }
 | |
| 
 | |
| func (x *Handle) PostAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) {
 | |
| 	inner_fn := func(c *Context) *Error {
 | |
| 		if !c.CheckAuthLevel(authLevel) {
 | |
| 			return nil
 | |
| 		}
 | |
| 		return fn(c)
 | |
| 	}
 | |
| 	x.posts = append(x.posts, HandleFunc{path, inner_fn})
 | |
| }
 | |
| 
 | |
| func PostAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, dat *T) *Error) {
 | |
| 	inner_fn := func(c *Context) *Error {
 | |
| 		if !c.CheckAuthLevel(authLevel) {
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		obj := new(T)
 | |
| 
 | |
| 		if err := c.ToJSON(obj); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(c, obj)
 | |
| 	}
 | |
| 
 | |
| 	x.posts = append(x.posts, HandleFunc{path, inner_fn})
 | |
| }
 | |
| 
 | |
| func PostAuthFormJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, dat *T, file []byte) *Error) {
 | |
| 	inner_fn := func(c *Context) *Error {
 | |
| 		if !c.CheckAuthLevel(1) {
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		read_form, err := c.R.MultipartReader()
 | |
| 		if err != nil {
 | |
| 			return c.JsonBadRequest("Please provide a valid form data request!")
 | |
| 		}
 | |
| 
 | |
| 		var json_data string
 | |
| 		var file []byte
 | |
| 
 | |
| 		for {
 | |
| 			part, err_part := read_form.NextPart()
 | |
| 			if err_part == io.EOF {
 | |
| 				break
 | |
| 			} else if err_part != nil {
 | |
| 				return c.JsonBadRequest("Please provide a valid form data request!")
 | |
| 			}
 | |
| 			if part.FormName() == "json_data" {
 | |
| 				buf := new(bytes.Buffer)
 | |
| 				buf.ReadFrom(part)
 | |
| 				json_data = buf.String()
 | |
| 			}
 | |
| 			if part.FormName() == "file" {
 | |
| 				buf := new(bytes.Buffer)
 | |
| 				buf.ReadFrom(part)
 | |
| 				file = buf.Bytes()
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if !c.CheckAuthLevel(authLevel) {
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		dat := new(T)
 | |
| 
 | |
| 		decoder := json.NewDecoder(strings.NewReader(json_data))
 | |
| 		if err := c.decodeAndValidade(decoder, dat); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(c, dat, file)
 | |
| 	}
 | |
| 
 | |
| 	x.posts = append(x.posts, HandleFunc{path, inner_fn})
 | |
| }
 | |
| 
 | |
| func (x *Handle) Delete(path string, fn func(c *Context) *Error) {
 | |
| 	x.deletes = append(x.deletes, HandleFunc{path, fn})
 | |
| }
 | |
| 
 | |
| func (x *Handle) DeleteAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) {
 | |
| 	inner_fn := func(c *Context) *Error {
 | |
| 		if !c.CheckAuthLevel(authLevel) {
 | |
| 			return nil
 | |
| 		}
 | |
| 		return fn(c)
 | |
| 	}
 | |
| 	x.posts = append(x.posts, HandleFunc{path, inner_fn})
 | |
| }
 | |
| 
 | |
| func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, obj *T) *Error) {
 | |
| 	inner_fn := func(c *Context) *Error {
 | |
| 		if !c.CheckAuthLevel(authLevel) {
 | |
| 			return nil
 | |
| 		}
 | |
| 
 | |
| 		obj := new(T)
 | |
| 
 | |
| 		if err := c.ToJSON(obj); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 
 | |
| 		return fn(c, obj)
 | |
| 	}
 | |
| 
 | |
| 	x.deletes = append(x.deletes, HandleFunc{path, inner_fn})
 | |
| }
 | |
| 
 | |
| // This function handles loop of a list of possible handler functions
 | |
| func handleLoop(array []HandleFunc, context *Context) {
 | |
| 	defer func() {
 | |
| 		if r := recover(); r != nil {
 | |
| 			context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack()))
 | |
| 			handleError(&Error{500, "500"}, context)
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	for _, s := range array {
 | |
| 		if s.path == context.R.URL.Path {
 | |
| 			handleError(s.fn(context), context)
 | |
| 			return
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	context.ShowMessage = false
 | |
| 	handleError(&Error{404, "Endpoint not found"}, context)
 | |
| }
 | |
| 
 | |
| func (c *Context) CheckAuthLevel(authLevel dbtypes.UserType) bool {
 | |
| 	if authLevel > 0 {
 | |
| 		if c.User == nil {
 | |
| 			contextlessLogoff(c.Writer)
 | |
| 			return false
 | |
| 		}
 | |
| 		if c.User.UserType < int(authLevel) {
 | |
| 			c.Writer.WriteHeader(http.StatusUnauthorized)
 | |
| 			e := c.SendJSON("Not Authorized")
 | |
| 			if e != nil {
 | |
| 				c.Writer.WriteHeader(http.StatusInternalServerError)
 | |
| 				c.Writer.Write([]byte("You can not access this resource!"))
 | |
| 			}
 | |
| 			return false
 | |
| 		}
 | |
| 	}
 | |
| 	return true
 | |
| }
 | |
| 
 | |
| type Context struct {
 | |
| 	Token       *string
 | |
| 	User        *dbtypes.User
 | |
| 	Logger      *log.Logger
 | |
| 	Db          db.Db
 | |
| 	Writer      http.ResponseWriter
 | |
| 	R           *http.Request
 | |
| 	Tx          pgx.Tx
 | |
| 	ShowMessage bool
 | |
| 	Handle      *Handle
 | |
| }
 | |
| 
 | |
| // This is required for this to integrate simealy with my orl
 | |
| func (c Context) GetDb() db.Db {
 | |
| 	return c.Db
 | |
| }
 | |
| 
 | |
| func (c Context) GetLogger() *log.Logger {
 | |
| 	return c.Logger
 | |
| }
 | |
| 
 | |
| func (c Context) GetHost() string {
 | |
| 	return c.Handle.Config.Hostname
 | |
| }
 | |
| 
 | |
| func (c Context) Query(query string, args ...any) (pgx.Rows, error) {
 | |
| 	if c.Tx == nil {
 | |
| 		return c.Db.Query(query, args...)
 | |
| 	}
 | |
| 	return c.Tx.Query(context.Background(), query, args...)
 | |
| }
 | |
| 
 | |
| func (c Context) Exec(query string, args ...any) (pgconn.CommandTag, error) {
 | |
| 	if c.Tx == nil {
 | |
| 		return c.Db.Exec(query, args...)
 | |
| 	}
 | |
| 	return c.Tx.Exec(context.Background(), query, args...)
 | |
| }
 | |
| 
 | |
| func (c Context) Begin() (pgx.Tx, error) {
 | |
| 	return c.Db.Begin()
 | |
| }
 | |
| 
 | |
| var TransactionAlreadyStarted = errors.New("Transaction already started")
 | |
| var TransactionNotStarted = errors.New("Transaction not started")
 | |
| 
 | |
| func (c *Context) StartTx() error {
 | |
| 	if c.Tx != nil {
 | |
| 		return TransactionAlreadyStarted
 | |
| 	}
 | |
| 	var err error = nil
 | |
| 	c.Tx, err = c.Begin()
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (c *Context) CommitTx() error {
 | |
| 	if c.Tx == nil {
 | |
| 		return TransactionNotStarted
 | |
| 	}
 | |
| 	err := c.Tx.Commit(context.Background())
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	c.Tx = nil
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *Context) RollbackTx() error {
 | |
| 	if c.Tx == nil {
 | |
| 		return TransactionNotStarted
 | |
| 	}
 | |
| 	err := c.Tx.Rollback(context.Background())
 | |
| 	if err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 	c.Tx = nil
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| /**
 | |
| * Parse and vailidates the json
 | |
|  */
 | |
| func (c Context) ParseJson(dat any, str string) *Error {
 | |
| 	decoder := json.NewDecoder(strings.NewReader(str))
 | |
| 
 | |
| 	return c.decodeAndValidade(decoder, dat)
 | |
| }
 | |
| 
 | |
| func (c Context) ToJSON(dat any) *Error {
 | |
| 	decoder := json.NewDecoder(c.R.Body)
 | |
| 
 | |
| 	return c.decodeAndValidade(decoder, dat)
 | |
| }
 | |
| 
 | |
| func (c Context) decodeAndValidade(decoder *json.Decoder, dat any) *Error {
 | |
| 	err := decoder.Decode(dat)
 | |
| 	if err != nil {
 | |
| 		c.Logger.Error("Failed to decode json", "dat", dat, "err", err)
 | |
| 		return c.JsonBadRequest("Bad Request! Invalid json passed!")
 | |
| 	}
 | |
| 
 | |
| 	err = c.Handle.validate.Struct(dat)
 | |
| 	if err != nil {
 | |
| 		c.Logger.Error("Failed invalid json passed", "dat", dat, "err", err)
 | |
| 		return c.JsonBadRequest("Bad Request! Invalid json passed!")
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c Context) SendJSON(dat any) *Error {
 | |
| 	c.Writer.Header().Add("content-type", "application/json")
 | |
| 	text, err := json.Marshal(dat)
 | |
| 	if err != nil {
 | |
| 		return c.Error500(err)
 | |
| 	}
 | |
| 	if _, err = c.Writer.Write(text); err != nil {
 | |
| 		return c.Error500(err)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c Context) SendJSONStatus(status int, dat any) *Error {
 | |
| 	c.Writer.Header().Add("content-type", "application/json")
 | |
| 	c.Writer.WriteHeader(status)
 | |
| 	text, err := json.Marshal(dat)
 | |
| 	if err != nil {
 | |
| 		return c.Error500(err)
 | |
| 	}
 | |
| 	if _, err = c.Writer.Write(text); err != nil {
 | |
| 		return c.Error500(err)
 | |
| 	}
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c Context) JsonBadRequest(dat any) *Error {
 | |
| 	c.SetReportCaller(true)
 | |
| 	c.Logger.Warn("Request failed with a bad request", "dat", dat)
 | |
| 	c.SetReportCaller(false)
 | |
| 	return c.ErrorCode(nil, 404, dat)
 | |
| }
 | |
| 
 | |
| func (c Context) JsonErrorBadRequest(err error, dat any) *Error {
 | |
| 	c.SetReportCaller(true)
 | |
| 	c.Logger.Error("Error while processing request", "err", err, "dat", dat)
 | |
| 	c.SetReportCaller(false)
 | |
| 	return c.SendJSONStatus(http.StatusBadRequest, dat)
 | |
| }
 | |
| 
 | |
| func (c *Context) GetModelFromId(id_path string) (*dbtypes.BaseModel, *Error) {
 | |
| 
 | |
| 	id, err := GetIdFromUrl(c, id_path)
 | |
| 	if err != nil {
 | |
| 		return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found")
 | |
| 	}
 | |
| 
 | |
| 	model, err := dbtypes.GetBaseModel(c.Db, id)
 | |
| 	if err == dbtypes.NotFoundError {
 | |
| 		return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found")
 | |
| 	} else if err != nil {
 | |
| 		return nil, c.Error500(err)
 | |
| 	}
 | |
| 
 | |
| 	return model, nil
 | |
| }
 | |
| 
 | |
| func ModelUpdateStatus(c dbtypes.BasePack, id string, status int) {
 | |
| 	_, err := c.GetDb().Exec("update models set status=$1 where id=$2;", status, id)
 | |
| 	if err != nil {
 | |
| 		c.GetLogger().Error("Failed to update model status", "err", err)
 | |
| 		c.GetLogger().Warn("TODO Maybe handle better")
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c Context) SetReportCaller(report bool) {
 | |
| 	if report {
 | |
| 		c.Logger.SetCallerOffset(2)
 | |
| 		c.Logger.SetReportCaller(true)
 | |
| 	} else {
 | |
| 		c.Logger.SetCallerOffset(1)
 | |
| 		c.Logger.SetReportCaller(false)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (c Context) ErrorCode(err error, code int, data any) *Error {
 | |
| 	if code == 400 {
 | |
| 		c.SetReportCaller(true)
 | |
| 		c.Logger.Warn("When returning BadRequest(400) please use context.Error400\n")
 | |
| 		c.SetReportCaller(false)
 | |
| 	}
 | |
| 	if err != nil {
 | |
| 		c.Logger.Error("Something went wrong returning with:", "Error", err)
 | |
| 	}
 | |
| 	return &Error{code, data}
 | |
| }
 | |
| 
 | |
| // Deprecated: Use the E500M instead
 | |
| func (c Context) Error500(err error) *Error {
 | |
| 	return c.ErrorCode(err, http.StatusInternalServerError, nil)
 | |
| }
 | |
| 
 | |
| func (c Context) E500M(msg string, err error) *Error {
 | |
| 	return c.ErrorCode(err, http.StatusInternalServerError, msg)
 | |
| }
 | |
| 
 | |
| var LogoffError = errors.New("Invalid token!")
 | |
| 
 | |
| func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseWriter) (*Context, error) {
 | |
| 
 | |
| 	var token *string
 | |
| 
 | |
| 	logger := log.NewWithOptions(os.Stdout, log.Options{
 | |
| 		ReportCaller:    true,
 | |
| 		ReportTimestamp: true,
 | |
| 		TimeFormat:      time.Kitchen,
 | |
| 		Prefix:          r.URL.Path,
 | |
| 	})
 | |
| 
 | |
| 	t := r.Header.Get("token")
 | |
| 	if t != "" {
 | |
| 		token = &t
 | |
| 	}
 | |
| 
 | |
| 	// TODO check that the token is still valid
 | |
| 	if token == nil {
 | |
| 		return &Context{
 | |
| 			Logger:      logger,
 | |
| 			Db:          handler.Db,
 | |
| 			Writer:      w,
 | |
| 			R:           r,
 | |
| 			ShowMessage: true,
 | |
| 			Handle:      &x,
 | |
| 		}, nil
 | |
| 	}
 | |
| 
 | |
| 	user, err := dbtypes.UserFromToken(x.Db, *token)
 | |
| 	if err != nil {
 | |
| 		return nil, errors.Join(err, LogoffError)
 | |
| 	}
 | |
| 
 | |
| 	return &Context{token, user, logger, handler.Db, w, r, nil, true, &x}, nil
 | |
| }
 | |
| 
 | |
| func contextlessLogoff(w http.ResponseWriter) {
 | |
| 	w.WriteHeader(http.StatusUnauthorized)
 | |
| 	w.Write([]byte("\"Not Authorized\""))
 | |
| }
 | |
| 
 | |
| func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) {
 | |
| 	http.HandleFunc("/api"+pathTest, func(w http.ResponseWriter, r *http.Request) {
 | |
| 		r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1)
 | |
| 
 | |
| 		user_path := r.URL.Path[len(pathTest):]
 | |
| 
 | |
| 		found := false
 | |
| 		index := -1
 | |
| 
 | |
| 		for i, fileType := range fileTypes {
 | |
| 			if strings.HasSuffix(user_path, fileType) {
 | |
| 				found = true
 | |
| 				index = i
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if !found {
 | |
| 			w.WriteHeader(http.StatusNotFound)
 | |
| 			w.Write([]byte("File not found"))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path))
 | |
| 		if err != nil {
 | |
| 			fmt.Println(err)
 | |
| 			w.WriteHeader(http.StatusInternalServerError)
 | |
| 			w.Write([]byte("Failed to load file"))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		w.Header().Set("Content-Type", contentTypes[index])
 | |
| 		w.Write(bytes)
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func NewHandler(db db.Db, config Config) *Handle {
 | |
| 	var gets []HandleFunc
 | |
| 	var posts []HandleFunc
 | |
| 	var deletes []HandleFunc
 | |
| 	validate := validator.New()
 | |
| 	x := &Handle{db, gets, posts, deletes, map[string]interface{}{}, config, validate}
 | |
| 
 | |
| 	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
 | |
| 
 | |
| 		w.Header().Add("Access-Control-Allow-Origin", "*")
 | |
| 		w.Header().Add("Access-Control-Allow-Headers", "*")
 | |
| 		w.Header().Add("Access-Control-Allow-Methods", "*")
 | |
| 
 | |
| 		// Decide answertype
 | |
| 		/* if !(r.Header.Get("content-type") == "application/json" || r.Header.Get("response-type") == "application/json") {
 | |
| 			w.WriteHeader(500)
 | |
| 			w.Write([]byte("Please set content-type to application/json or set response-type to application/json\n"))
 | |
| 			return
 | |
| 		}*/
 | |
| 
 | |
| 		if !strings.HasPrefix(r.URL.Path, "/api") {
 | |
| 			w.WriteHeader(404)
 | |
| 			w.Write([]byte("Path not found"))
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1)
 | |
| 
 | |
| 		//Login state
 | |
| 		context, err := x.createContext(x, r, w)
 | |
| 		if err != nil {
 | |
| 			contextlessLogoff(w)
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// context.Logger.Info("Parsing", "path", r.URL.Path)
 | |
| 
 | |
| 		if r.Method == "GET" {
 | |
| 			handleLoop(x.gets, context)
 | |
| 		} else if r.Method == "POST" {
 | |
| 			handleLoop(x.posts, context)
 | |
| 		} else if r.Method == "DELETE" {
 | |
| 			handleLoop(x.deletes, context)
 | |
| 		} else if r.Method == "OPTIONS" {
 | |
| 			// do nothing
 | |
| 		} else {
 | |
| 			panic("TODO handle method: " + r.Method)
 | |
| 		}
 | |
| 
 | |
| 		if context.ShowMessage {
 | |
| 			context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
 | |
| 		}
 | |
| 	})
 | |
| 
 | |
| 	return x
 | |
| }
 | |
| 
 | |
| func (x Handle) Startup() {
 | |
| 	log.Info("Starting up!\n")
 | |
| 
 | |
| 	log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", x.Config.Port), nil))
 | |
| }
 |