From 908c84b7663d50b7229e5fa34f7e9d24d2ce8b1e Mon Sep 17 00:00:00 2001 From: Ashok Kumar Ramakrishnan <83938949+ashok672@users.noreply.github.com> Date: Mon, 22 Dec 2025 18:03:57 -0800 Subject: [PATCH 1/6] Initial changes --- msal/oauth2cli/authcode.py | 62 +++++++++++---- msal/oauth2cli/oauth2.py | 11 ++- tests/test_authcode.py | 148 +++++++++++++++++++++++++++++++++++- tests/test_response_mode.py | 112 +++++++++++++++++++++++++++ 4 files changed, 313 insertions(+), 20 deletions(-) create mode 100644 tests/test_response_mode.py diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index ba266223..85e3f1e6 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -112,26 +112,56 @@ def do_GET(self): # For flexibility, we choose to not check self.path matching redirect_uri #assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP') qs = parse_qs(urlparse(self.path).query) - if qs.get('code') or qs.get("error"): # So, it is an auth response - auth_response = _qs2kv(qs) - logger.debug("Got auth response: %s", auth_response) - if self.server.auth_state and self.server.auth_state != auth_response.get("state"): - # OAuth2 successful and error responses contain state when it was used - # https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 - self._send_full_response("State mismatch") # Possibly an attack - else: - template = (self.server.success_template - if "code" in qs else self.server.error_template) - if _is_html(template.template): - safe_data = _escape(auth_response) # Foiling an XSS attack - else: - safe_data = auth_response - self._send_full_response(template.safe_substitute(**safe_data)) - self.server.auth_response = auth_response # Set it now, after the response is likely sent + if qs.get('code') or qs.get("error"): # Auth response via query string is not allowed + logger.error("Received auth response via query string (GET request). " + "This is a security risk. Only form_post (POST) is supported.") + self._send_full_response( + "Authentication method not supported. " + "The application requires response_mode=form_post for security. " + "Please ensure your application registration uses form_post response mode.", + is_ok=False) else: self._send_full_response(self.server.welcome_page) # NOTE: Don't do self.server.shutdown() here. It'll halt the server. + def do_POST(self): + # Handle form_post response mode where auth code is sent via POST body + content_length = int(self.headers.get('Content-Length', 0)) + post_data = self.rfile.read(content_length).decode('utf-8') + try: + from urllib.parse import parse_qs as parse_qs_post + except ImportError: + from urlparse import parse_qs as parse_qs_post + + qs = parse_qs_post(post_data) + if qs.get('code') or qs.get('error'): # So, it is an auth response + auth_response = _qs2kv(qs) + logger.debug("Got auth response via POST: %s", auth_response) + self._process_auth_response(auth_response) + else: + self._send_full_response("Invalid POST request", is_ok=False) + + def _process_auth_response(self, auth_response): + """Process the auth response from either GET or POST request.""" + if self.server.auth_state and self.server.auth_state != auth_response.get("state"): + # OAuth2 successful and error responses contain state when it was used + # https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1 + self._send_full_response("State mismatch") # Possibly an attack + else: + template = (self.server.success_template + if "code" in auth_response else self.server.error_template) + if _is_html(template.template): + safe_data = _escape(auth_response) # Foiling an XSS attack + else: + safe_data = dict(auth_response) # Make a copy to avoid mutating original + # Provide default values for common OAuth2 response fields + # to avoid showing literal placeholder text like "$error_description" + safe_data.setdefault("error", "") + safe_data.setdefault("error_description", "") + safe_data.setdefault("error_uri", "") + self._send_full_response(template.safe_substitute(**safe_data)) + self.server.auth_response = auth_response # Set it now, after the response is likely sent + def _send_full_response(self, body, is_ok=True): self.send_response(200 if is_ok else 400) content_type = 'text/html' if _is_html(body) else 'text/plain' diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index ef32ceaa..7875c4d7 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -176,7 +176,16 @@ def _build_auth_request_params(self, response_type, **kwargs): response_type = self._stringify(response_type) params = {'client_id': self.client_id, 'response_type': response_type} - params.update(kwargs) # Note: None values will override params + # Strictly enforce form_post for security - query string is not allowed + params['response_mode'] = 'form_post' + if 'response_mode' in kwargs and kwargs['response_mode'] != 'form_post': + import warnings + warnings.warn( + "response_mode='{}' is not supported for security reasons. " + "Using form_post instead. Query string transmission of authorization " + "codes is insecure and has been disabled.".format(kwargs['response_mode']), + UserWarning) + params.update({k: v for k, v in kwargs.items() if k != 'response_mode'}) # Exclude response_mode from kwargs params = {k: v for k, v in params.items() if v is not None} # clean up if params.get('scope'): params['scope'] = self._stringify(params['scope']) diff --git a/tests/test_authcode.py b/tests/test_authcode.py index 329cafd8..fd38c294 100644 --- a/tests/test_authcode.py +++ b/tests/test_authcode.py @@ -26,17 +26,159 @@ def test_no_two_concurrent_receivers_can_listen_on_same_port(self): pass def test_template_should_escape_input(self): + """Test that POST request with HTML in error is properly escaped""" with AuthCodeReceiver() as receiver: receiver._scheduled_actions = [( # Injection happens here when the port is known 1, # Delay it until the receiver is activated by get_auth_response() lambda: self.assertEqual( "<tag>foo</tag>", - requests.get("http://localhost:{}?error=foo".format( - receiver.get_port())).text, - "Unsafe data in HTML should be escaped", + requests.post( + "http://localhost:{}".format(receiver.get_port()), + data={"error": "foo"}, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ).text, ))] receiver.get_auth_response( # Starts server and hang until timeout timeout=3, error_template="$error", ) + def test_get_request_with_auth_code_is_rejected(self): + """Test that GET request with auth code is rejected for security""" + with AuthCodeReceiver() as receiver: + test_code = "test_auth_code_12345" + test_state = "test_state_67890" + receiver._scheduled_actions = [( + 1, + lambda: self._verify_get_rejection( + receiver.get_port(), + code=test_code, + state=test_state + ) + )] + result = receiver.get_auth_response( + timeout=3, + state=test_state, + ) + # Should not receive auth response via GET + self.assertIsNone(result) + + def _verify_get_rejection(self, port, **params): + """Helper to verify GET requests with auth codes are rejected""" + try: + from urllib.parse import urlencode + except ImportError: + from urllib import urlencode + + response = requests.get( + "http://localhost:{}?{}".format(port, urlencode(params)) + ) + # Verify error message about unsupported method + self.assertIn("not supported", response.text.lower()) + self.assertEqual(response.status_code, 400) + + def test_post_request_with_auth_code(self): + """Test that POST request with auth code is handled correctly (form_post response mode)""" + with AuthCodeReceiver() as receiver: + test_code = "test_auth_code_12345" + test_state = "test_state_67890" + receiver._scheduled_actions = [( + 1, + lambda: self._send_post_auth_response( + receiver.get_port(), + code=test_code, + state=test_state + ) + )] + result = receiver.get_auth_response( + timeout=3, + state=test_state, + success_template="Success: Got code $code", + ) + self.assertIsNotNone(result) + self.assertEqual(result.get("code"), test_code) + self.assertEqual(result.get("state"), test_state) + + def test_post_request_with_error(self): + """Test that POST request with error is handled correctly""" + with AuthCodeReceiver() as receiver: + test_error = "access_denied" + test_error_description = "User denied access" + receiver._scheduled_actions = [( + 1, + lambda: self._send_post_auth_response( + receiver.get_port(), + error=test_error, + error_description=test_error_description + ) + )] + result = receiver.get_auth_response( + timeout=3, + error_template="Error: $error - $error_description", + ) + self.assertIsNotNone(result) + self.assertEqual(result.get("error"), test_error) + self.assertEqual(result.get("error_description"), test_error_description) + + def test_post_request_state_mismatch(self): + """Test that POST request with mismatched state is rejected""" + with AuthCodeReceiver() as receiver: + test_code = "test_auth_code_12345" + wrong_state = "wrong_state" + expected_state = "expected_state" + receiver._scheduled_actions = [( + 1, + lambda: self._send_post_auth_response( + receiver.get_port(), + code=test_code, + state=wrong_state + ) + )] + result = receiver.get_auth_response( + timeout=3, + state=expected_state, + ) + # When state mismatches, the response should not be set + self.assertIsNone(result) + + def test_post_request_escapes_html(self): + """Test that POST request with HTML in error is properly escaped""" + with AuthCodeReceiver() as receiver: + malicious_error = "" + receiver._scheduled_actions = [( + 1, + lambda: self._verify_post_escaping(receiver.get_port(), malicious_error) + )] + receiver.get_auth_response( + timeout=3, + error_template="$error", + ) + + def _send_post_auth_response(self, port, **params): + """Helper to send POST request with auth response""" + try: + from urllib.parse import urlencode + except ImportError: + from urllib import urlencode + + response = requests.post( + "http://localhost:{}".format(port), + data=params, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + return response + + def _verify_post_escaping(self, port, malicious_error): + """Helper to verify HTML escaping in POST requests""" + response = self._send_post_auth_response(port, error=malicious_error) + # Verify that the malicious script is escaped + self.assertIn("<script>", response.text) + self.assertNotIn("