package utils import ( "bytes" "context" "errors" "fmt" "io" "net/http" "os" "path" "runtime/debug" "strings" "time" "git.andr3h3nriqu3s.com/andr3/fyp/logic/db" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" "github.com/charmbracelet/log" "github.com/go-playground/validator/v10" "github.com/goccy/go-json" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" ) type AnyMap = map[string]interface{} type Error struct { Code int data any } type HandleFunc struct { path string fn func(c *Context) *Error } type Handler interface { New() Startup() Get(fn func(c *Context) *Error) Post(fn func(c *Context) *Error) } type Handle struct { Db db.Db gets []HandleFunc posts []HandleFunc deletes []HandleFunc DataMap map[string]interface{} Config Config validate *validator.Validate } func handleError(err *Error, c *Context) { if err != nil { c.Logger.Warn("Responded with error", "code", err.Code, "data", err.data) c.Writer.WriteHeader(err.Code) var e *Error if err.data != nil { e = c.SendJSON(err.data) } else { e = c.SendJSON(500) } if e != nil { c.Logger.Error("Something went very wrong while trying to send and error message", "stack", string(debug.Stack())) c.Writer.Write([]byte("505")) } } } // This group of functions defines some endpoints func (x *Handle) Get(path string, fn func(c *Context) *Error) { x.gets = append(x.gets, HandleFunc{path, fn}) } func (x *Handle) GetAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) { inner_fn := func(c *Context) *Error { if !c.CheckAuthLevel(authLevel) { return nil } return fn(c) } x.gets = append(x.gets, HandleFunc{path, inner_fn}) } func (x *Handle) Post(path string, fn func(c *Context) *Error) { x.posts = append(x.posts, HandleFunc{path, fn}) } func (x *Handle) PostAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) { inner_fn := func(c *Context) *Error { if !c.CheckAuthLevel(authLevel) { return nil } return fn(c) } x.posts = append(x.posts, HandleFunc{path, inner_fn}) } func PostAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, dat *T) *Error) { inner_fn := func(c *Context) *Error { if !c.CheckAuthLevel(authLevel) { return nil } obj := new(T) if err := c.ToJSON(obj); err != nil { return err } return fn(c, obj) } x.posts = append(x.posts, HandleFunc{path, inner_fn}) } func PostAuthFormJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, dat *T, file []byte) *Error) { inner_fn := func(c *Context) *Error { if !c.CheckAuthLevel(1) { return nil } read_form, err := c.R.MultipartReader() if err != nil { return c.JsonBadRequest("Please provide a valid form data request!") } var json_data string var file []byte for { part, err_part := read_form.NextPart() if err_part == io.EOF { break } else if err_part != nil { return c.JsonBadRequest("Please provide a valid form data request!") } if part.FormName() == "json_data" { buf := new(bytes.Buffer) buf.ReadFrom(part) json_data = buf.String() } if part.FormName() == "file" { buf := new(bytes.Buffer) buf.ReadFrom(part) file = buf.Bytes() } } if !c.CheckAuthLevel(authLevel) { return nil } dat := new(T) decoder := json.NewDecoder(strings.NewReader(json_data)) if err := c.decodeAndValidade(decoder, dat); err != nil { return err } return fn(c, dat, file) } x.posts = append(x.posts, HandleFunc{path, inner_fn}) } func (x *Handle) Delete(path string, fn func(c *Context) *Error) { x.deletes = append(x.deletes, HandleFunc{path, fn}) } func (x *Handle) DeleteAuth(path string, authLevel dbtypes.UserType, fn func(c *Context) *Error) { inner_fn := func(c *Context) *Error { if !c.CheckAuthLevel(authLevel) { return nil } return fn(c) } x.deletes = append(x.deletes, HandleFunc{path, inner_fn}) } func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, obj *T) *Error) { inner_fn := func(c *Context) *Error { if !c.CheckAuthLevel(authLevel) { return nil } obj := new(T) if err := c.ToJSON(obj); err != nil { return err } return fn(c, obj) } x.deletes = append(x.deletes, HandleFunc{path, inner_fn}) } // This function handles loop of a list of possible handler functions func handleLoop(array []HandleFunc, context *Context) { defer func() { if r := recover(); r != nil { context.Logger.Error("Something went very wrong", "Error", r, "stack", string(debug.Stack())) handleError(&Error{500, "500"}, context) } }() for _, s := range array { if s.path == context.R.URL.Path { handleError(s.fn(context), context) return } } context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } func (c *Context) CheckAuthLevel(authLevel dbtypes.UserType) bool { if authLevel > 0 { if c.User == nil { contextlessLogoff(c.Writer) return false } if c.User.UserType < int(authLevel) { c.Writer.WriteHeader(http.StatusUnauthorized) e := c.SendJSON("Not Authorized") if e != nil { c.Writer.WriteHeader(http.StatusInternalServerError) c.Writer.Write([]byte("You can not access this resource!")) } return false } } return true } type Context struct { Token *string User *dbtypes.User Logger *log.Logger Db db.Db Writer http.ResponseWriter R *http.Request Tx pgx.Tx ShowMessage bool Handle *Handle } // This is required for this to integrate simealy with my orl func (c Context) GetDb() db.Db { return c.Db } func (c Context) GetLogger() *log.Logger { return c.Logger } func (c Context) GetHost() string { return c.Handle.Config.Hostname } func (c Context) Query(query string, args ...any) (pgx.Rows, error) { if c.Tx == nil { return c.Db.Query(query, args...) } return c.Tx.Query(context.Background(), query, args...) } func (c Context) Exec(query string, args ...any) (pgconn.CommandTag, error) { if c.Tx == nil { return c.Db.Exec(query, args...) } return c.Tx.Exec(context.Background(), query, args...) } func (c Context) Begin() (pgx.Tx, error) { return c.Db.Begin() } var TransactionAlreadyStarted = errors.New("Transaction already started") var TransactionNotStarted = errors.New("Transaction not started") func (c *Context) StartTx() error { if c.Tx != nil { return TransactionAlreadyStarted } var err error = nil c.Tx, err = c.Begin() return err } func (c *Context) CommitTx() error { if c.Tx == nil { return TransactionNotStarted } err := c.Tx.Commit(context.Background()) if err != nil { return err } c.Tx = nil return nil } func (c *Context) RollbackTx() error { if c.Tx == nil { return TransactionNotStarted } err := c.Tx.Rollback(context.Background()) if err != nil { return err } c.Tx = nil return nil } /** * Parse and vailidates the json */ func (c Context) ParseJson(dat any, str string) *Error { decoder := json.NewDecoder(strings.NewReader(str)) return c.decodeAndValidade(decoder, dat) } func (c Context) ToJSON(dat any) *Error { decoder := json.NewDecoder(c.R.Body) return c.decodeAndValidade(decoder, dat) } func (c Context) decodeAndValidade(decoder *json.Decoder, dat any) *Error { err := decoder.Decode(dat) if err != nil { c.Logger.Error("Failed to decode json", "dat", dat, "err", err) return c.JsonBadRequest("Bad Request! Invalid json passed!") } err = c.Handle.validate.Struct(dat) if err != nil { c.Logger.Error("Failed invalid json passed", "dat", dat, "err", err) return c.JsonBadRequest("Bad Request! Invalid json passed!") } return nil } func (c Context) SendJSON(dat any) *Error { c.Writer.Header().Add("content-type", "application/json") text, err := json.Marshal(dat) if err != nil { return c.Error500(err) } if _, err = c.Writer.Write(text); err != nil { return c.Error500(err) } return nil } func (c Context) SendJSONStatus(status int, dat any) *Error { c.Writer.Header().Add("content-type", "application/json") c.Writer.WriteHeader(status) text, err := json.Marshal(dat) if err != nil { return c.Error500(err) } if _, err = c.Writer.Write(text); err != nil { return c.Error500(err) } return nil } func (c Context) JsonBadRequest(dat any) *Error { c.SetReportCaller(true) c.Logger.Warn("Request failed with a bad request", "dat", dat) c.SetReportCaller(false) return c.SendJSONStatus(http.StatusBadRequest, dat) } func (c Context) JsonErrorBadRequest(err error, dat any) *Error { c.SetReportCaller(true) c.Logger.Error("Error while processing request", "err", err, "dat", dat) c.SetReportCaller(false) return c.SendJSONStatus(http.StatusBadRequest, dat) } func (c *Context) GetModelFromId(id_path string) (*dbtypes.BaseModel, *Error) { id, err := GetIdFromUrl(c, id_path) if err != nil { return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found") } model, err := dbtypes.GetBaseModel(c.Db, id) if err == dbtypes.ModelNotFoundError { return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found") } else if err != nil { return nil, c.Error500(err) } return model, nil } func ModelUpdateStatus(c dbtypes.BasePack, id string, status int) { _, err := c.GetDb().Exec("update models set status=$1 where id=$2;", status, id) if err != nil { c.GetLogger().Error("Failed to update model status", "err", err) c.GetLogger().Warn("TODO Maybe handle better") } } func (c Context) SetReportCaller(report bool) { if report { c.Logger.SetCallerOffset(2) c.Logger.SetReportCaller(true) } else { c.Logger.SetCallerOffset(1) c.Logger.SetReportCaller(false) } } func (c Context) ErrorCode(err error, code int, data any) *Error { if code == 400 { c.SetReportCaller(true) c.Logger.Warn("When returning BadRequest(400) please use context.Error400\n") c.SetReportCaller(false) } if err != nil { c.Logger.Error("Something went wrong returning with:", "Error", err) } return &Error{code, data} } // Deprecated: Use the E500M instead func (c Context) Error500(err error) *Error { return c.ErrorCode(err, http.StatusInternalServerError, nil) } func (c Context) E500M(msg string, err error) *Error { return c.ErrorCode(err, http.StatusInternalServerError, msg) } var LogoffError = errors.New("Invalid token!") func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseWriter) (*Context, error) { var token *string logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, TimeFormat: time.Kitchen, Prefix: r.URL.Path, }) t := r.Header.Get("token") if t != "" { token = &t } // TODO check that the token is still valid if token == nil { return &Context{ Logger: logger, Db: handler.Db, Writer: w, R: r, ShowMessage: true, Handle: &x, }, nil } user, err := dbtypes.UserFromToken(x.Db, *token) if err != nil { return nil, errors.Join(err, LogoffError) } return &Context{token, user, logger, handler.Db, w, r, nil, true, &x}, nil } func contextlessLogoff(w http.ResponseWriter) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("\"Not Authorized\"")) } func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { http.HandleFunc("/api"+pathTest, func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1) user_path := r.URL.Path[len(pathTest):] found := false index := -1 for i, fileType := range fileTypes { if strings.HasSuffix(user_path, fileType) { found = true index = i break } } if !found { w.WriteHeader(http.StatusNotFound) w.Write([]byte("File not found")) return } bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) if err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("Failed to load file")) return } w.Header().Set("Content-Type", contentTypes[index]) w.Write(bytes) }) } func NewHandler(db db.Db, config Config) *Handle { var gets []HandleFunc var posts []HandleFunc var deletes []HandleFunc validate := validator.New() x := &Handle{db, gets, posts, deletes, map[string]interface{}{}, config, validate} http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Access-Control-Allow-Origin", "*") w.Header().Add("Access-Control-Allow-Headers", "*") w.Header().Add("Access-Control-Allow-Methods", "*") // Decide answertype /* if !(r.Header.Get("content-type") == "application/json" || r.Header.Get("response-type") == "application/json") { w.WriteHeader(500) w.Write([]byte("Please set content-type to application/json or set response-type to application/json\n")) return }*/ if !strings.HasPrefix(r.URL.Path, "/api") { w.WriteHeader(404) w.Write([]byte("Path not found")) return } r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1) //Login state context, err := x.createContext(x, r, w) if err != nil { contextlessLogoff(w) return } // context.Logger.Info("Parsing", "path", r.URL.Path) if r.Method == "GET" { handleLoop(x.gets, context) } else if r.Method == "POST" { handleLoop(x.posts, context) } else if r.Method == "DELETE" { handleLoop(x.deletes, context) } else if r.Method == "OPTIONS" { // do nothing } else { panic("TODO handle method: " + r.Method) } if context.ShowMessage { context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path) } }) return x } func (x Handle) Startup() { log.Info("Starting up!\n") log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", x.Config.Port), nil)) }