Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions backend/controllers/add_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ var GlobalJobQueue *JobQueue
// @Router /add-task [post]
func AddTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
email, uuid, encryptionSecret, err := GetSessionCredentials(r)
if err != nil {
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest)
return
}
defer r.Body.Close()
// fmt.Printf("Raw request body: %s\n", string(body))

var requestBody models.AddTaskRequestBody

Expand All @@ -40,9 +45,14 @@ func AddTaskHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("error decoding request body: %v", err), http.StatusBadRequest)
return
}
email := requestBody.Email
encryptionSecret := requestBody.EncryptionSecret
uuid := requestBody.UUID

if requestBody.Email != "" || requestBody.UUID != "" {
if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil {
http.Error(w, "Invalid credentials", http.StatusForbidden)
return
}
}

description := requestBody.Description
project := requestBody.Project
priority := requestBody.Priority
Expand All @@ -61,7 +71,6 @@ func AddTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}

// Validate dependencies
if err := utils.ValidateDependencies(depends, ""); err != nil {
http.Error(w, fmt.Sprintf("Invalid dependencies: %v", err), http.StatusBadRequest)
return
Expand Down
42 changes: 42 additions & 0 deletions backend/controllers/auth_helpers.go
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

implemented creds validation helpers

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package controllers

import (
"errors"
"net/http"
)

func ValidateUserCredentials(r *http.Request, requestEmail, requestUUID string) error {
userInfo, ok := r.Context().Value("user").(map[string]interface{})
if !ok {
return errors.New("user context not found")
}

sessionEmail, emailOk := userInfo["email"].(string)
sessionUUID, uuidOk := userInfo["uuid"].(string)

if !emailOk || !uuidOk {
return errors.New("invalid user session data")
}

if sessionEmail != requestEmail || sessionUUID != requestUUID {
return errors.New("credentials do not match authenticated user")
}
return nil
}

func GetSessionCredentials(r *http.Request) (email, uuid, encryptionSecret string, err error) {
userInfo, ok := r.Context().Value("user").(map[string]interface{})
if !ok {
return "", "", "", errors.New("user context not found")
}

email, emailOk := userInfo["email"].(string)
uuid, uuidOk := userInfo["uuid"].(string)
encryptionSecret, secretOk := userInfo["encryption_secret"].(string)

if !emailOk || !uuidOk || !secretOk {
return "", "", "", errors.New("incomplete user session data")
}

return email, uuid, encryptionSecret, nil
}
167 changes: 167 additions & 0 deletions backend/controllers/auth_helpers_test.go
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added tests for helper functions

Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package controllers

import (
"context"
"net/http/httptest"
"testing"
)

func TestValidateUserCredentials_MatchingCredentials(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "user", map[string]interface{}{
"email": "test@example.com",
"uuid": "test-uuid-123",
"encryption_secret": "test-secret",
})
req = req.WithContext(ctx)

err := ValidateUserCredentials(req, "test@example.com", "test-uuid-123")
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
}

func TestValidateUserCredentials_MismatchedEmail(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "user", map[string]interface{}{
"email": "test@example.com",
"uuid": "test-uuid-123",
"encryption_secret": "test-secret",
})
req = req.WithContext(ctx)

err := ValidateUserCredentials(req, "wrong@example.com", "test-uuid-123")
if err == nil {
t.Error("Expected error for mismatched email")
}
if err.Error() != "credentials do not match authenticated user" {
t.Errorf("Expected specific error message, got %v", err)
}
}

func TestValidateUserCredentials_MismatchedUUID(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "user", map[string]interface{}{
"email": "test@example.com",
"uuid": "test-uuid-123",
"encryption_secret": "test-secret",
})
req = req.WithContext(ctx)

err := ValidateUserCredentials(req, "test@example.com", "wrong-uuid")
if err == nil {
t.Error("Expected error for mismatched UUID")
}
if err.Error() != "credentials do not match authenticated user" {
t.Errorf("Expected specific error message, got %v", err)
}
}

func TestValidateUserCredentials_NoContext(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)

err := ValidateUserCredentials(req, "test@example.com", "test-uuid-123")
if err == nil {
t.Error("Expected error for missing context")
}
if err.Error() != "user context not found" {
t.Errorf("Expected 'user context not found', got %v", err)
}
}

func TestValidateUserCredentials_InvalidSessionData(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "user", map[string]interface{}{
"email": 12345,
"uuid": "test-uuid-123",
})
req = req.WithContext(ctx)

err := ValidateUserCredentials(req, "test@example.com", "test-uuid-123")
if err == nil {
t.Error("Expected error for invalid session data")
}
if err.Error() != "invalid user session data" {
t.Errorf("Expected 'invalid user session data', got %v", err)
}
}

func TestGetSessionCredentials_ValidSession(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "user", map[string]interface{}{
"email": "test@example.com",
"uuid": "test-uuid-123",
"encryption_secret": "test-secret-456",
})
req = req.WithContext(ctx)

email, uuid, secret, err := GetSessionCredentials(req)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if email != "test@example.com" {
t.Errorf("Expected email test@example.com, got %s", email)
}
if uuid != "test-uuid-123" {
t.Errorf("Expected uuid test-uuid-123, got %s", uuid)
}
if secret != "test-secret-456" {
t.Errorf("Expected secret test-secret-456, got %s", secret)
}
}

func TestGetSessionCredentials_NoContext(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)

email, uuid, secret, err := GetSessionCredentials(req)
if err == nil {
t.Error("Expected error for missing context")
}
if err.Error() != "user context not found" {
t.Errorf("Expected 'user context not found', got %v", err)
}
if email != "" || uuid != "" || secret != "" {
t.Error("Expected empty strings for credentials")
}
}

func TestGetSessionCredentials_IncompleteData(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "user", map[string]interface{}{
"email": "test@example.com",
"uuid": "test-uuid-123",
})
req = req.WithContext(ctx)

email, uuid, secret, err := GetSessionCredentials(req)
if err == nil {
t.Error("Expected error for incomplete session data")
}
if err.Error() != "incomplete user session data" {
t.Errorf("Expected 'incomplete user session data', got %v", err)
}
if email != "" || uuid != "" || secret != "" {
t.Error("Expected empty strings for credentials")
}
}

func TestGetSessionCredentials_InvalidDataTypes(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
ctx := context.WithValue(req.Context(), "user", map[string]interface{}{
"email": 12345,
"uuid": true,
"encryption_secret": []string{"invalid"},
})
req = req.WithContext(ctx)

email, uuid, secret, err := GetSessionCredentials(req)
if err == nil {
t.Error("Expected error for invalid data types")
}
if err.Error() != "incomplete user session data" {
t.Errorf("Expected 'incomplete user session data', got %v", err)
}
if email != "" || uuid != "" || secret != "" {
t.Error("Expected empty strings for credentials")
}
}
22 changes: 13 additions & 9 deletions backend/controllers/complete_task.go
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using session creds for bulk operations

Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@ import (
// @Router /complete-task [post]
func CompleteTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
email, uuid, encryptionSecret, err := GetSessionCredentials(r)
if err != nil {
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest)
return
}
defer r.Body.Close()

// fmt.Printf("Raw request body: %s\n", string(body))

var requestBody models.CompleteTaskRequestBody

err = json.Unmarshal(body, &requestBody)
Expand All @@ -39,20 +43,20 @@ func CompleteTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}

email := requestBody.Email
encryptionSecret := requestBody.EncryptionSecret
uuid := requestBody.UUID
if requestBody.Email != "" || requestBody.UUID != "" {
if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil {
http.Error(w, "Invalid credentials", http.StatusForbidden)
return
}
}

taskuuid := requestBody.TaskUUID

if taskuuid == "" {
http.Error(w, "taskuuid is required", http.StatusBadRequest)
return
}

// if err := tw.CompleteTaskInTaskwarrior(email, encryptionSecret, uuid, taskuuid); err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
logStore := models.GetLogStore()
job := Job{
Name: "Complete Task",
Expand Down
17 changes: 13 additions & 4 deletions backend/controllers/complete_tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ func BulkCompleteTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}

email, uuid, encryptionSecret, err := GetSessionCredentials(r)
if err != nil {
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest)
Expand All @@ -40,9 +46,13 @@ func BulkCompleteTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}

email := requestBody.Email
encryptionSecret := requestBody.EncryptionSecret
uuid := requestBody.UUID
if requestBody.Email != "" || requestBody.UUID != "" {
if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil {
http.Error(w, "Invalid credentials", http.StatusForbidden)
return
}
}

taskUUIDs := requestBody.TaskUUIDs

if len(taskUUIDs) == 0 {
Expand All @@ -52,7 +62,6 @@ func BulkCompleteTaskHandler(w http.ResponseWriter, r *http.Request) {

logStore := models.GetLogStore()

// Create a *single* job for all UUIDs
job := Job{
Name: "Bulk Complete Tasks",
Execute: func() error {
Expand Down
20 changes: 13 additions & 7 deletions backend/controllers/delete_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ import (
// @Router /delete-task [post]
func DeleteTaskHandler(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodPost {
email, uuid, encryptionSecret, err := GetSessionCredentials(r)
if err != nil {
http.Error(w, "Authentication required", http.StatusUnauthorized)
return
}

body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("error reading request body: %v", err), http.StatusBadRequest)
Expand All @@ -37,20 +43,20 @@ func DeleteTaskHandler(w http.ResponseWriter, r *http.Request) {
return
}

email := requestBody.Email
encryptionSecret := requestBody.EncryptionSecret
uuid := requestBody.UUID
if requestBody.Email != "" || requestBody.UUID != "" {
if err := ValidateUserCredentials(r, requestBody.Email, requestBody.UUID); err != nil {
http.Error(w, "Invalid credentials", http.StatusForbidden)
return
}
}

taskuuid := requestBody.TaskUUID

if taskuuid == "" {
http.Error(w, "taskuuid is required", http.StatusBadRequest)
return
}

// if err := tw.DeleteTaskInTaskwarrior(email, encryptionSecret, uuid, taskuuid); err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
logStore := models.GetLogStore()
job := Job{
Name: "Delete Task",
Expand Down
Loading
Loading