Skip to content

Commit f631cd4

Browse files
committed
fix: add type annotations to MockInstrumenter for pyright
Added full type hints to MockInstrumenter class to resolve pyright type checking errors. This ensures the test helper class properly implements the Instrumenter protocol with correct types.
1 parent 00a63c8 commit f631cd4

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

tests/shared/test_instrumentation.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
11
"""Tests for instrumentation interface."""
22

3+
from typing import Any
4+
35
import pytest
46

57
from mcp.shared.instrumentation import NoOpInstrumenter, get_default_instrumenter
8+
from mcp.types import RequestId
69

710

811
class MockInstrumenter:
912
"""Track calls to instrumentation hooks for testing."""
1013

11-
def __init__(self):
12-
self.calls = []
14+
def __init__(self) -> None:
15+
self.calls: list[dict[str, Any]] = []
1316

14-
def on_request_start(self, request_id, request_type, method=None, **metadata):
15-
call = {
17+
def on_request_start(
18+
self, request_id: RequestId, request_type: str, method: str | None = None, **metadata: Any
19+
) -> dict[str, Any]:
20+
call: dict[str, Any] = {
1621
"hook": "on_request_start",
1722
"request_id": request_id,
1823
"request_type": request_type,
@@ -23,7 +28,15 @@ def on_request_start(self, request_id, request_type, method=None, **metadata):
2328
# Return the call itself as a token for testing
2429
return call
2530

26-
def on_request_end(self, token, request_id, request_type, success, duration_seconds=None, **metadata):
31+
def on_request_end(
32+
self,
33+
token: Any,
34+
request_id: RequestId,
35+
request_type: str,
36+
success: bool,
37+
duration_seconds: float | None = None,
38+
**metadata: Any,
39+
) -> None:
2740
self.calls.append(
2841
{
2942
"hook": "on_request_end",
@@ -36,7 +49,9 @@ def on_request_end(self, token, request_id, request_type, success, duration_seco
3649
}
3750
)
3851

39-
def on_error(self, token, request_id, error, error_type, **metadata):
52+
def on_error(
53+
self, token: Any, request_id: RequestId | None, error: Exception, error_type: str, **metadata: Any
54+
) -> None:
4055
self.calls.append(
4156
{
4257
"hook": "on_error",
@@ -48,11 +63,11 @@ def on_error(self, token, request_id, error, error_type, **metadata):
4863
}
4964
)
5065

51-
def get_calls_by_hook(self, hook_name):
66+
def get_calls_by_hook(self, hook_name: str) -> list[dict[str, Any]]:
5267
"""Get all calls to a specific hook."""
5368
return [call for call in self.calls if call["hook"] == hook_name]
5469

55-
def get_calls_by_request_id(self, request_id):
70+
def get_calls_by_request_id(self, request_id: RequestId) -> list[dict[str, Any]]:
5671
"""Get all calls for a specific request_id."""
5772
return [call for call in self.calls if call.get("request_id") == request_id]
5873

0 commit comments

Comments
 (0)