11"""Tests for instrumentation interface."""
22
3+ from typing import Any
4+
35import pytest
46
57from mcp .shared .instrumentation import NoOpInstrumenter , get_default_instrumenter
8+ from mcp .types import RequestId
69
710
811class MockInstrumenter :
912 """Track calls to instrumentation hooks for testing."""
1013
11- def __init__ (self ):
12- self .calls = []
14+ def __init__ (self ) -> None :
15+ self .calls : list [ dict [ str , Any ]] = []
1316
14- def on_request_start (self , request_id , request_type , method = None , ** metadata ):
15- call = {
17+ def on_request_start (
18+ self , request_id : RequestId , request_type : str , method : str | None = None , ** metadata : Any
19+ ) -> dict [str , Any ]:
20+ call : dict [str , Any ] = {
1621 "hook" : "on_request_start" ,
1722 "request_id" : request_id ,
1823 "request_type" : request_type ,
@@ -23,7 +28,15 @@ def on_request_start(self, request_id, request_type, method=None, **metadata):
2328 # Return the call itself as a token for testing
2429 return call
2530
26- def on_request_end (self , token , request_id , request_type , success , duration_seconds = None , ** metadata ):
31+ def on_request_end (
32+ self ,
33+ token : Any ,
34+ request_id : RequestId ,
35+ request_type : str ,
36+ success : bool ,
37+ duration_seconds : float | None = None ,
38+ ** metadata : Any ,
39+ ) -> None :
2740 self .calls .append (
2841 {
2942 "hook" : "on_request_end" ,
@@ -36,7 +49,9 @@ def on_request_end(self, token, request_id, request_type, success, duration_seco
3649 }
3750 )
3851
39- def on_error (self , token , request_id , error , error_type , ** metadata ):
52+ def on_error (
53+ self , token : Any , request_id : RequestId | None , error : Exception , error_type : str , ** metadata : Any
54+ ) -> None :
4055 self .calls .append (
4156 {
4257 "hook" : "on_error" ,
@@ -48,11 +63,11 @@ def on_error(self, token, request_id, error, error_type, **metadata):
4863 }
4964 )
5065
51- def get_calls_by_hook (self , hook_name ) :
66+ def get_calls_by_hook (self , hook_name : str ) -> list [ dict [ str , Any ]] :
5267 """Get all calls to a specific hook."""
5368 return [call for call in self .calls if call ["hook" ] == hook_name ]
5469
55- def get_calls_by_request_id (self , request_id ) :
70+ def get_calls_by_request_id (self , request_id : RequestId ) -> list [ dict [ str , Any ]] :
5671 """Get all calls for a specific request_id."""
5772 return [call for call in self .calls if call .get ("request_id" ) == request_id ]
5873
0 commit comments