diff --git a/lambdas/handlers/patch_document_review_handler.py b/lambdas/handlers/patch_document_review_handler.py index 2d0ca35bc7..56e27e6357 100644 --- a/lambdas/handlers/patch_document_review_handler.py +++ b/lambdas/handlers/patch_document_review_handler.py @@ -11,9 +11,11 @@ from utils.decorators.override_error_check import override_error_check from utils.decorators.set_audit_arg import set_request_context_for_logging from utils.decorators.validate_patient_id import validate_patient_id +from utils.exceptions import OdsErrorException from utils.lambda_exceptions import UpdateDocumentReviewException from utils.lambda_handler_utils import validate_review_path_parameters from utils.lambda_response import ApiGatewayResponse +from utils.ods_utils import extract_ods_code_from_request_context from utils.request_context import request_context logger = LoggingService(__name__) @@ -46,12 +48,10 @@ def lambda_handler(event, context): document_id, document_version = validate_review_path_parameters(event) - reviewer_ods_code = request_context.authorization.get( - "selected_organisation", {} - ).get("org_ods_code") + try: + reviewer_ods_code = extract_ods_code_from_request_context() - if not reviewer_ods_code: - logger.error("Missing ODS code in authorization token") + except OdsErrorException: raise UpdateDocumentReviewException( 401, LambdaError.DocumentReferenceUnauthorised ) diff --git a/lambdas/handlers/search_document_review_handler.py b/lambdas/handlers/search_document_review_handler.py index 5c248e14a6..8b49af5d68 100644 --- a/lambdas/handlers/search_document_review_handler.py +++ b/lambdas/handlers/search_document_review_handler.py @@ -15,7 +15,7 @@ from utils.exceptions import OdsErrorException from utils.lambda_exceptions import DocumentReviewLambdaException from utils.lambda_response import ApiGatewayResponse -from utils.request_context import request_context +from utils.ods_utils import extract_ods_code_from_request_context logger = LoggingService(__name__) @@ -56,7 +56,7 @@ def lambda_handler(event, context): logger.info("Feature flag not enabled, event will not be processed") raise DocumentReviewLambdaException(403, LambdaError.FeatureFlagDisabled) - ods_code = get_ods_code_from_request_context() + ods_code = extract_ods_code_from_request_context() params = parse_querystring_parameters(event) @@ -86,22 +86,6 @@ def lambda_handler(event, context): ).create_api_gateway_response() -def get_ods_code_from_request_context(): - logger.info("Getting ODS code from request context") - try: - ods_code = request_context.authorization.get("selected_organisation", {}).get( - "org_ods_code" - ) - if not ods_code: - raise OdsErrorException() - - return ods_code - - except AttributeError as e: - logger.error(e) - raise DocumentReviewLambdaException(401, LambdaError.DocumentReviewMissingODS) - - def parse_querystring_parameters(event): logger.info("Parsing query string parameters.") params = event.get("queryStringParameters", {}) diff --git a/lambdas/tests/unit/handlers/test_patch_document_review_handler.py b/lambdas/tests/unit/handlers/test_patch_document_review_handler.py index 721b65088c..cf8279bc63 100644 --- a/lambdas/tests/unit/handlers/test_patch_document_review_handler.py +++ b/lambdas/tests/unit/handlers/test_patch_document_review_handler.py @@ -4,6 +4,8 @@ from enums.document_review_status import DocumentReviewStatus from enums.lambda_error import LambdaError from handlers.patch_document_review_handler import lambda_handler +from tests.unit.conftest import TEST_CURRENT_GP_ODS +from utils.exceptions import OdsErrorException from utils.lambda_exceptions import UpdateDocumentReviewException from utils.lambda_response import ApiGatewayResponse @@ -130,25 +132,18 @@ def mocked_service(set_env, mocker): @pytest.fixture def mock_authorization(mocker): - mocked_context = mocker.MagicMock() - mocked_context.authorization = { - "selected_organisation": {"org_ods_code": TEST_REVIEWER_ODS_CODE}, - } - yield mocker.patch( - "handlers.patch_document_review_handler.request_context", mocked_context + return mocker.patch( + "handlers.patch_document_review_handler.extract_ods_code_from_request_context" ) @pytest.fixture def mock_missing_authorization(mocker): - mocked_context = mocker.MagicMock() - mocked_context.authorization = { - "selected_organisation": {"org_ods_code": None}, - } - yield mocker.patch( - "handlers.patch_document_review_handler.request_context", mocked_context + mock_auth = mocker.patch( + "handlers.patch_document_review_handler.extract_ods_code_from_request_context" ) - + mock_auth.side_effect = OdsErrorException() + yield mock_auth def test_lambda_handler_returns_200_when_document_review_approved( mocked_service, @@ -159,6 +154,7 @@ def test_lambda_handler_returns_200_when_document_review_approved( mock_upload_document_iteration_3_enabled, ): mocked_service.update_document_review.return_value = None + mock_authorization.return_value = TEST_CURRENT_GP_ODS expected = ApiGatewayResponse(200, "", "PATCH").create_api_gateway_response() @@ -186,6 +182,7 @@ def test_lambda_handler_returns_200_when_document_review_rejected( mock_upload_document_iteration_3_enabled, ): mocked_service.update_document_review.return_value = None + mock_authorization.return_value = TEST_CURRENT_GP_ODS expected = ApiGatewayResponse(200, "", "PATCH").create_api_gateway_response() @@ -209,6 +206,8 @@ def test_lambda_handler_returns_400_when_patient_id_missing( mock_authorization, mock_upload_document_iteration_3_enabled, ): + mock_authorization.return_value = TEST_CURRENT_GP_ODS + actual = lambda_handler(missing_patient_id_event, context) expected = ApiGatewayResponse( diff --git a/lambdas/tests/unit/handlers/test_search_document_review_handler.py b/lambdas/tests/unit/handlers/test_search_document_review_handler.py index dfc20af0d3..c4e17dff54 100644 --- a/lambdas/tests/unit/handlers/test_search_document_review_handler.py +++ b/lambdas/tests/unit/handlers/test_search_document_review_handler.py @@ -4,7 +4,6 @@ import pytest from enums.lambda_error import LambdaError from handlers.search_document_review_handler import ( - get_ods_code_from_request_context, lambda_handler, parse_querystring_parameters, ) @@ -13,6 +12,7 @@ from tests.unit.helpers.data.search_document_review.dynamo_response import ( MOCK_DOCUMENT_REVIEW_SEARCH_RESPONSE, ) +from utils.exceptions import OdsErrorException from utils.lambda_exceptions import DocumentReviewLambdaException from utils.lambda_response import ApiGatewayResponse @@ -91,43 +91,25 @@ def event_with_all_params(): @pytest.fixture() -def mocked_request_context_with_ods(mocker): - mocked_context = mocker.MagicMock() - mocked_context.authorization = { - "selected_organisation": {"org_ods_code": TEST_CURRENT_GP_ODS}, - } - yield mocker.patch( - "handlers.search_document_review_handler.request_context", mocked_context +def mocked_extract_ods_with_ods_code(mocker): + mock_extract = mocker.patch( + "handlers.search_document_review_handler.extract_ods_code_from_request_context" ) + mock_extract.return_value = TEST_CURRENT_GP_ODS + yield mock_extract @pytest.fixture() -def mocked_request_context_without_ods(mocker): - mocked_context = mocker.MagicMock() - mocked_context.authorization = { - "selected_organisation": {"org_ods_code": ""}, - } - yield mocker.patch( - "handlers.search_document_review_handler.request_context", mocked_context +def mocked_extract_ods_code_error(mocker): + mocked_extract = mocker.patch( + "handlers.search_document_review_handler.extract_ods_code_from_request_context", ) - - -def test_get_ods_code_from_request(mocked_request_context_with_ods): - - assert get_ods_code_from_request_context() == TEST_CURRENT_GP_ODS - - -def test_get_ods_code_from_request_throws_exception_no_auth(mocker): - mocker.patch("handlers.search_document_review_handler.request_context", {}) - - with pytest.raises(DocumentReviewLambdaException) as e: - get_ods_code_from_request_context() - - assert e.value.status_code == 401 + mocked_extract.side_effect = OdsErrorException() + yield mocked_extract def test_handler_returns_401_response_no_ods_code_in_request_context( - set_env, context, event, mock_service, mocked_request_context_without_ods + set_env, context, event, mock_service, mocked_extract_ods_code_error ): body = json.dumps( { @@ -178,7 +160,7 @@ def test_process_request_called_with_correct_arguments( context, set_env, event_with_all_params, - mocked_request_context_with_ods, + mocked_extract_ods_with_ods_code, ): lambda_handler(event_with_all_params, context) @@ -191,7 +173,7 @@ def test_process_request_called_with_correct_arguments( def test_handler_returns_empty_list_of_references_no_dynamo_results_no_limit_in_query_params( - mock_service, context, set_env, mocked_request_context_with_ods, event + mock_service, context, set_env, mocked_extract_ods_with_ods_code, event ): mock_service.process_request.return_value = ([], None) @@ -213,7 +195,7 @@ def test_handler_returns_empty_list_of_references_no_dynamo_results_no_limit_in_ def test_handler_returns_list_of_references_last_evaluated_key_more_results_available( - mock_service, context, set_env, mocked_request_context_with_ods, event_with_limit + mock_service, context, set_env, mocked_extract_ods_with_ods_code, event_with_limit ): references = [ @@ -248,7 +230,7 @@ def test_handler_returns_list_of_references_last_evaluated_key_more_results_avai def test_handler_returns_list_of_references_no_limit_passed( - mock_service, context, set_env, mocked_request_context_with_ods, event + mock_service, context, set_env, mocked_extract_ods_with_ods_code, event ): references = [ DocumentUploadReviewReference.model_validate(item).model_dump_camel_case( @@ -278,7 +260,7 @@ def test_handler_returns_list_of_references_no_limit_passed( def test_handler_returns_500_response_error_raised( - mock_service, context, set_env, mocked_request_context_with_ods, event_with_limit + mock_service, context, set_env, mocked_extract_ods_with_ods_code, event_with_limit ): mock_service.process_request.side_effect = DocumentReviewLambdaException( diff --git a/lambdas/tests/unit/utils/test_ods_utils.py b/lambdas/tests/unit/utils/test_ods_utils.py index 8fac994889..d0e5703c8f 100644 --- a/lambdas/tests/unit/utils/test_ods_utils.py +++ b/lambdas/tests/unit/utils/test_ods_utils.py @@ -5,7 +5,7 @@ from utils.ods_utils import ( extract_ods_code_from_request_context, extract_ods_role_code_with_r_prefix_from_role_codes_string, - is_ods_code_active, + is_ods_code_active )