diff --git a/dev/README.md b/dev/README.md new file mode 100644 index 00000000..e8c10764 --- /dev/null +++ b/dev/README.md @@ -0,0 +1,6 @@ +This directory contains tools to aid the developers of the Microsoft 365 Agents SDK for Python. + +### `benchmark` + +This folder contains benchmarking utilities built in Python to send concurrent requests +to an agent. \ No newline at end of file diff --git a/dev/benchmark/README.md b/dev/benchmark/README.md new file mode 100644 index 00000000..c64b118d --- /dev/null +++ b/dev/benchmark/README.md @@ -0,0 +1,107 @@ +A simple benchmarking tool. + +## Benchmark Python Environment Manual Setup (Windows) + +Currently a version of this tool that spawns async workers/coroutines instead of +concurrent threads is not supported, so if you use a "normal" (non free-threaded) version +of Python, you will be running with the global interpreter lock (GIL). + +Note: This may or may not incur significant changes in performance over using +free-threaded concurrent tests or async workers, depending on the test scenario. + +Install any Python version >= 3.9. Check with: + +```bash +python --version +``` + +Then, set up and activate the virtual environment with: + +```bash +python -m venv venv +. ./venv/Scripts/activate +pip install -r requirements.txt +``` + +To activate the virtual environment, use: + +```bash +. ./venv/Scripts/activate +``` + +To deactivate it, you may use: + +```bash +deactivate +``` + +## Benchmark Python Environment Setup (Windows) - Free Threaded Python + +Traditionally, most Python versions have a global interpreter lock (GIL) which prevents +more than 1 thread to run at the same time. With 3.13, there are free-threaded versions +of Python which allow one to bypass this constraint. This section walks through how +to do that on Windows. Use PowerShell. + +Based on: https://docs.python.org/3/using/windows.html# + +Go to `Microsoft Store` and install `Python Install Manager` and follow the instructions +presented. You may have to make certain changes to alias used by your machine (that +should be guided by the installation process). + +Based on: https://docs.python.org/3/whatsnew/3.13.html#free-threaded-cpython + +In PowerShell, install the free-threaded version of Python of your choice. In this guide +we will install `3.14t`: + +```bash +py install 3.14t +``` + +Then, set up and activate the virtual environment with: + +```bash +python3.14t -m venv venv +. ./venv/Scripts/activate +pip install -r requirements.txt +``` + +To activate the virtual environment, use: + +```bash +. ./venv/Scripts/activate +``` + +To deactivate it, you may use: + +```bash +deactivate +``` + +## Benchmark Configuration + +If you open the `env.template` file, you will see three environmental variables to define: + +```bash +TENANT_ID= +APP_ID= +APP_SECRET= +``` + +For `APP_ID` use the app Id of your ABS resource. For `APP_SECRET` set it to a secret +for the App Registration resource tied to your ABS resource. Finally, the `TENANT_ID` +variable should be set to the tenant Id of your ABS resource. + +These settings are used to generate valid tokens that are sent and validated by the +agent you are trying to run. + +## Usage + +Running these tests requires you to have the agent running in a separate process. You +may open a separate PowerShell window or VSCode window and run your agent there. + +To run the basic payload sending stress test (our only implemented test so far), use: + +```bash +. ./venv/Scripts/activate # activate the virtual environment if you haven't already +python -m src.main --num_workers=... +``` diff --git a/dev/benchmark/env.template b/dev/benchmark/env.template new file mode 100644 index 00000000..ea7473b2 --- /dev/null +++ b/dev/benchmark/env.template @@ -0,0 +1,3 @@ +TENANT_ID= +APP_ID= +APP_SECRET= \ No newline at end of file diff --git a/dev/benchmark/payload.json b/dev/benchmark/payload.json new file mode 100644 index 00000000..28ac2a8c --- /dev/null +++ b/dev/benchmark/payload.json @@ -0,0 +1,19 @@ +{ + "channelId": "msteams", + "serviceUrl": "http://localhost:49231/_connector", + "delivery_mode": "expectReplies", + "recipient": { + "id": "00000000-0000-0000-0000-00000000000011", + "name": "Test Bot" + }, + "conversation": { + "id": "personal-chat-id", + "conversationType": "personal", + "tenantId": "00000000-0000-0000-0000-0000000000001" + }, + "from": { + "id": "user-id-0", + "aadObjectId": "00000000-0000-0000-0000-0000000000020" + }, + "type": "message" +} \ No newline at end of file diff --git a/dev/benchmark/requirements.txt b/dev/benchmark/requirements.txt new file mode 100644 index 00000000..ea3bd96d --- /dev/null +++ b/dev/benchmark/requirements.txt @@ -0,0 +1,4 @@ +microsoft-agents-activity +microsoft-agents-hosting-core +click +azure-identity \ No newline at end of file diff --git a/dev/benchmark/src/__init__.py b/dev/benchmark/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dev/benchmark/src/aggregated_results.py b/dev/benchmark/src/aggregated_results.py new file mode 100644 index 00000000..b1edaa5e --- /dev/null +++ b/dev/benchmark/src/aggregated_results.py @@ -0,0 +1,51 @@ +from .executor import ExecutionResult + + +class AggregatedResults: + """Class to analyze execution time results.""" + + def __init__(self, results: list[ExecutionResult]): + self._results = results + + self.average = sum(r.duration for r in results) / len(results) if results else 0 + self.min = min((r.duration for r in results), default=0) + self.max = max((r.duration for r in results), default=0) + self.success_count = sum(1 for r in results if r.success) + self.failure_count = len(results) - self.success_count + self.total_time = sum(r.duration for r in results) + + def display(self, start_time: float, end_time: float): + """Display aggregated results.""" + print() + print("---- Aggregated Results ----") + print() + print(f"Average Time: {self.average:.4f} seconds") + print(f"Min Time: {self.min:.4f} seconds") + print(f"Max Time: {self.max:.4f} seconds") + print() + print(f"Success Rate: {self.success_count} / {len(self._results)}") + print() + print(f"Total Time: {end_time - start_time} seconds") + print("----------------------------") + print() + + def display_timeline(self): + """Display timeline of individual execution results.""" + print() + print("---- Execution Timeline ----") + print( + "Each '.' represents 1 second of successful execution. So a line like '...' is a success that took 3 seconds (rounded up), 'x' represents a failure." + ) + print() + for result in sorted(self._results, key=lambda r: r.exe_id): + c = "." if result.success else "x" + if c == ".": + duration = int(round(result.duration)) + for _ in range(1 + duration): + print(c, end="") + print() + else: + print(c) + + print("----------------------------") + print() diff --git a/dev/benchmark/src/config.py b/dev/benchmark/src/config.py new file mode 100644 index 00000000..403fbafc --- /dev/null +++ b/dev/benchmark/src/config.py @@ -0,0 +1,23 @@ +import os +from dotenv import load_dotenv + +load_dotenv() + + +class BenchmarkConfig: + """Configuration class for benchmark settings.""" + + TENANT_ID: str = "" + APP_ID: str = "" + APP_SECRET: str = "" + AGENT_API_URL: str = "" + + @classmethod + def load_from_env(cls) -> None: + """Loads configuration values from environment variables.""" + cls.TENANT_ID = os.environ.get("TENANT_ID", "") + cls.APP_ID = os.environ.get("APP_ID", "") + cls.APP_SECRET = os.environ.get("APP_SECRET", "") + cls.AGENT_URL = os.environ.get( + "AGENT_API_URL", "http://localhost:3978/api/messages" + ) diff --git a/dev/benchmark/src/executor/__init__.py b/dev/benchmark/src/executor/__init__.py new file mode 100644 index 00000000..b01cfb1c --- /dev/null +++ b/dev/benchmark/src/executor/__init__.py @@ -0,0 +1,11 @@ +from .coroutine_executor import CoroutineExecutor +from .execution_result import ExecutionResult +from .executor import Executor +from .thread_executor import ThreadExecutor + +__all__ = [ + "CoroutineExecutor", + "ExecutionResult", + "Executor", + "ThreadExecutor", +] diff --git a/dev/benchmark/src/executor/coroutine_executor.py b/dev/benchmark/src/executor/coroutine_executor.py new file mode 100644 index 00000000..5d03ff19 --- /dev/null +++ b/dev/benchmark/src/executor/coroutine_executor.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +from typing import Callable, Awaitable, Any + +from .executor import Executor +from .execution_result import ExecutionResult + + +class CoroutineExecutor(Executor): + """An executor that runs asynchronous functions using asyncio.""" + + def run( + self, func: Callable[[], Awaitable[Any]], num_workers: int = 1 + ) -> list[ExecutionResult]: + """Run the given asynchronous function using the specified number of coroutines. + + :param func: An asynchronous function to be executed. + :param num_workers: The number of coroutines to use. + """ + + async def gather(): + return await asyncio.gather( + *[self.run_func(i, func) for i in range(num_workers)] + ) + + return asyncio.run(gather()) diff --git a/dev/benchmark/src/executor/execution_result.py b/dev/benchmark/src/executor/execution_result.py new file mode 100644 index 00000000..ae72cabb --- /dev/null +++ b/dev/benchmark/src/executor/execution_result.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from typing import Any, Optional +from dataclasses import dataclass + + +@dataclass +class ExecutionResult: + """Class to represent the result of an execution.""" + + exe_id: int + + start_time: float + end_time: float + + result: Any = None + error: Optional[Exception] = None + + @property + def success(self) -> bool: + """Indicate whether the execution was successful.""" + return self.error is None + + @property + def duration(self) -> float: + """Calculate the duration of the execution, in seconds.""" + return self.end_time - self.start_time diff --git a/dev/benchmark/src/executor/executor.py b/dev/benchmark/src/executor/executor.py new file mode 100644 index 00000000..688c1cfb --- /dev/null +++ b/dev/benchmark/src/executor/executor.py @@ -0,0 +1,49 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from datetime import datetime, timezone +from abc import ABC, abstractmethod +from typing import Callable, Awaitable, Any + +from .execution_result import ExecutionResult + + +class Executor(ABC): + """Protocol for executing asynchronous functions concurrently.""" + + async def run_func( + self, exe_id: int, func: Callable[[], Awaitable[Any]] + ) -> ExecutionResult: + """Run the given asynchronous function. + + :param exe_id: An identifier for the execution instance. + :param func: An asynchronous function to be executed. + """ + + start_time = datetime.now(timezone.utc).timestamp() + try: + result = await func() + return ExecutionResult( + exe_id=exe_id, + result=result, + start_time=start_time, + end_time=datetime.now(timezone.utc).timestamp(), + ) + except Exception as e: # pylint: disable=broad-except + return ExecutionResult( + exe_id=exe_id, + error=e, + start_time=start_time, + end_time=datetime.now(timezone.utc).timestamp(), + ) + + @abstractmethod + def run( + self, func: Callable[[], Awaitable[Any]], num_workers: int = 1 + ) -> list[ExecutionResult]: + """Run the given asynchronous function using the specified number of workers. + + :param func: An asynchronous function to be executed. + :param num_workers: The number of concurrent workers to use. + """ + raise NotImplementedError("This method should be implemented by subclasses.") diff --git a/dev/benchmark/src/executor/thread_executor.py b/dev/benchmark/src/executor/thread_executor.py new file mode 100644 index 00000000..ee3ce532 --- /dev/null +++ b/dev/benchmark/src/executor/thread_executor.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import logging +import asyncio +from typing import Callable, Awaitable, Any +from concurrent.futures import ThreadPoolExecutor + +from .executor import Executor +from .execution_result import ExecutionResult + +logger = logging.getLogger(__name__) + + +class ThreadExecutor(Executor): + """An executor that runs asynchronous functions using multiple threads.""" + + def run( + self, func: Callable[[], Awaitable[Any]], num_workers: int = 1 + ) -> list[ExecutionResult]: + """Run the given asynchronous function using the specified number of threads. + + :param func: An asynchronous function to be executed. + :param num_workers: The number of concurrent threads to use. + """ + + def _func(exe_id: int) -> ExecutionResult: + return asyncio.run(self.run_func(exe_id, func)) + + results: list[ExecutionResult] = [] + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(_func, i) for i in range(num_workers)] + for future in futures: + results.append(future.result()) + + return results diff --git a/dev/benchmark/src/generate_token.py b/dev/benchmark/src/generate_token.py new file mode 100644 index 00000000..19c0e93e --- /dev/null +++ b/dev/benchmark/src/generate_token.py @@ -0,0 +1,34 @@ +import requests +from .config import BenchmarkConfig + +URL = "https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/token" + + +def generate_token(app_id: str, app_secret: str) -> str: + """Generate a token using the provided app credentials.""" + + url = URL.format(tenant_id=BenchmarkConfig.TENANT_ID) + + res = requests.post( + url, + headers={ + "Content-Type": "application/x-www-form-urlencoded", + }, + data={ + "grant_type": "client_credentials", + "client_id": app_id, + "client_secret": app_secret, + "scope": f"{app_id}/.default", + }, + timeout=10, + ) + return res.json().get("access_token") + + +def generate_token_from_env() -> str: + """Generates a token using environment variables.""" + app_id = BenchmarkConfig.APP_ID + app_secret = BenchmarkConfig.APP_SECRET + if not app_id or not app_secret: + raise ValueError("APP_ID and APP_SECRET must be set in the BenchmarkConfig.") + return generate_token(app_id, app_secret) diff --git a/dev/benchmark/src/main.py b/dev/benchmark/src/main.py new file mode 100644 index 00000000..af9e177e --- /dev/null +++ b/dev/benchmark/src/main.py @@ -0,0 +1,49 @@ +import json +import logging +from datetime import datetime, timezone + +import click + +from .payload_sender import create_payload_sender +from .executor import Executor, CoroutineExecutor, ThreadExecutor +from .aggregated_results import AggregatedResults +from .config import BenchmarkConfig + +LOG_FORMAT = "%(asctime)s: %(message)s" +logging.basicConfig(format=LOG_FORMAT, level=logging.INFO, datefmt="%H:%M:%S") + +BenchmarkConfig.load_from_env() + + +@click.command() +@click.option( + "--payload_path", "-p", default="./payload.json", help="Path to the payload file." +) +@click.option("--num_workers", "-n", default=1, help="Number of workers to use.") +@click.option( + "--async_mode", + "-a", + is_flag=True, + help="Run coroutine workers rather than thread workers.", +) +def main(payload_path: str, num_workers: int, async_mode: bool): + """Main function to run the benchmark.""" + + with open(payload_path, "r", encoding="utf-8") as f: + payload = json.load(f) + + func = create_payload_sender(payload) + + executor: Executor = CoroutineExecutor() if async_mode else ThreadExecutor() + + start_time = datetime.now(timezone.utc).timestamp() + results = executor.run(func, num_workers=num_workers) + end_time = datetime.now(timezone.utc).timestamp() + + agg = AggregatedResults(results) + agg.display(start_time, end_time) + agg.display_timeline() + + +if __name__ == "__main__": + main() # pylint: disable=no-value-for-parameter diff --git a/dev/benchmark/src/payload_sender.py b/dev/benchmark/src/payload_sender.py new file mode 100644 index 00000000..a27f87c0 --- /dev/null +++ b/dev/benchmark/src/payload_sender.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import asyncio +import requests +from typing import Callable, Awaitable, Any + +from .config import BenchmarkConfig +from .generate_token import generate_token_from_env + + +def create_payload_sender( + payload: dict[str, Any], timeout: int = 60 +) -> Callable[..., Awaitable[Any]]: + """Create a payload sender function that sends the given payload to the configured endpoint. + + :param payload: The payload to be sent. + :param timeout: The timeout for the request in seconds. + :return: A callable that sends the payload when invoked. + """ + + token = generate_token_from_env() + endpoint = BenchmarkConfig.AGENT_URL + headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} + + async def payload_sender() -> Any: + response = await asyncio.to_thread( + requests.post, endpoint, headers=headers, json=payload, timeout=timeout + ) + return response.content + + return payload_sender diff --git a/dev_dependencies.txt b/dev_dependencies.txt index f1e775ef..86b3d4c8 100644 --- a/dev_dependencies.txt +++ b/dev_dependencies.txt @@ -1,4 +1,5 @@ pytest pytest-asyncio pytest-mock -pre-commit \ No newline at end of file +pre-commit +click \ No newline at end of file diff --git a/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py b/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py index aed16156..f9486cc4 100644 --- a/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py +++ b/libraries/microsoft-agents-authentication-msal/microsoft_agents/authentication/msal/msal_auth.py @@ -3,6 +3,7 @@ from __future__ import annotations +import asyncio import logging import jwt from typing import Optional @@ -39,12 +40,29 @@ def __str__(self): return f"Agentic blueprint id: {agentic_blueprint_id}" +async def _async_acquire_token_for_client(msal_auth_client, *args, **kwargs): + """MSAL in Python does not support async, so we use asyncio.to_thread to run it in + a separate thread and avoid blocking the event loop + """ + return await asyncio.to_thread( + lambda: msal_auth_client.acquire_token_for_client(*args, **kwargs) + ) + + class MsalAuth(AccessTokenProviderBase): _client_credential_cache = None def __init__(self, msal_configuration: AgentAuthConfiguration): + """Initializes the MsalAuth class with the given configuration. + + :param msal_configuration: The MSAL authentication configuration. Assumed to + not be mutated after being passed in. + :type msal_configuration: AgentAuthConfiguration + """ + self._msal_configuration = msal_configuration + self._msal_auth_client = None logger.debug( f"Initializing MsalAuth with configuration: {self._msal_configuration}" ) @@ -60,17 +78,17 @@ async def get_access_token( raise ValueError("Invalid instance URL") local_scopes = self._resolve_scopes_list(instance_uri, scopes) - msal_auth_client = self._create_client_application() + self._create_client_application() - if isinstance(msal_auth_client, ManagedIdentityClient): + if isinstance(self._msal_auth_client, ManagedIdentityClient): logger.info("Acquiring token using Managed Identity Client.") - auth_result_payload = msal_auth_client.acquire_token_for_client( - resource=resource_url + auth_result_payload = await _async_acquire_token_for_client( + self._msal_auth_client, resource=resource_url ) - elif isinstance(msal_auth_client, ConfidentialClientApplication): + elif isinstance(self._msal_auth_client, ConfidentialClientApplication): logger.info("Acquiring token using Confidential Client Application.") - auth_result_payload = msal_auth_client.acquire_token_for_client( - scopes=local_scopes + auth_result_payload = await _async_acquire_token_for_client( + self._msal_auth_client, scopes=local_scopes ) else: auth_result_payload = None @@ -79,6 +97,7 @@ async def get_access_token( if not res: logger.error("Failed to acquire token for resource %s", auth_result_payload) raise ValueError(f"Failed to acquire token. {str(auth_result_payload)}") + return res async def acquire_token_on_behalf_of( @@ -91,19 +110,23 @@ async def acquire_token_on_behalf_of( :return: The access token as a string. """ - msal_auth_client = self._create_client_application() - if isinstance(msal_auth_client, ManagedIdentityClient): + self._create_client_application() + if isinstance(self._msal_auth_client, ManagedIdentityClient): logger.error( "Attempted on-behalf-of flow with Managed Identity authentication." ) raise NotImplementedError( "On-behalf-of flow is not supported with Managed Identity authentication." ) - elif isinstance(msal_auth_client, ConfidentialClientApplication): + elif isinstance(self._msal_auth_client, ConfidentialClientApplication): # TODO: Handling token error / acquisition failed - token = msal_auth_client.acquire_token_on_behalf_of( - user_assertion=user_assertion, scopes=scopes + # MSAL in Python does not support async, so we use asyncio.to_thread to run it in + # a separate thread and avoid blocking the event loop + token = await asyncio.to_thread( + lambda: self._msal_auth_client.acquire_token_on_behalf_of( + scopes=scopes, user_assertion=user_assertion + ) ) if "access_token" not in token: @@ -115,19 +138,19 @@ async def acquire_token_on_behalf_of( return token["access_token"] logger.error( - f"On-behalf-of flow is not supported with the current authentication type: {msal_auth_client.__class__.__name__}" + f"On-behalf-of flow is not supported with the current authentication type: {self._msal_auth_client.__class__.__name__}" ) raise NotImplementedError( - f"On-behalf-of flow is not supported with the current authentication type: {msal_auth_client.__class__.__name__}" + f"On-behalf-of flow is not supported with the current authentication type: {self._msal_auth_client.__class__.__name__}" ) - def _create_client_application( - self, - ) -> ManagedIdentityClient | ConfidentialClientApplication: - msal_auth_client = None + def _create_client_application(self) -> None: + + if self._msal_auth_client: + return if self._msal_configuration.AUTH_TYPE == AuthTypes.user_managed_identity: - msal_auth_client = ManagedIdentityClient( + self._msal_auth_client = ManagedIdentityClient( UserAssignedManagedIdentity( client_id=self._msal_configuration.CLIENT_ID ), @@ -135,7 +158,7 @@ def _create_client_application( ) elif self._msal_configuration.AUTH_TYPE == AuthTypes.system_managed_identity: - msal_auth_client = ManagedIdentityClient( + self._msal_auth_client = ManagedIdentityClient( SystemAssignedManagedIdentity(), http_client=Session(), ) @@ -176,14 +199,12 @@ def _create_client_application( ) raise NotImplementedError("Authentication type not supported") - msal_auth_client = ConfidentialClientApplication( + self._msal_auth_client = ConfidentialClientApplication( client_id=self._msal_configuration.CLIENT_ID, authority=authority, client_credential=self._client_credential_cache, ) - return msal_auth_client - @staticmethod def _uri_validator(url_str: str) -> tuple[bool, Optional[URI]]: try: @@ -228,12 +249,13 @@ async def get_agentic_application_token( "Attempting to get agentic application token from agent_app_instance_id %s", agent_app_instance_id, ) - msal_auth_client = self._create_client_application() + self._create_client_application() - if isinstance(msal_auth_client, ConfidentialClientApplication): + if isinstance(self._msal_auth_client, ConfidentialClientApplication): # https://github.dev/AzureAD/microsoft-authentication-library-for-dotnet - auth_result_payload = msal_auth_client.acquire_token_for_client( + auth_result_payload = await _async_acquire_token_for_client( + self._msal_auth_client, ["api://AzureAdTokenExchange/.default"], data={"fmi_path": agent_app_instance_id}, ) @@ -284,8 +306,8 @@ async def get_agentic_instance_token( client_credential={"client_assertion": agent_token_result}, ) - agentic_instance_token = instance_app.acquire_token_for_client( - ["api://AzureAdTokenExchange/.default"] + agentic_instance_token = await _async_acquire_token_for_client( + instance_app, ["api://AzureAdTokenExchange/.default"] ) if not agentic_instance_token: @@ -363,7 +385,10 @@ async def get_agentic_user_token( agent_app_instance_id, agentic_user_id, ) - auth_result_payload = instance_app.acquire_token_for_client( + # MSAL in Python does not support async, so we use asyncio.to_thread to run it in + # a separate thread and avoid blocking the event loop + auth_result_payload = await _async_acquire_token_for_client( + instance_app, scopes, data={ "user_id": agentic_user_id, diff --git a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py index d28618cd..d3a2384d 100644 --- a/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py +++ b/libraries/microsoft-agents-hosting-aiohttp/microsoft_agents/hosting/aiohttp/jwt_authorization_middleware.py @@ -13,11 +13,12 @@ async def jwt_authorization_middleware(request: Request, handler): auth_config: AgentAuthConfiguration = request.app["agent_configuration"] token_validator = JwtTokenValidator(auth_config) auth_header = request.headers.get("Authorization") + if auth_header: # Extract the token from the Authorization header token = auth_header.split(" ")[1] try: - claims = token_validator.validate_token(token) + claims = await token_validator.validate_token(token) request["claims_identity"] = claims except ValueError as e: print(f"JWT validation error: {e}") @@ -44,7 +45,7 @@ async def wrapper(request): # Extract the token from the Authorization header token = auth_header.split(" ")[1] try: - claims = token_validator.validate_token(token) + claims = await token_validator.validate_token(token) request["claims_identity"] = claims except ValueError as e: print(f"JWT validation error: {e}") diff --git a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py index 399e101f..9069e81d 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py +++ b/libraries/microsoft-agents-hosting-core/microsoft_agents/hosting/core/authorization/jwt_token_validator.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import asyncio import logging import jwt @@ -16,10 +17,10 @@ class JwtTokenValidator: def __init__(self, configuration: AgentAuthConfiguration): self.configuration = configuration - def validate_token(self, token: str) -> ClaimsIdentity: + async def validate_token(self, token: str) -> ClaimsIdentity: logger.debug("Validating JWT token.") - key = self._get_public_key_or_secret(token) + key = await self._get_public_key_or_secret(token) decoded_token = jwt.decode( token, key=key, @@ -39,7 +40,7 @@ def get_anonymous_claims(self) -> ClaimsIdentity: logger.debug("Returning anonymous claims identity.") return ClaimsIdentity({}, False, authentication_type="Anonymous") - def _get_public_key_or_secret(self, token: str) -> PyJWK: + async def _get_public_key_or_secret(self, token: str) -> PyJWK: header = get_unverified_header(token) unverified_payload: dict = decode(token, options={"verify_signature": False}) @@ -50,5 +51,6 @@ def _get_public_key_or_secret(self, token: str) -> PyJWK: ) jwks_client = PyJWKClient(jwksUri) - key = jwks_client.get_signing_key(header["kid"]) + key = await asyncio.to_thread(jwks_client.get_signing_key, header["kid"]) + return key diff --git a/tests/_common/testing_objects/mocks/mock_msal_auth.py b/tests/_common/testing_objects/mocks/mock_msal_auth.py index f9a046b7..c9e9eb09 100644 --- a/tests/_common/testing_objects/mocks/mock_msal_auth.py +++ b/tests/_common/testing_objects/mocks/mock_msal_auth.py @@ -25,7 +25,8 @@ def __init__( ) self.mock_client = mock_client - self._create_client_application = mocker.Mock(return_value=self.mock_client) + def _create_client_application(self) -> None: + self._msal_auth_client = self.mock_client def agentic_mock_class_MsalAuth( diff --git a/tests/authentication_msal/test_msal_auth.py b/tests/authentication_msal/test_msal_auth.py index 7198d190..4368c63e 100644 --- a/tests/authentication_msal/test_msal_auth.py +++ b/tests/authentication_msal/test_msal_auth.py @@ -52,9 +52,6 @@ async def test_acquire_token_on_behalf_of_managed_identity(self, mocker): @pytest.mark.asyncio async def test_acquire_token_on_behalf_of_confidential(self, mocker): mock_auth = MockMsalAuth(mocker, ConfidentialClientApplication) - mock_auth._create_client_application = mocker.Mock( - return_value=mock_auth.mock_client - ) token = await mock_auth.acquire_token_on_behalf_of( scopes=["test-scope"], user_assertion="test-assertion"