diff --git a/db.go b/db.go new file mode 100644 index 0000000..94fb070 --- /dev/null +++ b/db.go @@ -0,0 +1,6 @@ +package main + +import ( +) + + diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..6ad0d96 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,10 @@ +version: '3.1' + +services: + db: + image: docker.andr3h3nriqu3s.com/services/postgres + restart: always + environment: + POSTGRES_PASSWORD: verysafepassword + ports: + - "5432:5432" diff --git a/go.mod b/go.mod index dcfee58..e5e3279 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module andr3h3nriqu3s.com/m go 1.20 + +require ( + github.com/lib/pq v1.10.9 // indirect + golang.org/x/crypto v0.13.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c205580 --- /dev/null +++ b/go.sum @@ -0,0 +1,4 @@ +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= diff --git a/handler.go b/handler.go index 089b1c4..2a95461 100644 --- a/handler.go +++ b/handler.go @@ -1,11 +1,15 @@ package main import ( + "database/sql" + "errors" "fmt" "html/template" + "io" "log" "net/http" "strings" + "time" ) func baseLoadTemplate(base string, path string) (*template.Template, any) { @@ -74,6 +78,8 @@ func LoadHtml(writer http.ResponseWriter, path string, data interface{}) { } } +type AnyMap = map[string]interface{} + type Error struct { code int msg *string @@ -90,10 +96,10 @@ const ( func LoadBasedOnAnswer(ans AnswerType, w http.ResponseWriter, path string, data map[string]interface{}) { if ans == NORMAL { - LoadView(w, path, nil) + LoadView(w, path, data) return } else if ans == HTML { - LoadHtml(w, path, nil) + LoadHtml(w, path, data) return } else if ans == HTMLFULL { if data == nil { @@ -116,26 +122,43 @@ func LoadBasedOnAnswer(ans AnswerType, w http.ResponseWriter, path string, data type HandleFunc struct { path string mode AnswerType - fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error + fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error } type Handler interface { New() Startup() - Get(fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) - Post(fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) + Get(fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) + Post(fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) } type Handle struct { + db *sql.DB gets []HandleFunc posts []HandleFunc } -func handleError(err *Error, answerType AnswerType, w http.ResponseWriter) { +func decodeBody(r *http.Request) (string, *Error) { + body, err := io.ReadAll(r.Body) + if err == nil { + return "", &Error{code: http.StatusBadRequest} + } + + return string(body[:]), nil +} + +func handleError(err *Error, w http.ResponseWriter, context *Context) { + + data := context.toMap() + if err != nil { w.WriteHeader(err.code) - if err.code == 404 { - LoadBasedOnAnswer(answerType, w, "404.html", nil) + if err.code == http.StatusNotFound { + LoadBasedOnAnswer(context.Mode, w, "404.html", data) + return + } + if err.code == http.StatusBadRequest { + LoadBasedOnAnswer(context.Mode, w, "400.html", data) return } if err.msg != nil { @@ -144,7 +167,7 @@ func handleError(err *Error, answerType AnswerType, w http.ResponseWriter) { } } -func (x *Handle) Get(path string, fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) { +func (x *Handle) Get(path string, fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) { nhandler := HandleFunc{ fn: fn, @@ -155,7 +178,7 @@ func (x *Handle) Get(path string, fn func(mode AnswerType, w http.ResponseWriter x.gets = append(x.gets, nhandler) } -func (x *Handle) GetHTML(path string, fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) { +func (x *Handle) GetHTML(path string, fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) { nhandler := HandleFunc{ fn: fn, @@ -166,7 +189,7 @@ func (x *Handle) GetHTML(path string, fn func(mode AnswerType, w http.ResponseWr x.gets = append(x.gets, nhandler) } -func (x *Handle) GetJSON(path string, fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) { +func (x *Handle) GetJSON(path string, fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) { nhandler := HandleFunc{ fn: fn, @@ -177,7 +200,7 @@ func (x *Handle) GetJSON(path string, fn func(mode AnswerType, w http.ResponseWr x.gets = append(x.gets, nhandler) } -func (x *Handle) Post(path string, fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) { +func (x *Handle) Post(path string, fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) { nhandler := HandleFunc{ fn: fn, @@ -188,7 +211,7 @@ func (x *Handle) Post(path string, fn func(mode AnswerType, w http.ResponseWrite x.posts = append(x.posts, nhandler) } -func (x *Handle) PostHTML(path string, fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) { +func (x *Handle) PostHTML(path string, fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) { nhandler := HandleFunc{ fn: fn, @@ -199,7 +222,7 @@ func (x *Handle) PostHTML(path string, fn func(mode AnswerType, w http.ResponseW x.posts = append(x.posts, nhandler) } -func (x *Handle) PostJSON(path string, fn func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error) { +func (x *Handle) PostJSON(path string, fn func(w http.ResponseWriter, r *http.Request, c *Context) *Error) { nhandler := HandleFunc{ fn: fn, @@ -210,43 +233,130 @@ func (x *Handle) PostJSON(path string, fn func(mode AnswerType, w http.ResponseW x.posts = append(x.posts, nhandler) } -func (x *Handle) handleGets(ans AnswerType, w http.ResponseWriter, r *http.Request) { +func (x *Handle) handleGets(w http.ResponseWriter, r *http.Request, context *Context) { for _, s := range x.gets { - fmt.Printf("target: %s, paths: %s\n", s.path, r.URL.Path) - if s.path == r.URL.Path && ans&s.mode != 0 { - s.fn(ans, w, r) + if s.path == r.URL.Path && context.Mode&s.mode != 0 { + handleError(s.fn(w, r, context), w, context) return } } - w.WriteHeader(http.StatusNotFound) - LoadBasedOnAnswer(ans, w, "404.html", nil) + if context.Mode != HTMLFULL { + w.WriteHeader(http.StatusNotFound) + } + LoadBasedOnAnswer(context.Mode, w, "404.html", map[string]interface{}{ + "context": context, + }) } -func (x *Handle) handlePosts(ans AnswerType, w http.ResponseWriter, r *http.Request) { +func (x *Handle) handlePosts(w http.ResponseWriter, r *http.Request, context *Context) { for _, s := range x.posts { - if s.path == r.URL.Path && ans&s.mode != 0 { - s.fn(ans, w, r) + if s.path == r.URL.Path && context.Mode&s.mode != 0 { + handleError(s.fn(w, r, context), w, context) return } } - w.WriteHeader(http.StatusNotFound) - LoadBasedOnAnswer(ans, w, "404.html", nil) + if context.Mode != HTMLFULL { + w.WriteHeader(http.StatusNotFound) + } + LoadBasedOnAnswer(context.Mode, w, "404.html", map[string]interface{}{ + "context": context, + }) } -func AnswerTemplate(path string, data interface{}) func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error { - return func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error { - LoadBasedOnAnswer(mode, w, path, nil) +func AnswerTemplate(path string, data AnyMap) func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + return func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if data == nil { + LoadBasedOnAnswer(c.Mode, w, path, c.toMap()) + } else { + LoadBasedOnAnswer(c.Mode, w, path, c.addMap(data)) + } return nil } } -func NewHandler() *Handle { +type Context struct { + Token *string + User *User + Mode AnswerType +} - x := &Handle{} +func (c Context) addMap(m AnyMap) AnyMap { + m["Context"] = c; + return m; +} + +func (c *Context) toMap() map[string]interface{} { + return map[string]interface{}{ + "Context": c, + } +} + +func (c *Context) requireAuth(w http.ResponseWriter, r *http.Request) bool { + if c.User == nil { + return true; + } + return false; +} + +var LogoffError = errors.New("Invalid token!") + +func (x Handle) createContext(mode AnswerType, r *http.Request) (*Context, error) { + + var token *string + + for _, r := range r.Cookies() { + if r.Name == "auth" { + token = &r.Value + } + } + + if token == nil { + return &Context{ + Mode: mode, + }, nil + } + + user, err := userFromToken(x.db, *token) + if err != nil { + return nil, errors.Join(err, LogoffError) + } + + return &Context{token, user, mode}, nil +} + +func logoff(mode AnswerType, w http.ResponseWriter, r *http.Request) { + // Delete cookie + cookie := &http.Cookie{ + Name: "auth", + Value: "", + Expires: time.Unix(0, 0), + } + http.SetCookie(w, cookie) + + // Setup response + w.Header().Set("Location", "/login") + if mode == JSON { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte("\"Bye Bye\"")); + return + } + if mode & (HTMLFULL | HTML) != 0 { + w.WriteHeader(http.StatusUnauthorized); + w.Write([]byte("Bye Bye")); + } else { + w.WriteHeader(http.StatusSeeOther); + } +} + +func NewHandler(db *sql.DB) *Handle { + + var gets []HandleFunc + var posts []HandleFunc + x := &Handle{ db, gets, posts, } http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // Decide answertype ans := NORMAL - if r.Header.Get("Request-Type") == "htmlfull" { ans = HTMLFULL } @@ -255,12 +365,19 @@ func NewHandler() *Handle { } //TODO JSON + //Login state + context, err := x.createContext(ans, r) + if err != nil { + logoff(ans, w, r) + return + } + if r.Method == "GET" { - x.handleGets(ans, w, r) + x.handleGets(w, r, context) return } if r.Method == "POST" { - x.handlePosts(ans, w, r) + x.handlePosts(w, r, context) return } panic("TODO handle: " + r.Method) diff --git a/main.go b/main.go index 62438dc..92808d9 100644 --- a/main.go +++ b/main.go @@ -2,25 +2,36 @@ package main import ( "fmt" - "net/http" + "database/sql" + + _ "github.com/lib/pq" +) + +const ( + host = "localhost" + port = 5432 + user = "postgres" + password = "verysafepassword" + dbname = "aistuff" ) func main() { + psqlInfo := fmt.Sprintf("host=%s port=%d user=%s "+ + "password=%s dbname=%s sslmode=disable", + host, port, user, password, dbname) + + db, err := sql.Open("postgres", psqlInfo) + if err != nil { + panic(err) + } + defer db.Close() fmt.Println("Starting server on :8000!") - handle := NewHandler() + handle := NewHandler(db) handle.GetHTML("/", AnswerTemplate("index.html", nil)) - handle.GetHTML("/login", AnswerTemplate("login.html", nil)) - handle.Post("/login", func(mode AnswerType, w http.ResponseWriter, r *http.Request) *Error { - if mode == JSON { - return &Error{code: 404} - } - w.Header().Set("Location", "/") - w.WriteHeader(http.StatusSeeOther) - return nil - }) + usersEndpints(db, handle) handle.Startup() } diff --git a/sql/base.sql b/sql/base.sql new file mode 100644 index 0000000..bc269ff --- /dev/null +++ b/sql/base.sql @@ -0,0 +1 @@ +CREATE DATABASE aistuff; diff --git a/sql/user.sql b/sql/user.sql new file mode 100644 index 0000000..8b37804 --- /dev/null +++ b/sql/user.sql @@ -0,0 +1,22 @@ +-- drop table if exists tokens; +-- drop table if exists users; +create table if not exists users ( + id uuid primary key default gen_random_uuid(), + user_type integer default 0, + username varchar (120) not null, + email varchar (120) not null, + salt char (8) not null, + password char (60) not null, + created_on timestamp default current_timestamp, + updated_at timestamp default current_timestamp, + lastlogin_at timestamp default current_timestamp +); + +--drop table if exists tokens; +create table if not exists tokens ( + token varchar (120) primary key, + user_id uuid references users (id) on delete cascade, + time_to_live integer default 86400, + emit_day timestamp default current_timestamp +); + diff --git a/users.go b/users.go new file mode 100644 index 0000000..cfa0213 --- /dev/null +++ b/users.go @@ -0,0 +1,246 @@ +package main + +import ( + "crypto/rand" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "time" + + "golang.org/x/crypto/bcrypt" +) + +type User struct { + id string + username string + email string + user_type int +} + +var ErrUserNotFound = errors.New("User Not found") + +func userFromToken(db *sql.DB, token string) (*User, error) { + row, err := db.Query("select users.id, users.username, users.email, users.user_type from users inner join tokens on tokens.user_id = users.id where tokens.token = $1;", token) + + if err != nil { + return nil, err + } + + var id string + var username string + var email string + var user_type int + + if !row.Next() { + return nil, ErrUserNotFound + } + + err = row.Scan(&id, &username, &email, &user_type) + if err != nil { + return nil, err + } + + return &User{id, username, email, user_type}, nil +} + +func generateSalt() string { + salt := make([]byte, 4) + _, err := io.ReadFull(rand.Reader, salt) + if err != nil { + panic("TODO handle this better") + } + return hex.EncodeToString(salt) +} + +func hashPassword(password string, salt string) (string, error) { + bytes_salt, err := hex.DecodeString(salt) + if err != nil { + return "", err + } + bytes, err := bcrypt.GenerateFromPassword(append([]byte(password), bytes_salt...), 14) + return string(bytes), err +} + +func genToken() string { + token := make([]byte, 60) + _, err := io.ReadFull(rand.Reader, token) + if err != nil { + panic("TODO handle this better") + } + return hex.EncodeToString(token) +} + +func generateToken(db *sql.DB, email string, password string) (string, bool) { + + row, err := db.Query("select id, salt, password from users where email = $1;", email) + if err != nil || !row.Next() { + return "", false + } + + var db_id string + var db_salt string + var db_password string + err = row.Scan(&db_id, &db_salt, &db_password) + if err != nil { + return "", false + } + + bytes_salt, err := hex.DecodeString(db_salt) + if err != nil { + panic("TODO handle better! Somethign is wrong with salt being stored in the database") + } + + if bcrypt.CompareHashAndPassword([]byte(db_password), append([]byte(password), bytes_salt...)) != nil { + return "", false + } + + token := genToken() + + _, err = db.Exec("insert into tokens (user_id, token) values ($1, $2);", db_id, token) + if err != nil { + return "", false + } + + return token, true +} + +func usersEndpints(db *sql.DB, handle *Handle) { + handle.GetHTML("/login", AnswerTemplate("login.html", nil)) + handle.Post("/login", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if c.Mode == JSON { + fmt.Println("Handle JSON") + return &Error{code: 404} + } + + r.ParseForm() + f := r.Form + + if checkEmpty(f, "email") || checkEmpty(f, "password") { + LoadBasedOnAnswer(c.Mode, w, "login.html", c.addMap(AnyMap{ + "Submited": true, + })) + return nil + } + + email := f.Get("email") + password := f.Get("password") + + // TODO Give this to the generateToken function + expiration := time.Now().Add(24 * time.Hour) + token, login := generateToken(db, email, password) + if !login { + LoadBasedOnAnswer(c.Mode, w, "login.html", c.addMap(AnyMap{ + "Submited": true, + "NoUserOrPassword": true, + "Email": email, + })) + return nil + } + + cookie := &http.Cookie{Name: "auth", Value: token, HttpOnly: false, Expires: expiration} + http.SetCookie(w, cookie) + + w.Header().Set("Location", "/") + w.WriteHeader(http.StatusSeeOther) + return nil + }) + + handle.GetHTML("/register", AnswerTemplate("register.html", nil)) + handle.Post("/register", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if c.Mode == JSON { + return &Error{code: http.StatusNotFound} + } + + r.ParseForm() + f := r.Form + + if checkEmpty(f, "email") || checkEmpty(f, "password") || checkEmpty(f, "username") { + LoadBasedOnAnswer(c.Mode, w, "register.html", AnyMap{ + "Submited": true, + }) + return nil + } + + email := f.Get("email") + username := f.Get("username") + password := f.Get("password") + + rows, err := db.Query("select username, email from users where username=$1 or email=$2;", username, email) + if err != nil { + panic("TODO handle this") + } + defer rows.Close() + + if rows.Next() { + var db_username string + var db_email string + err = rows.Scan(&db_username, &db_email) + if err != nil { + panic("TODO handle this better") + } + LoadBasedOnAnswer(c.Mode, w, "register.html", AnyMap{ + "Submited": true, + "Email": email, + "Username": username, + "EmailError": db_email == email, + "UserError": db_username == username, + }) + return nil + } + + if len([]byte(password)) > 68 { + LoadBasedOnAnswer(c.Mode, w, "register.html", AnyMap{ + "Submited": true, + "Email": email, + "Username": username, + "PasswordToLong": true, + }) + return nil + } + + salt := generateSalt() + hash_password, err := hashPassword(password, salt) + if err != nil { + return &Error{ + code: http.StatusInternalServerError, + } + } + + _, err = db.Exec("insert into users (username, email, salt, password) values ($1, $2, $3, $4);", username, email, salt, hash_password) + + if err != nil { + return &Error{ + code: http.StatusInternalServerError, + } + } + + // TODO Give this to the generateToken function + expiration := time.Now().Add(24 * time.Hour) + token, login := generateToken(db, email, password) + + if !login { + msg := "Login failed" + return &Error{ + code: http.StatusInternalServerError, + msg: &msg, + } + } + + cookie := &http.Cookie{Name: "auth", Value: token, HttpOnly: false, Expires: expiration} + http.SetCookie(w, cookie) + w.Header().Set("Location", "/") + w.WriteHeader(http.StatusSeeOther) + return nil + }) + + handle.Get("/logout", func(w http.ResponseWriter, r *http.Request, c *Context) *Error { + if c.Mode == JSON { + panic("TODO handle json") + } + logoff(c.Mode, w, r) + return nil + }) +} diff --git a/utils.go b/utils.go new file mode 100644 index 0000000..6ebba69 --- /dev/null +++ b/utils.go @@ -0,0 +1,7 @@ +package main + +import "net/url" + +func checkEmpty(f url.Values, path string) bool { + return !f.Has(path) || f.Get(path) == "" +} diff --git a/views/js/main.js b/views/js/main.js index 3f72e55..5524eef 100644 --- a/views/js/main.js +++ b/views/js/main.js @@ -5,6 +5,10 @@ function load() { }); } } - window.onload = load; htmx.on('htmx:afterSwap', load); +htmx.on('htmx:beforeSwap', (env) => { + if (env.detail.xhr.status === 401) { + window.location = "/login" + } +}); diff --git a/views/login.html b/views/login.html index 84a88c7..2d7e615 100644 --- a/views/login.html +++ b/views/login.html @@ -6,18 +6,27 @@