From 908c84b7663d50b7229e5fa34f7e9d24d2ce8b1e Mon Sep 17 00:00:00 2001
From: Ashok Kumar Ramakrishnan <83938949+ashok672@users.noreply.github.com>
Date: Mon, 22 Dec 2025 18:03:57 -0800
Subject: [PATCH 1/6] Initial changes
---
msal/oauth2cli/authcode.py | 62 +++++++++++----
msal/oauth2cli/oauth2.py | 11 ++-
tests/test_authcode.py | 148 +++++++++++++++++++++++++++++++++++-
tests/test_response_mode.py | 112 +++++++++++++++++++++++++++
4 files changed, 313 insertions(+), 20 deletions(-)
create mode 100644 tests/test_response_mode.py
diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py
index ba266223..85e3f1e6 100644
--- a/msal/oauth2cli/authcode.py
+++ b/msal/oauth2cli/authcode.py
@@ -112,26 +112,56 @@ def do_GET(self):
# For flexibility, we choose to not check self.path matching redirect_uri
#assert self.path.startswith('/THE_PATH_REGISTERED_BY_THE_APP')
qs = parse_qs(urlparse(self.path).query)
- if qs.get('code') or qs.get("error"): # So, it is an auth response
- auth_response = _qs2kv(qs)
- logger.debug("Got auth response: %s", auth_response)
- if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
- # OAuth2 successful and error responses contain state when it was used
- # https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
- self._send_full_response("State mismatch") # Possibly an attack
- else:
- template = (self.server.success_template
- if "code" in qs else self.server.error_template)
- if _is_html(template.template):
- safe_data = _escape(auth_response) # Foiling an XSS attack
- else:
- safe_data = auth_response
- self._send_full_response(template.safe_substitute(**safe_data))
- self.server.auth_response = auth_response # Set it now, after the response is likely sent
+ if qs.get('code') or qs.get("error"): # Auth response via query string is not allowed
+ logger.error("Received auth response via query string (GET request). "
+ "This is a security risk. Only form_post (POST) is supported.")
+ self._send_full_response(
+ "Authentication method not supported. "
+ "The application requires response_mode=form_post for security. "
+ "Please ensure your application registration uses form_post response mode.",
+ is_ok=False)
else:
self._send_full_response(self.server.welcome_page)
# NOTE: Don't do self.server.shutdown() here. It'll halt the server.
+ def do_POST(self):
+ # Handle form_post response mode where auth code is sent via POST body
+ content_length = int(self.headers.get('Content-Length', 0))
+ post_data = self.rfile.read(content_length).decode('utf-8')
+ try:
+ from urllib.parse import parse_qs as parse_qs_post
+ except ImportError:
+ from urlparse import parse_qs as parse_qs_post
+
+ qs = parse_qs_post(post_data)
+ if qs.get('code') or qs.get('error'): # So, it is an auth response
+ auth_response = _qs2kv(qs)
+ logger.debug("Got auth response via POST: %s", auth_response)
+ self._process_auth_response(auth_response)
+ else:
+ self._send_full_response("Invalid POST request", is_ok=False)
+
+ def _process_auth_response(self, auth_response):
+ """Process the auth response from either GET or POST request."""
+ if self.server.auth_state and self.server.auth_state != auth_response.get("state"):
+ # OAuth2 successful and error responses contain state when it was used
+ # https://www.rfc-editor.org/rfc/rfc6749#section-4.2.2.1
+ self._send_full_response("State mismatch") # Possibly an attack
+ else:
+ template = (self.server.success_template
+ if "code" in auth_response else self.server.error_template)
+ if _is_html(template.template):
+ safe_data = _escape(auth_response) # Foiling an XSS attack
+ else:
+ safe_data = dict(auth_response) # Make a copy to avoid mutating original
+ # Provide default values for common OAuth2 response fields
+ # to avoid showing literal placeholder text like "$error_description"
+ safe_data.setdefault("error", "")
+ safe_data.setdefault("error_description", "")
+ safe_data.setdefault("error_uri", "")
+ self._send_full_response(template.safe_substitute(**safe_data))
+ self.server.auth_response = auth_response # Set it now, after the response is likely sent
+
def _send_full_response(self, body, is_ok=True):
self.send_response(200 if is_ok else 400)
content_type = 'text/html' if _is_html(body) else 'text/plain'
diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py
index ef32ceaa..7875c4d7 100644
--- a/msal/oauth2cli/oauth2.py
+++ b/msal/oauth2cli/oauth2.py
@@ -176,7 +176,16 @@ def _build_auth_request_params(self, response_type, **kwargs):
response_type = self._stringify(response_type)
params = {'client_id': self.client_id, 'response_type': response_type}
- params.update(kwargs) # Note: None values will override params
+ # Strictly enforce form_post for security - query string is not allowed
+ params['response_mode'] = 'form_post'
+ if 'response_mode' in kwargs and kwargs['response_mode'] != 'form_post':
+ import warnings
+ warnings.warn(
+ "response_mode='{}' is not supported for security reasons. "
+ "Using form_post instead. Query string transmission of authorization "
+ "codes is insecure and has been disabled.".format(kwargs['response_mode']),
+ UserWarning)
+ params.update({k: v for k, v in kwargs.items() if k != 'response_mode'}) # Exclude response_mode from kwargs
params = {k: v for k, v in params.items() if v is not None} # clean up
if params.get('scope'):
params['scope'] = self._stringify(params['scope'])
diff --git a/tests/test_authcode.py b/tests/test_authcode.py
index 329cafd8..fd38c294 100644
--- a/tests/test_authcode.py
+++ b/tests/test_authcode.py
@@ -26,17 +26,159 @@ def test_no_two_concurrent_receivers_can_listen_on_same_port(self):
pass
def test_template_should_escape_input(self):
+ """Test that POST request with HTML in error is properly escaped"""
with AuthCodeReceiver() as receiver:
receiver._scheduled_actions = [( # Injection happens here when the port is known
1, # Delay it until the receiver is activated by get_auth_response()
lambda: self.assertEqual(
"<tag>foo</tag>",
- requests.get("http://localhost:{}?error=foo".format(
- receiver.get_port())).text,
- "Unsafe data in HTML should be escaped",
+ requests.post(
+ "http://localhost:{}".format(receiver.get_port()),
+ data={"error": "foo"},
+ headers={'Content-Type': 'application/x-www-form-urlencoded'}
+ ).text,
))]
receiver.get_auth_response( # Starts server and hang until timeout
timeout=3,
error_template="$error",
)
+ def test_get_request_with_auth_code_is_rejected(self):
+ """Test that GET request with auth code is rejected for security"""
+ with AuthCodeReceiver() as receiver:
+ test_code = "test_auth_code_12345"
+ test_state = "test_state_67890"
+ receiver._scheduled_actions = [(
+ 1,
+ lambda: self._verify_get_rejection(
+ receiver.get_port(),
+ code=test_code,
+ state=test_state
+ )
+ )]
+ result = receiver.get_auth_response(
+ timeout=3,
+ state=test_state,
+ )
+ # Should not receive auth response via GET
+ self.assertIsNone(result)
+
+ def _verify_get_rejection(self, port, **params):
+ """Helper to verify GET requests with auth codes are rejected"""
+ try:
+ from urllib.parse import urlencode
+ except ImportError:
+ from urllib import urlencode
+
+ response = requests.get(
+ "http://localhost:{}?{}".format(port, urlencode(params))
+ )
+ # Verify error message about unsupported method
+ self.assertIn("not supported", response.text.lower())
+ self.assertEqual(response.status_code, 400)
+
+ def test_post_request_with_auth_code(self):
+ """Test that POST request with auth code is handled correctly (form_post response mode)"""
+ with AuthCodeReceiver() as receiver:
+ test_code = "test_auth_code_12345"
+ test_state = "test_state_67890"
+ receiver._scheduled_actions = [(
+ 1,
+ lambda: self._send_post_auth_response(
+ receiver.get_port(),
+ code=test_code,
+ state=test_state
+ )
+ )]
+ result = receiver.get_auth_response(
+ timeout=3,
+ state=test_state,
+ success_template="Success: Got code $code",
+ )
+ self.assertIsNotNone(result)
+ self.assertEqual(result.get("code"), test_code)
+ self.assertEqual(result.get("state"), test_state)
+
+ def test_post_request_with_error(self):
+ """Test that POST request with error is handled correctly"""
+ with AuthCodeReceiver() as receiver:
+ test_error = "access_denied"
+ test_error_description = "User denied access"
+ receiver._scheduled_actions = [(
+ 1,
+ lambda: self._send_post_auth_response(
+ receiver.get_port(),
+ error=test_error,
+ error_description=test_error_description
+ )
+ )]
+ result = receiver.get_auth_response(
+ timeout=3,
+ error_template="Error: $error - $error_description",
+ )
+ self.assertIsNotNone(result)
+ self.assertEqual(result.get("error"), test_error)
+ self.assertEqual(result.get("error_description"), test_error_description)
+
+ def test_post_request_state_mismatch(self):
+ """Test that POST request with mismatched state is rejected"""
+ with AuthCodeReceiver() as receiver:
+ test_code = "test_auth_code_12345"
+ wrong_state = "wrong_state"
+ expected_state = "expected_state"
+ receiver._scheduled_actions = [(
+ 1,
+ lambda: self._send_post_auth_response(
+ receiver.get_port(),
+ code=test_code,
+ state=wrong_state
+ )
+ )]
+ result = receiver.get_auth_response(
+ timeout=3,
+ state=expected_state,
+ )
+ # When state mismatches, the response should not be set
+ self.assertIsNone(result)
+
+ def test_post_request_escapes_html(self):
+ """Test that POST request with HTML in error is properly escaped"""
+ with AuthCodeReceiver() as receiver:
+ malicious_error = ""
+ receiver._scheduled_actions = [(
+ 1,
+ lambda: self._verify_post_escaping(receiver.get_port(), malicious_error)
+ )]
+ receiver.get_auth_response(
+ timeout=3,
+ error_template="$error",
+ )
+
+ def _send_post_auth_response(self, port, **params):
+ """Helper to send POST request with auth response"""
+ try:
+ from urllib.parse import urlencode
+ except ImportError:
+ from urllib import urlencode
+
+ response = requests.post(
+ "http://localhost:{}".format(port),
+ data=params,
+ headers={'Content-Type': 'application/x-www-form-urlencoded'}
+ )
+ return response
+
+ def _verify_post_escaping(self, port, malicious_error):
+ """Helper to verify HTML escaping in POST requests"""
+ response = self._send_post_auth_response(port, error=malicious_error)
+ # Verify that the malicious script is escaped
+ self.assertIn("<script>", response.text)
+ self.assertNotIn("