package utils import ( "database/sql" "errors" "fmt" "html/template" "io" "net/http" "os" "path" "strings" "time" dbtypes "git.andr3h3nriqu3s.com/andr3/fyp/logic/db_types" . "git.andr3h3nriqu3s.com/andr3/fyp/logic/models/utils" "github.com/charmbracelet/log" "github.com/goccy/go-json" ) type AnyMap = map[string]interface{} type Error struct { Code int data any } type HandleFunc struct { path string fn func(c *Context) *Error } type Handler interface { New() Startup() Get(fn func(c *Context) *Error) Post(fn func(c *Context) *Error) } type Handle struct { Db *sql.DB gets []HandleFunc posts []HandleFunc deletes []HandleFunc } func decodeBody(r *http.Request) (string, *Error) { body, err := io.ReadAll(r.Body) if err == nil { return "", &Error{Code: http.StatusBadRequest} } return string(body[:]), nil } func handleError(err *Error, c *Context) { if err != nil { c.Logger.Warn("Responded with error", "code", err.Code, "data", err.data) c.Writer.WriteHeader(err.Code) var e *Error if err.data != nil { e = c.SendJSON(err.data) } else { e = c.SendJSON(500) } if e != nil { c.Logger.Error("Something went very wront while trying to send and error message") c.Writer.Write([]byte("505")) } } } func (x *Handle) Get(path string, fn func(c *Context) *Error) { x.gets = append(x.gets, HandleFunc{path, fn}) } func (x *Handle) Post(path string, fn func(c *Context) *Error) { x.posts = append(x.posts, HandleFunc{path, fn}) } func (x *Handle) Delete(path string, fn func(c *Context) *Error) { x.deletes = append(x.deletes, HandleFunc{path, fn}) } func (x *Handle) handleGets(context *Context) { for _, s := range x.gets { if s.path == context.R.URL.Path { handleError(s.fn(context), context) return } } context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } func (x *Handle) handlePosts(context *Context) { for _, s := range x.posts { if s.path == context.R.URL.Path { handleError(s.fn(context), context) return } } context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } func (x *Handle) handleDeletes(context *Context) { for _, s := range x.deletes { if s.path == context.R.URL.Path { handleError(s.fn(context), context) return } } context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } func (c *Context) CheckAuthLevel(authLevel int) bool { if authLevel > 0 { if c.requireAuth() { c.Logoff() return false } if c.User.UserType < authLevel { c.NotAuth() return false } } return true } type Context struct { Token *string User *dbtypes.User Logger *log.Logger Db *sql.DB Writer http.ResponseWriter R *http.Request Tx *sql.Tx ShowMessage bool } func (c Context) Prepare(str string) (*sql.Stmt, error) { if c.Tx == nil { return c.Db.Prepare(str) } return c.Tx.Prepare(str) } var TransactionAlreadyStarted = errors.New("Transaction already started") var TransactionNotStarted = errors.New("Transaction not started") func (c *Context) StartTx() error { if c.Tx != nil { return TransactionAlreadyStarted } var err error = nil c.Tx, err = c.Db.Begin() return err } func (c *Context) CommitTx() error { if c.Tx == nil { return TransactionNotStarted } err := c.Tx.Commit() if err != nil { return err } c.Tx = nil return nil } func (c *Context) RollbackTx() error { if c.Tx == nil { return TransactionNotStarted } err := c.Tx.Rollback() if err != nil { return err } c.Tx = nil return nil } func (c Context) ToJSON(dat any) *Error { decoder := json.NewDecoder(c.R.Body) err := decoder.Decode(dat) if err != nil { return c.Error500(err) } return nil } func (c Context) SendJSON(dat any) *Error { c.Writer.Header().Add("content-type", "application/json") text, err := json.Marshal(dat) if err != nil { return c.Error500(err) } if _, err = c.Writer.Write(text); err != nil { return c.Error500(err) } return nil } func (c Context) SendJSONStatus(status int, dat any) *Error { c.Writer.Header().Add("content-type", "application/json") c.Writer.WriteHeader(status) text, err := json.Marshal(dat) if err != nil { return c.Error500(err) } if _, err = c.Writer.Write(text); err != nil { return c.Error500(err) } return nil } func (c Context) JsonBadRequest(dat any) *Error { c.SetReportCaller(true) c.Logger.Warn("Request failed with a bad request", "dat", dat) c.SetReportCaller(false) return c.SendJSONStatus(http.StatusBadRequest, dat) } func (c Context) JsonErrorBadRequest(err error, dat any) *Error { c.SetReportCaller(true) c.Logger.Error("Error while processing request", "err", err, "dat", dat) c.SetReportCaller(false) return c.SendJSONStatus(http.StatusBadRequest, dat) } func (c *Context) GetModelFromId(id_path string) (*BaseModel, *Error) { id, err := GetIdFromUrl(c, id_path) if err != nil { return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found") } model, err := GetBaseModel(c.Db, id) if err == ModelNotFoundError { return nil, c.SendJSONStatus(http.StatusNotFound, "Model not found") } else if err != nil { return nil, c.Error500(err) } return model, nil } func ModelUpdateStatus(c *Context, id string, status int) { _, err := c.Db.Exec("update models set status=$1 where id=$2;", status, id) if err != nil { c.Logger.Error("Failed to update model status", "err", err) c.Logger.Warn("TODO Maybe handle better") } } func (c Context) SetReportCaller(report bool) { if report { c.Logger.SetCallerOffset(2) c.Logger.SetReportCaller(true) } else { c.Logger.SetCallerOffset(1) c.Logger.SetReportCaller(false) } } func (c Context) ErrorCode(err error, code int, data any) *Error { if code == 400 { c.SetReportCaller(true) c.Logger.Warn("When returning BadRequest(400) please use context.Error400\n") c.SetReportCaller(false) } if err != nil { c.Logger.Errorf("Something went wrong returning with: %d\n.Err:\n", code) c.Logger.Error(err) } return &Error{code, data} } func (c Context) Error500(err error) *Error { return c.ErrorCode(err, http.StatusInternalServerError, nil) } func (c *Context) requireAuth() bool { if c.User == nil { return true } return false } var LogoffError = errors.New("Invalid token!") func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseWriter) (*Context, error) { var token *string logger := log.NewWithOptions(os.Stdout, log.Options{ ReportCaller: true, ReportTimestamp: true, TimeFormat: time.Kitchen, Prefix: r.URL.Path, }) t := r.Header.Get("token") if t != "" { token = &t } // TODO check that the token is still valid if token == nil { return &Context{ Logger: logger, Db: handler.Db, Writer: w, R: r, ShowMessage: true, }, nil } user, err := dbtypes.UserFromToken(x.Db, *token) if err != nil { return nil, errors.Join(err, LogoffError) } return &Context{token, user, logger, handler.Db, w, r, nil, true}, nil } func contextlessLogoff(w http.ResponseWriter) { w.WriteHeader(http.StatusUnauthorized) w.Write([]byte("\"Not Authorized\"")) } func (c *Context) Logoff() { contextlessLogoff(c.Writer) } func (c *Context) NotAuth() { c.Writer.WriteHeader(http.StatusUnauthorized) e := c.SendJSON("Not Authorized") if e != nil { c.Writer.WriteHeader(http.StatusInternalServerError) c.Writer.Write([]byte("You can not access this resource!")) } } func (x Handle) StaticFiles(pathTest string, fileType string, contentType string) { http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { path := r.URL.Path[len(pathTest):] if !strings.HasSuffix(path, fileType) { w.WriteHeader(http.StatusNotFound) w.Write([]byte("File not found")) return } t, err := template.ParseFiles("./views" + pathTest + path) if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("Failed to load template")) return } w.Header().Set("Content-Type", contentType+"; charset=utf-8") t.Execute(w, nil) }) } func (x Handle) ReadFiles(pathTest string, baseFilePath string, fileType string, contentType string) { http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { user_path := r.URL.Path[len(pathTest):] // fmt.Printf("Requested path: %s\n", user_path) if !strings.HasSuffix(user_path, fileType) { w.WriteHeader(http.StatusNotFound) w.Write([]byte("File not found")) return } bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) if err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("Failed to load file")) return } w.Header().Set("Content-Type", contentType) w.Write(bytes) }) } // TODO remove this func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { user_path := r.URL.Path[len(pathTest):] // fmt.Printf("Requested path: %s\n", user_path) found := false index := -1 for i, fileType := range fileTypes { if strings.HasSuffix(user_path, fileType) { found = true index = i break } } if !found { w.WriteHeader(http.StatusNotFound) w.Write([]byte("File not found")) return } bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) if err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("Failed to load file")) return } w.Header().Set("Content-Type", contentTypes[index]) w.Write(bytes) }) } func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { http.HandleFunc("/api"+pathTest, func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1) user_path := r.URL.Path[len(pathTest):] found := false index := -1 for i, fileType := range fileTypes { if strings.HasSuffix(user_path, fileType) { found = true index = i break } } if !found { w.WriteHeader(http.StatusNotFound) w.Write([]byte("File not found")) return } bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) if err != nil { fmt.Println(err) w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("Failed to load file")) return } w.Header().Set("Content-Type", contentTypes[index]) w.Write(bytes) }) } func NewHandler(db *sql.DB) *Handle { var gets []HandleFunc var posts []HandleFunc var deletes []HandleFunc x := &Handle{db, gets, posts, deletes} http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { w.Header().Add("Access-Control-Allow-Origin", "*") w.Header().Add("Access-Control-Allow-Headers", "*") w.Header().Add("Access-Control-Allow-Methods", "*") // Decide answertype if !(r.Header.Get("content-type") == "application/json" || r.Header.Get("response-type") == "application/json") { w.WriteHeader(500) w.Write([]byte("Please set content-type to application/json or set response-type to application/json\n")) return } if !strings.HasPrefix(r.URL.Path, "/api") { w.WriteHeader(404) w.Write([]byte("Path not found")) return } r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1) //Login state context, err := x.createContext(x, r, w) if err != nil { contextlessLogoff(w) return } // context.Logger.Info("Parsing", "path", r.URL.Path) if r.Method == "GET" { x.handleGets(context) } else if r.Method == "POST" { x.handlePosts(context) } else if r.Method == "DELETE" { x.handleDeletes(context) } else if r.Method == "OPTIONS" { // do nothing } else { panic("TODO handle method: " + r.Method) } if context.ShowMessage { context.Logger.Info("Processed", "method", r.Method, "url", r.URL.Path) } }) return x } func (x Handle) Startup() { log.Info("Starting up!\n") port := os.Getenv("PORT") if port == "" { port = "8000" } log.Fatal(http.ListenAndServe(fmt.Sprintf(":%s", port), nil)) }