651 lines
15 KiB
Go
651 lines
15 KiB
Go
package utils
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"html/template"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"path"
|
|
"strings"
|
|
"time"
|
|
|
|
dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types"
|
|
. "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils"
|
|
"github.com/charmbracelet/log"
|
|
"github.com/go-playground/validator/v10"
|
|
"github.com/goccy/go-json"
|
|
)
|
|
|
|
type AnyMap = map[string]interface{}
|
|
|
|
type Error struct {
|
|
Code int
|
|
data any
|
|
}
|
|
|
|
type HandleFunc struct {
|
|
path string
|
|
fn func(c *Context) *Error
|
|
}
|
|
|
|
type Handler interface {
|
|
New()
|
|
Startup()
|
|
Get(fn func(c *Context) *Error)
|
|
Post(fn func(c *Context) *Error)
|
|
}
|
|
|
|
type Handle struct {
|
|
Db *sql.DB
|
|
gets []HandleFunc
|
|
posts []HandleFunc
|
|
deletes []HandleFunc
|
|
Config Config
|
|
validate *validator.Validate
|
|
}
|
|
|
|
func decodeBody(r *http.Request) (string, *Error) {
|
|
body, err := io.ReadAll(r.Body)
|
|
if err == nil {
|
|
return "", &Error{Code: http.StatusBadRequest}
|
|
}
|
|
|
|
return string(body[:]), nil
|
|
}
|
|
|
|
func handleError(err *Error, c *Context) {
|
|
if err != nil {
|
|
c.Logger.Warn("Responded with error", "code", err.Code, "data", err.data)
|
|
c.Writer.WriteHeader(err.Code)
|
|
var e *Error
|
|
if err.data != nil {
|
|
e = c.SendJSON(err.data)
|
|
} else {
|
|
e = c.SendJSON(500)
|
|
}
|
|
if e != nil {
|
|
c.Logger.Error("Something went very wrong while trying to send and error message")
|
|
c.Writer.Write([]byte("505"))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (x *Handle) Get(path string, fn func(c *Context) *Error) {
|
|
x.gets = append(x.gets, HandleFunc{path, fn})
|
|
}
|
|
|
|
func (x *Handle) Post(path string, fn func(c *Context) *Error) {
|
|
x.posts = append(x.posts, HandleFunc{path, fn})
|
|
}
|
|
|
|
func (x *Handle) PostAuth(path string, authLevel int, fn func(c *Context) *Error) {
|
|
inner_fn := func(c *Context) *Error {
|
|
if !c.CheckAuthLevel(authLevel) {
|
|
return nil
|
|
}
|
|
return fn(c)
|
|
}
|
|
x.posts = append(x.posts, HandleFunc{path, inner_fn})
|
|
}
|
|
|
|
func PostAuthJson[T interface{}](x *Handle, path string, authLevel int, fn func(c *Context, obj *T) *Error) {
|
|
inner_fn := func(c *Context) *Error {
|
|
if !c.CheckAuthLevel(authLevel) {
|
|
return nil
|
|
}
|
|
|
|
obj := new(T)
|
|
|
|
if err := c.ToJSON(obj); err != nil {
|
|
return err
|
|
}
|
|
|
|
return fn(c, obj)
|
|
}
|
|
|
|
x.posts = append(x.posts, HandleFunc{path, inner_fn})
|
|
}
|
|
|
|
func (x *Handle) Delete(path string, fn func(c *Context) *Error) {
|
|
x.deletes = append(x.deletes, HandleFunc{path, fn})
|
|
}
|
|
|
|
func (x *Handle) DeleteAuth(path string, authLevel int, fn func(c *Context) *Error) {
|
|
inner_fn := func(c *Context) *Error {
|
|
if !c.CheckAuthLevel(authLevel) {
|
|
return nil
|
|
}
|
|
return fn(c)
|
|
}
|
|
x.posts = append(x.posts, HandleFunc{path, inner_fn})
|
|
}
|
|
|
|
func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel int, fn func(c *Context, obj *T) *Error) {
|
|
inner_fn := func(c *Context) *Error {
|
|
if !c.CheckAuthLevel(authLevel) {
|
|
return nil
|
|
}
|
|
|
|
obj := new(T)
|
|
|
|
if err := c.ToJSON(obj); err != nil {
|
|
return err
|
|
}
|
|
|
|
return fn(c, obj)
|
|
}
|
|
|
|
x.deletes = append(x.deletes, HandleFunc{path, inner_fn})
|
|
}
|
|
|
|
func (x *Handle) handleGets(context *Context) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
context.Logger.Error("Something went very wrong", "Error", r)
|
|
handleError(&Error{500, "500"}, context)
|
|
}
|
|
}()
|
|
|
|
for _, s := range x.gets {
|
|
if s.path == context.R.URL.Path {
|
|
handleError(s.fn(context), context)
|
|
return
|
|
}
|
|
}
|
|
context.ShowMessage = false
|
|
handleError(&Error{404, "Endpoint not found"}, context)
|
|
}
|
|
|
|
func (x *Handle) handlePosts(context *Context) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
context.Logger.Error("Something went very wrong", "Error", r)
|
|
handleError(&Error{500, "500"}, context)
|
|
}
|
|
}()
|
|
|
|
for _, s := range x.posts {
|
|
if s.path == context.R.URL.Path {
|
|
handleError(s.fn(context), context)
|
|
return
|
|
}
|
|
}
|
|
context.ShowMessage = false
|
|
handleError(&Error{404, "Endpoint not found"}, context)
|
|
}
|
|
|
|
func (x *Handle) handleDeletes(context *Context) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
context.Logger.Error("Something went very wrong", "Error", r)
|
|
handleError(&Error{500, "500"}, context)
|
|
}
|
|
}()
|
|
|
|
for _, s := range x.deletes {
|
|
if s.path == context.R.URL.Path {
|
|
handleError(s.fn(context), context)
|
|
return
|
|
}
|
|
}
|
|
context.ShowMessage = false
|
|
handleError(&Error{404, "Endpoint not found"}, context)
|
|
}
|
|
|
|
func (c *Context) CheckAuthLevel(authLevel int) bool {
|
|
if authLevel > 0 {
|
|
if c.requireAuth() {
|
|
c.Logoff()
|
|
return false
|
|
}
|
|
if c.User.UserType < authLevel {
|
|
c.NotAuth()
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
type Context struct {
|
|
Token *string
|
|
User *dbtypes.User
|
|
Logger *log.Logger
|
|
Db *sql.DB
|
|
Writer http.ResponseWriter
|
|
R *http.Request
|
|
Tx *sql.Tx
|
|
ShowMessage bool
|
|
Handle *Handle
|
|
}
|
|
|
|
func (c Context) GetDb() *sql.DB {
|
|
return c.Db
|
|
}
|
|
|
|
func (c Context) GetLogger() *log.Logger {
|
|
return c.Logger
|
|
}
|
|
|
|
func (c Context) Query(query string, args ...any) (*sql.Rows, error) {
|
|
return c.Db.Query(query, args...)
|
|
}
|
|
|
|
func (c Context) Prepare(str string) (*sql.Stmt, error) {
|
|
if c.Tx == nil {
|
|
return c.Db.Prepare(str)
|
|
}
|
|
|
|
return c.Tx.Prepare(str)
|
|
}
|
|
|
|
var TransactionAlreadyStarted = errors.New("Transaction already started")
|
|
var TransactionNotStarted = errors.New("Transaction not started")
|
|
|
|
func (c *Context) StartTx() error {
|
|
if c.Tx != nil {
|
|
return TransactionAlreadyStarted
|
|
}
|
|
var err error = nil
|
|
c.Tx, err = c.Db.Begin()
|
|
return err
|
|
}
|
|
|
|
func (c *Context) CommitTx() error {
|
|
if c.Tx == nil {
|
|
return TransactionNotStarted
|
|
}
|
|
err := c.Tx.Commit()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.Tx = nil
|
|
return nil
|
|
}
|
|
|
|
func (c *Context) RollbackTx() error {
|
|
if c.Tx == nil {
|
|
return TransactionNotStarted
|
|
}
|
|
err := c.Tx.Rollback()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.Tx = nil
|
|
return nil
|
|
}
|
|
|
|
/**
|
|
* Parse and vailidates the json
|
|
*/
|
|
func (c Context) ParseJson(dat any, str string) *Error {
|
|
decoder := json.NewDecoder(strings.NewReader(str))
|
|
|
|
return c.decodeAndValidade(decoder, dat)
|
|
}
|
|
|
|
func (c Context) ToJSON(dat any) *Error {
|
|
decoder := json.NewDecoder(c.R.Body)
|
|
|
|
return c.decodeAndValidade(decoder, dat)
|
|
}
|
|
|
|
func (c Context) decodeAndValidade(decoder *json.Decoder, dat any) *Error {
|
|
err := decoder.Decode(dat)
|
|
if err != nil {
|
|
c.Logger.Error("Failed to decode json", "dat", dat, "err", err)
|
|
return c.JsonBadRequest("Bad Request! Invalid json passed!")
|
|
}
|
|
|
|
err = c.Handle.validate.Struct(dat)
|
|
if err != nil {
|
|
c.Logger.Error("Failed invalid json passed", "dat", dat, "err", err)
|
|
return c.JsonBadRequest("Bad Request! Invalid json passed!")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c Context) SendJSON(dat any) *Error {
|
|
c.Writer.Header().Add("content-type", "application/json")
|
|
text, err := json.Marshal(dat)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
if _, err = c.Writer.Write(text); err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c Context) SendJSONStatus(status int, dat any) *Error {
|
|
c.Writer.Header().Add("content-type", "application/json")
|
|
c.Writer.WriteHeader(status)
|
|
text, err := json.Marshal(dat)
|
|
if err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
if _, err = c.Writer.Write(text); err != nil {
|
|
return c.Error500(err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c Context) JsonBadRequest(dat any) *Error {
|
|
c.SetReportCaller(true)
|
|
c.Logger.Warn("Request failed with a bad request", "dat", dat)
|
|
c.SetReportCaller(false)
|
|
return c.ErrorCode(nil, 404, dat)
|
|
}
|
|
|
|
func (c Context) JsonErrorBadRequest(err error, dat any) *Error {
|
|
c.SetReportCaller(true)
|
|
c.Logger.Error("Error while processing request", "err", err, "dat", dat)
|
|
c.SetReportCaller(false)
|
|
return c.SendJSONStatus(http.StatusBadRequest, dat)
|
|
}
|
|
|
|
func (c *Context) GetModelFromId(id_path string) (*BaseModel, *Error) {
|
|
|
|
id, err := GetIdFromUrl(c, id_path)
|
|
if err != nil {
|
|
return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found")
|
|
}
|
|
|
|
model, err := GetBaseModel(c.Db, id)
|
|
if err == ModelNotFoundError {
|
|
return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found")
|
|
} else if err != nil {
|
|
return nil, c.Error500(err)
|
|
}
|
|
|
|
return model, nil
|
|
}
|
|
|
|
func ModelUpdateStatus(c *Context, id string, status int) {
|
|
_, err := c.Db.Exec("update models set status=$1 where id=$2;", status, id)
|
|
if err != nil {
|
|
c.Logger.Error("Failed to update model status", "err", err)
|
|
c.Logger.Warn("TODO Maybe handle better")
|
|
}
|
|
}
|
|
|
|
func (c Context) SetReportCaller(report bool) {
|
|
if report {
|
|
c.Logger.SetCallerOffset(2)
|
|
c.Logger.SetReportCaller(true)
|
|
} else {
|
|
c.Logger.SetCallerOffset(1)
|
|
c.Logger.SetReportCaller(false)
|
|
}
|
|
}
|
|
|
|
func (c Context) ErrorCode(err error, code int, data any) *Error {
|
|
if code == 400 {
|
|
c.SetReportCaller(true)
|
|
c.Logger.Warn("When returning BadRequest(400) please use context.Error400\n")
|
|
c.SetReportCaller(false)
|
|
}
|
|
if err != nil {
|
|
c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code)
|
|
c.Logger.Error(err)
|
|
}
|
|
return &Error{code, data}
|
|
}
|
|
|
|
func (c Context) Error500(err error) *Error {
|
|
return c.ErrorCode(err, http.StatusInternalServerError, nil)
|
|
}
|
|
|
|
func (c Context) E500M(msg string, err error) *Error {
|
|
return c.ErrorCode(err, http.StatusInternalServerError, msg)
|
|
}
|
|
|
|
func (c *Context) requireAuth() bool {
|
|
if c.User == nil {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
var LogoffError = errors.New("Invalid token!")
|
|
|
|
func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseWriter) (*Context, error) {
|
|
|
|
var token *string
|
|
|
|
logger := log.NewWithOptions(os.Stdout, log.Options{
|
|
ReportCaller: true,
|
|
ReportTimestamp: true,
|
|
TimeFormat: time.Kitchen,
|
|
Prefix: r.URL.Path,
|
|
})
|
|
|
|
t := r.Header.Get("token")
|
|
if t != "" {
|
|
token = &t
|
|
}
|
|
|
|
// TODO check that the token is still valid
|
|
if token == nil {
|
|
return &Context{
|
|
Logger: logger,
|
|
Db: handler.Db,
|
|
Writer: w,
|
|
R: r,
|
|
ShowMessage: true,
|
|
Handle: &x,
|
|
}, nil
|
|
}
|
|
|
|
user, err := dbtypes.UserFromToken(x.Db, *token)
|
|
if err != nil {
|
|
return nil, errors.Join(err, LogoffError)
|
|
}
|
|
|
|
return &Context{token, user, logger, handler.Db, w, r, nil, true, &x}, nil
|
|
}
|
|
|
|
func contextlessLogoff(w http.ResponseWriter) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
w.Write([]byte("\"Not Authorized\""))
|
|
}
|
|
|
|
func (c *Context) Logoff() { contextlessLogoff(c.Writer) }
|
|
|
|
func (c *Context) NotAuth() {
|
|
c.Writer.WriteHeader(http.StatusUnauthorized)
|
|
e := c.SendJSON("Not Authorized")
|
|
if e != nil {
|
|
c.Writer.WriteHeader(http.StatusInternalServerError)
|
|
c.Writer.Write([]byte("You can not access this resource!"))
|
|
}
|
|
}
|
|
|
|
func (x Handle) StaticFiles(pathTest string, fileType string, contentType string) {
|
|
http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) {
|
|
path := r.URL.Path[len(pathTest):]
|
|
|
|
if !strings.HasSuffix(path, fileType) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte("File not found"))
|
|
return
|
|
}
|
|
|
|
t, err := template.ParseFiles("./views" + pathTest + path)
|
|
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte("Failed to load template"))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", contentType+"; charset=utf-8")
|
|
t.Execute(w, nil)
|
|
})
|
|
}
|
|
|
|
func (x Handle) ReadFiles(pathTest string, baseFilePath string, fileType string, contentType string) {
|
|
http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) {
|
|
user_path := r.URL.Path[len(pathTest):]
|
|
|
|
// fmt.Printf("Requested path: %s\n", user_path)
|
|
|
|
if !strings.HasSuffix(user_path, fileType) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte("File not found"))
|
|
return
|
|
}
|
|
|
|
bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path))
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte("Failed to load file"))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", contentType)
|
|
w.Write(bytes)
|
|
})
|
|
}
|
|
|
|
// TODO remove this
|
|
func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) {
|
|
http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) {
|
|
user_path := r.URL.Path[len(pathTest):]
|
|
|
|
// fmt.Printf("Requested path: %s\n", user_path)
|
|
|
|
found := false
|
|
index := -1
|
|
|
|
for i, fileType := range fileTypes {
|
|
if strings.HasSuffix(user_path, fileType) {
|
|
found = true
|
|
index = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte("File not found"))
|
|
return
|
|
}
|
|
|
|
bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path))
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte("Failed to load file"))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", contentTypes[index])
|
|
w.Write(bytes)
|
|
})
|
|
}
|
|
|
|
func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) {
|
|
http.HandleFunc("/api"+pathTest, func(w http.ResponseWriter, r *http.Request) {
|
|
r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1)
|
|
|
|
user_path := r.URL.Path[len(pathTest):]
|
|
|
|
found := false
|
|
index := -1
|
|
|
|
for i, fileType := range fileTypes {
|
|
if strings.HasSuffix(user_path, fileType) {
|
|
found = true
|
|
index = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if !found {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
w.Write([]byte("File not found"))
|
|
return
|
|
}
|
|
|
|
bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path))
|
|
if err != nil {
|
|
fmt.Println(err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
w.Write([]byte("Failed to load file"))
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", contentTypes[index])
|
|
w.Write(bytes)
|
|
})
|
|
}
|
|
|
|
func NewHandler(db *sql.DB, config Config) *Handle {
|
|
|
|
var gets []HandleFunc
|
|
var posts []HandleFunc
|
|
var deletes []HandleFunc
|
|
validate := validator.New()
|
|
x := &Handle{db, gets, posts, deletes, config, validate}
|
|
|
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
w.Header().Add("Access-Control-Allow-Origin", "*")
|
|
w.Header().Add("Access-Control-Allow-Headers", "*")
|
|
w.Header().Add("Access-Control-Allow-Methods", "*")
|
|
|
|
// Decide answertype
|
|
/* if !(r.Header.Get("content-type") == "application/json" || r.Header.Get("response-type") == "application/json") {
|
|
w.WriteHeader(500)
|
|
w.Write([]byte("Please set content-type to application/json or set response-type to application/json\n"))
|
|
return
|
|
}*/
|
|
|
|
if !strings.HasPrefix(r.URL.Path, "/api") {
|
|
w.WriteHeader(404)
|
|
w.Write([]byte("Path not found"))
|
|
return
|
|
}
|
|
|
|
r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1)
|
|
|
|
//Login state
|
|
context, err := x.createContext(x, r, w)
|
|
if err != nil {
|
|
contextlessLogoff(w)
|
|
return
|
|
}
|
|
|
|
// context.Logger.Info("Parsing", "path", r.URL.Path)
|
|
|
|
if r.Method == "GET" {
|
|
x.handleGets(context)
|
|
} else if r.Method == "POST" {
|
|
x.handlePosts(context)
|
|
} else if r.Method == "DELETE" {
|
|
x.handleDeletes(context)
|
|
} else if r.Method == "OPTIONS" {
|
|
// do nothing
|
|
} else {
|
|
panic("TODO handle method: " + r.Method)
|
|
}
|
|
|
|
if context.ShowMessage {
|
|
context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path)
|
|
}
|
|
})
|
|
|
|
return x
|
|
}
|
|
|
|
func (x Handle) Startup() {
|
|
log.Info("Starting up!\n")
|
|
|
|
log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", x.Config.Port), nil))
|
|
}
|