diff --git a/backend/controllers/add_task.go b/backend/controllers/add_task.go index e44aa727..376fbd6b 100644 --- a/backend/controllers/add_task.go +++ b/backend/controllers/add_task.go @@ -25,13 +25,18 @@ var GlobalJobQueue *JobQueue // @Router /add-task [post] func AddTaskHandler(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest) return } defer r.Body.Close() - // fmt.Printf("Raw request body: %s\n", string(body)) var requestBody models.AddTaskRequestBody @@ -40,9 +45,14 @@ func AddTaskHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("error decoding request body: %v", err), http.StatusBadRequest) return } - email := requestBody.Email - encryptionSecret := requestBody.EncryptionSecret - uuid := requestBody.UUID + + if requestBody.Email != "" || requestBody.UUID != "" { + if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + description := requestBody.Description project := requestBody.Project priority := requestBody.Priority @@ -61,7 +71,6 @@ func AddTaskHandler(w http.ResponseWriter, r *http.Request) { return } - // Validate dependencies if err := utils.ValidateDependencies(depends, ""); err != nil { http.Error(w, fmt.Sprintf("Invalid dependencies: %v", err), http.StatusBadRequest) return diff --git a/backend/controllers/auth_helpers.go b/backend/controllers/auth_helpers.go new file mode 100644 index 00000000..6074bd86 --- /dev/null +++ b/backend/controllers/auth_helpers.go @@ -0,0 +1,42 @@ +package controllers + +import ( + "errors" + "net/http" +) + +func ValidateUserCredentials(r *http.Request, requestEmail, requestUUID string) error { + userInfo, ok := r.Context().Value("user").(map[string]interface{}) + if !ok { + return errors.New("user context not found") + } + + sessionEmail, emailOk := userInfo["email"].(string) + sessionUUID, uuidOk := userInfo["uuid"].(string) + + if !emailOk || !uuidOk { + return errors.New("invalid user session data") + } + + if sessionEmail != requestEmail || sessionUUID != requestUUID { + return errors.New("credentials do not match authenticated user") + } + return nil +} + +func GetSessionCredentials(r *http.Request) (email, uuid, encryptionSecret string, err error) { + userInfo, ok := r.Context().Value("user").(map[string]interface{}) + if !ok { + return "", "", "", errors.New("user context not found") + } + + email, emailOk := userInfo["email"].(string) + uuid, uuidOk := userInfo["uuid"].(string) + encryptionSecret, secretOk := userInfo["encryption_secret"].(string) + + if !emailOk || !uuidOk || !secretOk { + return "", "", "", errors.New("incomplete user session data") + } + + return email, uuid, encryptionSecret, nil +} diff --git a/backend/controllers/auth_helpers_test.go b/backend/controllers/auth_helpers_test.go new file mode 100644 index 00000000..4e62b03b --- /dev/null +++ b/backend/controllers/auth_helpers_test.go @@ -0,0 +1,167 @@ +package controllers + +import ( + "context" + "net/http/httptest" + "testing" +) + +func TestValidateUserCredentials_MatchingCredentials(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + ctx := context.WithValue(req.Context(), "user", map[string]interface{}{ + "email": "test@example.com", + "uuid": "test-uuid-123", + "encryption_secret": "test-secret", + }) + req = req.WithContext(ctx) + + err := ValidateUserCredentials(req, "test@example.com", "test-uuid-123") + if err != nil { + t.Errorf("Expected no error, got %v", err) + } +} + +func TestValidateUserCredentials_MismatchedEmail(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + ctx := context.WithValue(req.Context(), "user", map[string]interface{}{ + "email": "test@example.com", + "uuid": "test-uuid-123", + "encryption_secret": "test-secret", + }) + req = req.WithContext(ctx) + + err := ValidateUserCredentials(req, "wrong@example.com", "test-uuid-123") + if err == nil { + t.Error("Expected error for mismatched email") + } + if err.Error() != "credentials do not match authenticated user" { + t.Errorf("Expected specific error message, got %v", err) + } +} + +func TestValidateUserCredentials_MismatchedUUID(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + ctx := context.WithValue(req.Context(), "user", map[string]interface{}{ + "email": "test@example.com", + "uuid": "test-uuid-123", + "encryption_secret": "test-secret", + }) + req = req.WithContext(ctx) + + err := ValidateUserCredentials(req, "test@example.com", "wrong-uuid") + if err == nil { + t.Error("Expected error for mismatched UUID") + } + if err.Error() != "credentials do not match authenticated user" { + t.Errorf("Expected specific error message, got %v", err) + } +} + +func TestValidateUserCredentials_NoContext(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + + err := ValidateUserCredentials(req, "test@example.com", "test-uuid-123") + if err == nil { + t.Error("Expected error for missing context") + } + if err.Error() != "user context not found" { + t.Errorf("Expected 'user context not found', got %v", err) + } +} + +func TestValidateUserCredentials_InvalidSessionData(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + ctx := context.WithValue(req.Context(), "user", map[string]interface{}{ + "email": 12345, + "uuid": "test-uuid-123", + }) + req = req.WithContext(ctx) + + err := ValidateUserCredentials(req, "test@example.com", "test-uuid-123") + if err == nil { + t.Error("Expected error for invalid session data") + } + if err.Error() != "invalid user session data" { + t.Errorf("Expected 'invalid user session data', got %v", err) + } +} + +func TestGetSessionCredentials_ValidSession(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + ctx := context.WithValue(req.Context(), "user", map[string]interface{}{ + "email": "test@example.com", + "uuid": "test-uuid-123", + "encryption_secret": "test-secret-456", + }) + req = req.WithContext(ctx) + + email, uuid, secret, err := GetSessionCredentials(req) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if email != "test@example.com" { + t.Errorf("Expected email test@example.com, got %s", email) + } + if uuid != "test-uuid-123" { + t.Errorf("Expected uuid test-uuid-123, got %s", uuid) + } + if secret != "test-secret-456" { + t.Errorf("Expected secret test-secret-456, got %s", secret) + } +} + +func TestGetSessionCredentials_NoContext(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + + email, uuid, secret, err := GetSessionCredentials(req) + if err == nil { + t.Error("Expected error for missing context") + } + if err.Error() != "user context not found" { + t.Errorf("Expected 'user context not found', got %v", err) + } + if email != "" || uuid != "" || secret != "" { + t.Error("Expected empty strings for credentials") + } +} + +func TestGetSessionCredentials_IncompleteData(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + ctx := context.WithValue(req.Context(), "user", map[string]interface{}{ + "email": "test@example.com", + "uuid": "test-uuid-123", + }) + req = req.WithContext(ctx) + + email, uuid, secret, err := GetSessionCredentials(req) + if err == nil { + t.Error("Expected error for incomplete session data") + } + if err.Error() != "incomplete user session data" { + t.Errorf("Expected 'incomplete user session data', got %v", err) + } + if email != "" || uuid != "" || secret != "" { + t.Error("Expected empty strings for credentials") + } +} + +func TestGetSessionCredentials_InvalidDataTypes(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + ctx := context.WithValue(req.Context(), "user", map[string]interface{}{ + "email": 12345, + "uuid": true, + "encryption_secret": []string{"invalid"}, + }) + req = req.WithContext(ctx) + + email, uuid, secret, err := GetSessionCredentials(req) + if err == nil { + t.Error("Expected error for invalid data types") + } + if err.Error() != "incomplete user session data" { + t.Errorf("Expected 'incomplete user session data', got %v", err) + } + if email != "" || uuid != "" || secret != "" { + t.Error("Expected empty strings for credentials") + } +} diff --git a/backend/controllers/complete_task.go b/backend/controllers/complete_task.go index 63ddffb1..30c4c35a 100644 --- a/backend/controllers/complete_task.go +++ b/backend/controllers/complete_task.go @@ -22,6 +22,12 @@ import ( // @Router /complete-task [post] func CompleteTaskHandler(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest) @@ -29,8 +35,6 @@ func CompleteTaskHandler(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - // fmt.Printf("Raw request body: %s\n", string(body)) - var requestBody models.CompleteTaskRequestBody err = json.Unmarshal(body, &requestBody) @@ -39,9 +43,13 @@ func CompleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } - email := requestBody.Email - encryptionSecret := requestBody.EncryptionSecret - uuid := requestBody.UUID + if requestBody.Email != "" || requestBody.UUID != "" { + if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + taskuuid := requestBody.TaskUUID if taskuuid == "" { @@ -49,10 +57,6 @@ func CompleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } - // if err := tw.CompleteTaskInTaskwarrior(email, encryptionSecret, uuid, taskuuid); err != nil { - // http.Error(w, err.Error(), http.StatusInternalServerError) - // return - // } logStore := models.GetLogStore() job := Job{ Name: "Complete Task", diff --git a/backend/controllers/complete_tasks.go b/backend/controllers/complete_tasks.go index b971f50b..0af7a58e 100644 --- a/backend/controllers/complete_tasks.go +++ b/backend/controllers/complete_tasks.go @@ -26,6 +26,12 @@ func BulkCompleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest) @@ -40,9 +46,13 @@ func BulkCompleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } - email := requestBody.Email - encryptionSecret := requestBody.EncryptionSecret - uuid := requestBody.UUID + if requestBody.Email != "" || requestBody.UUID != "" { + if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + taskUUIDs := requestBody.TaskUUIDs if len(taskUUIDs) == 0 { @@ -52,7 +62,6 @@ func BulkCompleteTaskHandler(w http.ResponseWriter, r *http.Request) { logStore := models.GetLogStore() - // Create a *single* job for all UUIDs job := Job{ Name: "Bulk Complete Tasks", Execute: func() error { diff --git a/backend/controllers/delete_task.go b/backend/controllers/delete_task.go index 5af79555..6b5ea36c 100644 --- a/backend/controllers/delete_task.go +++ b/backend/controllers/delete_task.go @@ -22,6 +22,12 @@ import ( // @Router /delete-task [post] func DeleteTaskHandler(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest) @@ -37,9 +43,13 @@ func DeleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } - email := requestBody.Email - encryptionSecret := requestBody.EncryptionSecret - uuid := requestBody.UUID + if requestBody.Email != "" || requestBody.UUID != "" { + if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + taskuuid := requestBody.TaskUUID if taskuuid == "" { @@ -47,10 +57,6 @@ func DeleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } - // if err := tw.DeleteTaskInTaskwarrior(email, encryptionSecret, uuid, taskuuid); err != nil { - // http.Error(w, err.Error(), http.StatusInternalServerError) - // return - // } logStore := models.GetLogStore() job := Job{ Name: "Delete Task", diff --git a/backend/controllers/delete_tasks.go b/backend/controllers/delete_tasks.go index f1f51cea..db64f835 100644 --- a/backend/controllers/delete_tasks.go +++ b/backend/controllers/delete_tasks.go @@ -26,6 +26,12 @@ func BulkDeleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest) @@ -40,9 +46,13 @@ func BulkDeleteTaskHandler(w http.ResponseWriter, r *http.Request) { return } - email := requestBody.Email - encryptionSecret := requestBody.EncryptionSecret - uuid := requestBody.UUID + if requestBody.Email != "" || requestBody.UUID != "" { + if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + taskUUIDs := requestBody.TaskUUIDs if len(taskUUIDs) == 0 { diff --git a/backend/controllers/edit_task.go b/backend/controllers/edit_task.go index f22cff37..c5b0f081 100644 --- a/backend/controllers/edit_task.go +++ b/backend/controllers/edit_task.go @@ -23,6 +23,12 @@ import ( // @Router /edit-task [post] func EditTaskHandler(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest) @@ -30,8 +36,6 @@ func EditTaskHandler(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - // fmt.Printf("Raw request body: %s\n", string(body)) - var requestBody models.EditTaskRequestBody err = json.Unmarshal(body, &requestBody) @@ -40,9 +44,13 @@ func EditTaskHandler(w http.ResponseWriter, r *http.Request) { return } - email := requestBody.Email - encryptionSecret := requestBody.EncryptionSecret - uuid := requestBody.UUID + if requestBody.Email != "" || requestBody.UUID != "" { + if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + taskUUID := requestBody.TaskUUID description := requestBody.Description tags := requestBody.Tags @@ -61,7 +69,6 @@ func EditTaskHandler(w http.ResponseWriter, r *http.Request) { return } - // Validate dependencies if err := utils.ValidateDependencies(depends, uuid); err != nil { http.Error(w, fmt.Sprintf("Invalid dependencies: %v", err), http.StatusBadRequest) return diff --git a/backend/controllers/get_tasks.go b/backend/controllers/get_tasks.go index 09ef598c..cd20efb7 100644 --- a/backend/controllers/get_tasks.go +++ b/backend/controllers/get_tasks.go @@ -21,17 +21,25 @@ import ( // @Failure 500 {string} string "Failed to fetch tasks at backend" // @Router /tasks [get] func TasksHandler(w http.ResponseWriter, r *http.Request) { - email := r.URL.Query().Get("email") - encryptionSecret := r.URL.Query().Get("encryptionSecret") - UUID := r.URL.Query().Get("UUID") - origin := os.Getenv("CONTAINER_ORIGIN") - if email == "" || encryptionSecret == "" || UUID == "" { - http.Error(w, "Missing required parameters", http.StatusBadRequest) + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) return } + queryEmail := r.URL.Query().Get("email") + queryUUID := r.URL.Query().Get("UUID") + + if queryEmail != "" || queryUUID != "" { + if err := ValidateUserCredentials(r, queryEmail, queryUUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + if r.Method == http.MethodGet { - tasks, _ := tw.FetchTasksFromTaskwarrior(email, encryptionSecret, origin, UUID) + origin := os.Getenv("CONTAINER_ORIGIN") + tasks, _ := tw.FetchTasksFromTaskwarrior(email, encryptionSecret, origin, uuid) if tasks == nil { http.Error(w, "Failed to fetch tasks at backend", http.StatusInternalServerError) return diff --git a/backend/controllers/modify_task.go b/backend/controllers/modify_task.go index e3f9645a..2c7e6516 100644 --- a/backend/controllers/modify_task.go +++ b/backend/controllers/modify_task.go @@ -23,6 +23,12 @@ import ( // @Router /modify-task [post] func ModifyTaskHandler(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost { + email, uuid, encryptionSecret, err := GetSessionCredentials(r) + if err != nil { + http.Error(w, "Authentication required", http.StatusUnauthorized) + return + } + body, err := io.ReadAll(r.Body) if err != nil { http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest) @@ -30,8 +36,6 @@ func ModifyTaskHandler(w http.ResponseWriter, r *http.Request) { } defer r.Body.Close() - // fmt.Printf("Raw request body: %s\n", string(body)) - var requestBody models.ModifyTaskRequestBody err = json.Unmarshal(body, &requestBody) @@ -39,9 +43,14 @@ func ModifyTaskHandler(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("error decoding request body: %v", err), http.StatusBadRequest) return } - email := requestBody.Email - encryptionSecret := requestBody.EncryptionSecret - uuid := requestBody.UUID + + if requestBody.Email != "" || requestBody.UUID != "" { + if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil { + http.Error(w, "Invalid credentials", http.StatusForbidden) + return + } + } + taskUUID := requestBody.TaskUUID description := requestBody.Description project := requestBody.Project @@ -60,17 +69,11 @@ func ModifyTaskHandler(w http.ResponseWriter, r *http.Request) { return } - // Validate dependencies if err := utils.ValidateDependencies(depends, uuid); err != nil { http.Error(w, fmt.Sprintf("Invalid dependencies: %v", err), http.StatusBadRequest) return } - // if err := tw.ModifyTaskInTaskwarrior(uuid, description, project, priority, status, due, email, encryptionSecret, taskID); err != nil { - // http.Error(w, err.Error(), http.StatusInternalServerError) - // return - // } - logStore := models.GetLogStore() job := Job{ Name: "Modify Task", diff --git a/backend/main.go b/backend/main.go index 05fb244a..8d7832ca 100644 --- a/backend/main.go +++ b/backend/main.go @@ -88,19 +88,21 @@ func main() { limiter := middleware.NewRateLimiter(30*time.Second, 50) rateLimitedHandler := middleware.RateLimitMiddleware(limiter) + authMiddleware := middleware.RequireAuth(store) + mux.Handle("/auth/oauth", rateLimitedHandler(http.HandlerFunc(app.OAuthHandler))) mux.Handle("/auth/callback", rateLimitedHandler(http.HandlerFunc(app.OAuthCallbackHandler))) mux.Handle("/api/user", rateLimitedHandler(http.HandlerFunc(app.UserInfoHandler))) mux.Handle("/auth/logout", rateLimitedHandler(http.HandlerFunc(app.LogoutHandler))) - mux.Handle("/tasks", rateLimitedHandler(http.HandlerFunc(controllers.TasksHandler))) - mux.Handle("/add-task", rateLimitedHandler(http.HandlerFunc(controllers.AddTaskHandler))) - mux.Handle("/edit-task", rateLimitedHandler(http.HandlerFunc(controllers.EditTaskHandler))) - mux.Handle("/modify-task", rateLimitedHandler(http.HandlerFunc(controllers.ModifyTaskHandler))) - mux.Handle("/complete-task", rateLimitedHandler(http.HandlerFunc(controllers.CompleteTaskHandler))) - mux.Handle("/delete-task", rateLimitedHandler(http.HandlerFunc(controllers.DeleteTaskHandler))) + mux.Handle("/tasks", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.TasksHandler)))) + mux.Handle("/add-task", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.AddTaskHandler)))) + mux.Handle("/edit-task", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.EditTaskHandler)))) + mux.Handle("/modify-task", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.ModifyTaskHandler)))) + mux.Handle("/complete-task", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.CompleteTaskHandler)))) + mux.Handle("/delete-task", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.DeleteTaskHandler)))) mux.Handle("/sync/logs", rateLimitedHandler(http.HandlerFunc(controllers.SyncLogsHandler))) - mux.Handle("/complete-tasks", rateLimitedHandler(http.HandlerFunc(controllers.BulkCompleteTaskHandler))) - mux.Handle("/delete-tasks", rateLimitedHandler(http.HandlerFunc(controllers.BulkDeleteTaskHandler))) + mux.Handle("/complete-tasks", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.BulkCompleteTaskHandler)))) + mux.Handle("/delete-tasks", rateLimitedHandler(authMiddleware(http.HandlerFunc(controllers.BulkDeleteTaskHandler)))) mux.HandleFunc("/health", controllers.HealthCheckHandler) diff --git a/backend/middleware/auth.go b/backend/middleware/auth.go new file mode 100644 index 00000000..e9de1af9 --- /dev/null +++ b/backend/middleware/auth.go @@ -0,0 +1,28 @@ +package middleware + +import ( + "context" + "net/http" + + "github.com/gorilla/sessions" +) + +func RequireAuth(sessionStore *sessions.CookieStore) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, _ := sessionStore.Get(r, "session-name") + userInfo, ok := session.Values["user"].(map[string]interface{}) + if !ok || userInfo == nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + ctx := context.WithValue(r.Context(), "user", userInfo) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +func GetUserFromContext(r *http.Request) (map[string]interface{}, bool) { + userInfo, ok := r.Context().Value("user").(map[string]interface{}) + return userInfo, ok +} diff --git a/backend/middleware/auth_test.go b/backend/middleware/auth_test.go new file mode 100644 index 00000000..7fe112e1 --- /dev/null +++ b/backend/middleware/auth_test.go @@ -0,0 +1,159 @@ +package middleware + +import ( + "encoding/gob" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/sessions" +) + +func init() { + gob.Register(map[string]interface{}{}) +} + +func TestRequireAuth_ValidSession(t *testing.T) { + store := sessions.NewCookieStore([]byte("test-secret-key-32-bytes-long!")) + middleware := RequireAuth(store) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userInfo, ok := GetUserFromContext(r) + if !ok { + t.Fatal("Expected user info in context") + } + if userInfo["email"] != "test@example.com" { + t.Errorf("Expected email test@example.com, got %v", userInfo["email"]) + } + w.WriteHeader(http.StatusOK) + })) + + reqSetup := httptest.NewRequest("GET", "/test", nil) + rrSetup := httptest.NewRecorder() + + session, _ := store.Get(reqSetup, "session-name") + session.Values["user"] = map[string]interface{}{ + "email": "test@example.com", + "uuid": "test-uuid", + "encryption_secret": "test-secret", + } + err := session.Save(reqSetup, rrSetup) + if err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + req := httptest.NewRequest("GET", "/test", nil) + for _, cookie := range rrSetup.Result().Cookies() { + req.AddCookie(cookie) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } +} + +func TestRequireAuth_NoSession(t *testing.T) { + store := sessions.NewCookieStore([]byte("test-secret-key-32-bytes-long!")) + middleware := RequireAuth(store) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Handler should not be called without valid session") + })) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", rr.Code) + } +} + +func TestRequireAuth_InvalidSession(t *testing.T) { + store := sessions.NewCookieStore([]byte("test-secret-key-32-bytes-long!")) + middleware := RequireAuth(store) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Handler should not be called with invalid session") + })) + + reqSetup := httptest.NewRequest("GET", "/test", nil) + rrSetup := httptest.NewRecorder() + + session, _ := store.Get(reqSetup, "session-name") + session.Values["user"] = "invalid-data" + session.Save(reqSetup, rrSetup) + + req := httptest.NewRequest("GET", "/test", nil) + for _, cookie := range rrSetup.Result().Cookies() { + req.AddCookie(cookie) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Errorf("Expected status 401, got %d", rr.Code) + } +} + +func TestGetUserFromContext_ValidContext(t *testing.T) { + store := sessions.NewCookieStore([]byte("test-secret-key-32-bytes-long!")) + middleware := RequireAuth(store) + + handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userInfo, ok := GetUserFromContext(r) + if !ok { + t.Fatal("Expected user info in context") + } + if userInfo["email"] != "test@example.com" { + t.Errorf("Expected email test@example.com, got %v", userInfo["email"]) + } + if userInfo["uuid"] != "test-uuid" { + t.Errorf("Expected uuid test-uuid, got %v", userInfo["uuid"]) + } + w.WriteHeader(http.StatusOK) + })) + + reqSetup := httptest.NewRequest("GET", "/test", nil) + rrSetup := httptest.NewRecorder() + + session, _ := store.Get(reqSetup, "session-name") + session.Values["user"] = map[string]interface{}{ + "email": "test@example.com", + "uuid": "test-uuid", + "encryption_secret": "test-secret", + } + err := session.Save(reqSetup, rrSetup) + if err != nil { + t.Fatalf("Failed to save session: %v", err) + } + + req := httptest.NewRequest("GET", "/test", nil) + for _, cookie := range rrSetup.Result().Cookies() { + req.AddCookie(cookie) + } + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("Expected status 200, got %d", rr.Code) + } +} + +func TestGetUserFromContext_NoContext(t *testing.T) { + req := httptest.NewRequest("GET", "/test", nil) + + userInfo, ok := GetUserFromContext(req) + if ok { + t.Error("Expected no user info in context") + } + if userInfo != nil { + t.Error("Expected nil user info") + } +}