From fbf7eb92716cd30be5f0bf367fad33647b77f5c4 Mon Sep 17 00:00:00 2001 From: Andre Henriques Date: Sat, 13 Apr 2024 23:55:01 +0100 Subject: [PATCH] chore: did some clean up --- logic/db_types/user.go | 15 +- logic/utils/handler.go | 189 ++++------------------- main.go | 4 - users.go | 147 ++++++------------ webpage/src/routes/register/+page.svelte | 138 ++++++++--------- 5 files changed, 148 insertions(+), 345 deletions(-) diff --git a/logic/db_types/user.go b/logic/db_types/user.go index f29e62a..08838b4 100644 --- a/logic/db_types/user.go +++ b/logic/db_types/user.go @@ -7,15 +7,16 @@ import ( type UserType int -const ( - User_Normal UserType = iota + 1 - User_Admin +const ( + User_Not_Auth UserType = iota + User_Normal + User_Admin ) type User struct { - Id string - Username string - Email string + Id string + Username string + Email string UserType int } @@ -26,7 +27,7 @@ func UserFromToken(db *sql.DB, token string) (*User, error) { if err != nil { return nil, err } - defer rows.Close() + defer rows.Close() var id string var username string diff --git a/logic/utils/handler.go b/logic/utils/handler.go index 5cfe090..09b1c69 100644 --- a/logic/utils/handler.go +++ b/logic/utils/handler.go @@ -4,8 +4,6 @@ import ( "database/sql" "errors" "fmt" - "html/template" - "io" "net/http" "os" "path" @@ -47,15 +45,6 @@ type Handle struct { validate *validator.Validate } -func decodeBody(r *http.Request) (string, *Error) { - body, err := io.ReadAll(r.Body) - if err == nil { - return "", &Error{Code: http.StatusBadRequest} - } - - return string(body[:]), nil -} - func handleError(err *Error, c *Context) { if err != nil { c.Logger.Warn("Responded with error", "code", err.Code, "data", err.data) @@ -73,10 +62,20 @@ func handleError(err *Error, c *Context) { } } +// This group of functions defines some endpoints func (x *Handle) Get(path string, fn func(c *Context) *Error) { x.gets = append(x.gets, HandleFunc{path, fn}) } +func (x *Handle) GetAuth(path string, authLevel int, fn func(c *Context) *Error) { + inner_fn := func(c *Context) *Error { + if !c.CheckAuthLevel(authLevel) { + return nil + } + return fn(c) + } + x.gets = append(x.gets, HandleFunc{path, inner_fn}) +} func (x *Handle) Post(path string, fn func(c *Context) *Error) { x.posts = append(x.posts, HandleFunc{path, fn}) } @@ -91,9 +90,9 @@ func (x *Handle) PostAuth(path string, authLevel int, fn func(c *Context) *Error x.posts = append(x.posts, HandleFunc{path, inner_fn}) } -func PostAuthJson[T interface{}](x *Handle, path string, authLevel int, fn func(c *Context, obj *T) *Error) { +func PostAuthJson[T interface{}](x *Handle, path string, authLevel dbtypes.UserType, fn func(c *Context, obj *T) *Error) { inner_fn := func(c *Context) *Error { - if !c.CheckAuthLevel(authLevel) { + if !c.CheckAuthLevel(int(authLevel)) { return nil } @@ -141,7 +140,8 @@ func DeleteAuthJson[T interface{}](x *Handle, path string, authLevel int, fn fun x.deletes = append(x.deletes, HandleFunc{path, inner_fn}) } -func (x *Handle) handleGets(context *Context) { +// This function handles loop of a list of possible handler functions +func handleLoop(array []HandleFunc, context *Context) { defer func() { if r := recover(); r != nil { context.Logger.Error("Something went very wrong", "Error", r) @@ -149,60 +149,30 @@ func (x *Handle) handleGets(context *Context) { } }() - for _, s := range x.gets { + for _, s := range array { if s.path == context.R.URL.Path { handleError(s.fn(context), context) return } } - context.ShowMessage = false - handleError(&Error{404, "Endpoint not found"}, context) -} -func (x *Handle) handlePosts(context *Context) { - defer func() { - if r := recover(); r != nil { - context.Logger.Error("Something went very wrong", "Error", r) - handleError(&Error{500, "500"}, context) - } - }() - - for _, s := range x.posts { - if s.path == context.R.URL.Path { - handleError(s.fn(context), context) - return - } - } - context.ShowMessage = false - handleError(&Error{404, "Endpoint not found"}, context) -} - -func (x *Handle) handleDeletes(context *Context) { - defer func() { - if r := recover(); r != nil { - context.Logger.Error("Something went very wrong", "Error", r) - handleError(&Error{500, "500"}, context) - } - }() - - for _, s := range x.deletes { - if s.path == context.R.URL.Path { - handleError(s.fn(context), context) - return - } - } context.ShowMessage = false handleError(&Error{404, "Endpoint not found"}, context) } func (c *Context) CheckAuthLevel(authLevel int) bool { if authLevel > 0 { - if c.requireAuth() { - c.Logoff() + if c.User == nil { + contextlessLogoff(c.Writer) return false } if c.User.UserType < authLevel { - c.NotAuth() + c.Writer.WriteHeader(http.StatusUnauthorized) + e := c.SendJSON("Not Authorized") + if e != nil { + c.Writer.WriteHeader(http.StatusInternalServerError) + c.Writer.Write([]byte("You can not access this resource!")) + } return false } } @@ -221,6 +191,7 @@ type Context struct { Handle *Handle } +// This is required for this to integrate simealy with my orl func (c Context) GetDb() *sql.DB { return c.Db } @@ -237,7 +208,6 @@ func (c Context) Prepare(str string) (*sql.Stmt, error) { if c.Tx == nil { return c.Db.Prepare(str) } - return c.Tx.Prepare(str) } @@ -395,6 +365,7 @@ func (c Context) ErrorCode(err error, code int, data any) *Error { return &Error{code, data} } +// Deprecated: Use the E500M instead func (c Context) Error500(err error) *Error { return c.ErrorCode(err, http.StatusInternalServerError, nil) } @@ -403,13 +374,6 @@ func (c Context) E500M(msg string, err error) *Error { return c.ErrorCode(err, http.StatusInternalServerError, msg) } -func (c *Context) requireAuth() bool { - if c.User == nil { - return true - } - return false -} - var LogoffError = errors.New("Invalid token!") func (x Handle) createContext(handler *Handle, r *http.Request, w http.ResponseWriter) (*Context, error) { @@ -453,102 +417,6 @@ func contextlessLogoff(w http.ResponseWriter) { w.Write([]byte("\"Not Authorized\"")) } -func (c *Context) Logoff() { contextlessLogoff(c.Writer) } - -func (c *Context) NotAuth() { - c.Writer.WriteHeader(http.StatusUnauthorized) - e := c.SendJSON("Not Authorized") - if e != nil { - c.Writer.WriteHeader(http.StatusInternalServerError) - c.Writer.Write([]byte("You can not access this resource!")) - } -} - -func (x Handle) StaticFiles(pathTest string, fileType string, contentType string) { - http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { - path := r.URL.Path[len(pathTest):] - - if !strings.HasSuffix(path, fileType) { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte("File not found")) - return - } - - t, err := template.ParseFiles("./views" + pathTest + path) - - if err != nil { - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Failed to load template")) - return - } - - w.Header().Set("Content-Type", contentType+"; charset=utf-8") - t.Execute(w, nil) - }) -} - -func (x Handle) ReadFiles(pathTest string, baseFilePath string, fileType string, contentType string) { - http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { - user_path := r.URL.Path[len(pathTest):] - - // fmt.Printf("Requested path: %s\n", user_path) - - if !strings.HasSuffix(user_path, fileType) { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte("File not found")) - return - } - - bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) - if err != nil { - fmt.Println(err) - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Failed to load file")) - return - } - - w.Header().Set("Content-Type", contentType) - w.Write(bytes) - }) -} - -// TODO remove this -func (x Handle) ReadTypesFiles(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { - http.HandleFunc(pathTest, func(w http.ResponseWriter, r *http.Request) { - user_path := r.URL.Path[len(pathTest):] - - // fmt.Printf("Requested path: %s\n", user_path) - - found := false - index := -1 - - for i, fileType := range fileTypes { - if strings.HasSuffix(user_path, fileType) { - found = true - index = i - break - } - } - - if !found { - w.WriteHeader(http.StatusNotFound) - w.Write([]byte("File not found")) - return - } - - bytes, err := os.ReadFile(path.Join(baseFilePath, pathTest, user_path)) - if err != nil { - fmt.Println(err) - w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("Failed to load file")) - return - } - - w.Header().Set("Content-Type", contentTypes[index]) - w.Write(bytes) - }) -} - func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileTypes []string, contentTypes []string) { http.HandleFunc("/api"+pathTest, func(w http.ResponseWriter, r *http.Request) { r.URL.Path = strings.Replace(r.URL.Path, "/api", "", 1) @@ -586,7 +454,6 @@ func (x Handle) ReadTypesFilesApi(pathTest string, baseFilePath string, fileType } func NewHandler(db *sql.DB, config Config) *Handle { - var gets []HandleFunc var posts []HandleFunc var deletes []HandleFunc @@ -624,11 +491,11 @@ func NewHandler(db *sql.DB, config Config) *Handle { // context.Logger.Info("Parsing", "path", r.URL.Path) if r.Method == "GET" { - x.handleGets(context) + handleLoop(x.gets, context) } else if r.Method == "POST" { - x.handlePosts(context) + handleLoop(x.posts, context) } else if r.Method == "DELETE" { - x.handleDeletes(context) + handleLoop(x.deletes, context) } else if r.Method == "OPTIONS" { // do nothing } else { diff --git a/main.go b/main.go index b2a908d..7c5490d 100644 --- a/main.go +++ b/main.go @@ -51,10 +51,6 @@ func main() { } // TODO Handle this in other way - handle.StaticFiles("/styles/", ".css", "text/css") - handle.StaticFiles("/js/", ".js", "text/javascript") - handle.ReadFiles("/imgs/", "views", ".png", "image/png;") - handle.ReadTypesFiles("/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"}) handle.ReadTypesFilesApi("/savedData/", ".", []string{".png", ".jpeg"}, []string{"image/png", "image/jpeg"}) usersEndpints(db, handle) diff --git a/users.go b/users.go index 4861e4f..4feef5a 100644 --- a/users.go +++ b/users.go @@ -81,18 +81,12 @@ func generateToken(db *sql.DB, email string, password string, name string) (stri } func usersEndpints(db *sql.DB, handle *Handle) { - handle.Post("/login", func(c *Context) *Error { - type UserLogin struct { - Email string `json:"email"` - Password string `json:"password"` - } - - var dat UserLogin - - if err := c.ToJSON(&dat); err != nil { - return err - } + type UserLogin struct { + Email string `json:"email"` + Password string `json:"password"` + } + PostAuthJson(handle, "/login", dbtypes.User_Not_Auth, func(c *Context, dat *UserLogin) *Error { // TODO Give this to the generateToken function token, login := generateToken(db, dat.Email, dat.Password, "Logged in user") if !login { @@ -101,7 +95,7 @@ func usersEndpints(db *sql.DB, handle *Handle) { user, err := dbtypes.UserFromToken(c.Db, token) if err != nil { - return c.Error500(err) + return c.E500M("Failed to get user from token", err) } type UserReturn struct { @@ -123,43 +117,29 @@ func usersEndpints(db *sql.DB, handle *Handle) { return c.SendJSON(userReturn) }) - handle.Post("/register", func(c *Context) *Error { - type UserLogin struct { - Username string `json:"username"` - Email string `json:"email"` - Password string `json:"password"` + type UserRegister struct { + Username string `json:"username" validate:"required"` + Email string `json:"email" validate:"required"` + Password string `json:"password" validate:"required"` + } + PostAuthJson(handle, "/register", dbtypes.User_Not_Auth, func(c *Context, dat *UserRegister) *Error { + + var prevUser struct { + Username string + Email string } - - var dat UserLogin - - if err := c.ToJSON(&dat); err != nil { - return err - } - - if len(dat.Username) == 0 || len(dat.Password) == 0 || len(dat.Email) == 0 { - return c.SendJSONStatus(http.StatusBadRequest, "Please provide a valid json") - } - - rows, err := db.Query("select username, email from users where username=$1 or email=$2;", dat.Username, dat.Email) - if err != nil { - return c.Error500(err) - } - defer rows.Close() - - if rows.Next() { - var db_username string - var db_email string - err = rows.Scan(&db_username, &db_email) - if err != nil { - return c.Error500(err) - } - if db_email == dat.Email { + err := GetDBOnce(c, &prevUser, "users where username=$1 or email=$2;", dat.Username, dat.Email) + if err == NotFoundError { + // Do nothing the user does not exist and it's ok to create a new one + } else if err != nil { + return c.E500M("Falied to get user data", err) + } else { + if prevUser.Email == dat.Email { return c.SendJSONStatus(http.StatusBadRequest, "Email already in use!") } - if db_username == dat.Username { + if prevUser.Username == dat.Username { return c.SendJSONStatus(http.StatusBadRequest, "Username already in use!") } - panic("Unrechable") } if len([]byte(dat.Password)) > 68 { @@ -169,12 +149,12 @@ func usersEndpints(db *sql.DB, handle *Handle) { salt := generateSalt() hash_password, err := hashPassword(dat.Password, salt) if err != nil { - return c.Error500(err) + return c.E500M("Falied to store password", err) } _, err = db.Exec("insert into users (username, email, salt, password) values ($1, $2, $3, $4);", dat.Username, dat.Email, salt, hash_password) if err != nil { - return c.Error500(err) + return c.E500M("Falied to create user", err) } // TODO Give this to the generateToken function @@ -185,7 +165,7 @@ func usersEndpints(db *sql.DB, handle *Handle) { user, err := dbtypes.UserFromToken(c.Db, token) if err != nil { - return c.Error500(err) + return c.E500M("Falied to create user", err) } type UserReturn struct { @@ -208,14 +188,10 @@ func usersEndpints(db *sql.DB, handle *Handle) { }) // TODO allow admin users to update this data - handle.Get("/user/info", func(c *Context) *Error { - if !c.CheckAuthLevel(1) { - return nil - } - + handle.GetAuth("/user/info", int(dbtypes.User_Normal), func(c *Context) *Error { user, err := dbtypes.UserFromToken(c.Db, *c.Token) if err != nil { - return c.Error500(err) + return c.E500M("Falied to get user data", err) } type UserReturn struct { @@ -236,22 +212,11 @@ func usersEndpints(db *sql.DB, handle *Handle) { }) // Handles updating users - handle.Post("/user/info", func(c *Context) *Error { - if !c.CheckAuthLevel(int(dbtypes.User_Normal)) { - return nil - } - - type UserData struct { - Id string `json:"id"` - Email string `json:"email"` - } - - var dat UserData - - if err := c.ToJSON(&dat); err != nil { - return err - } - + type UpdateUserData struct { + Id string `json:"id"` + Email string `json:"email"` + } + PostAuthJson(handle, "/user/info", dbtypes.User_Normal, func(c *Context, dat *UpdateUserData) *Error { if dat.Id != c.User.Id && c.User.UserType != int(dbtypes.User_Admin) { return c.SendJSONStatus(403, "You need to be an admin to update another users account") } @@ -265,17 +230,14 @@ func usersEndpints(db *sql.DB, handle *Handle) { if err == NotFoundError { return c.JsonBadRequest("User does not exist") } else if err != nil { - return c.Error500(err) + return c.E500M("Falied to get data for user", err) } } - var data struct { - Id string - } - + var data JustId err := utils.GetDBOnce(c, &data, "users where email=$1", dat.Email) if err != nil && err != NotFoundError { - return c.Error500(err) + return c.E500M("Falied to get data for user", err) } if err != NotFoundError { @@ -288,7 +250,7 @@ func usersEndpints(db *sql.DB, handle *Handle) { _, err = c.Db.Exec("update users set email=$2 where id=$1", dat.Id, dat.Email) if err != nil { - return c.Error500(err) + return c.E500M("Failed to update data", err) } var user struct { @@ -300,7 +262,7 @@ func usersEndpints(db *sql.DB, handle *Handle) { err = utils.GetDBOnce(c, &user, "users where id=$1", dat.Id) if err != nil { - return c.Error500(err) + return c.E500M("Failed to get user data", err) } toReturnUser := dbtypes.User{ @@ -313,25 +275,12 @@ func usersEndpints(db *sql.DB, handle *Handle) { return c.SendJSON(toReturnUser) }) - handle.Post("/user/info/password", func(c *Context) *Error { - if !c.CheckAuthLevel(1) { - return nil - } - - var dat struct { - Old_Password string `json:"old_password"` - Password string `json:"password"` - Password2 string `json:"password2"` - } - - if err := c.ToJSON(&dat); err != nil { - return err - } - - if dat.Password == "" { - return c.JsonBadRequest("Password can not be empty") - } - + type PasswordUpdate struct { + Old_Password string `json:"old_password" validate:"required"` + Password string `json:"password" validate:"required"` + Password2 string `json:"password2" validate:"required"` + } + PostAuthJson(handle, "/user/info/password", dbtypes.User_Normal, func(c *Context, dat *PasswordUpdate) *Error { if dat.Password != dat.Password2 { return c.JsonBadRequest("New passwords did not match") } @@ -345,12 +294,12 @@ func usersEndpints(db *sql.DB, handle *Handle) { salt := generateSalt() hash_password, err := hashPassword(dat.Password, salt) if err != nil { - return c.Error500(err) + return c.E500M("Failed to parse the password", err) } _, err = db.Exec("update users set salt=$1, password=$2 where id=$3", salt, hash_password, c.User.Id) if err != nil { - return c.Error500(err) + return c.E500M("Failed to update password", err) } return c.SendJSON(c.User.Id) @@ -405,6 +354,4 @@ func usersEndpints(db *sql.DB, handle *Handle) { return c.SendJSON("Ok") }) - - // TODO create function to remove token } diff --git a/webpage/src/routes/register/+page.svelte b/webpage/src/routes/register/+page.svelte index c720189..726bf6f 100644 --- a/webpage/src/routes/register/+page.svelte +++ b/webpage/src/routes/register/+page.svelte @@ -1,104 +1,96 @@ - - Register - + Register
-
-

- Register -

-
-
- - - -
-
- - - -
-
- - - -
- {#if errorMessage} -
- {errorMessage} -
- {/if} - -
- - Login - -
-
+ + {#if errorMessage} +
+ {errorMessage} +
+ {/if} + +
+ Login + +