|
7 | 7 | import json |
8 | 8 | import multiprocessing |
9 | 9 | import socket |
| 10 | +import threading |
10 | 11 | import time |
11 | 12 | from collections.abc import Generator |
12 | 13 | from typing import Any |
| 14 | +from urllib.parse import ParseResult, urlparse |
13 | 15 |
|
14 | 16 | import anyio |
15 | 17 | import httpx |
@@ -329,6 +331,18 @@ def basic_server_port() -> int: |
329 | 331 | return s.getsockname()[1] |
330 | 332 |
|
331 | 333 |
|
| 334 | +@pytest.fixture |
| 335 | +def proxy_port() -> int: |
| 336 | + with socket.socket() as s: |
| 337 | + s.bind(("127.0.0.1", 0)) |
| 338 | + return s.getsockname()[1] |
| 339 | + |
| 340 | + |
| 341 | +@pytest.fixture |
| 342 | +def proxy_url(proxy_port: int) -> str: |
| 343 | + return f"http://127.0.0.1:{proxy_port}" |
| 344 | + |
| 345 | + |
332 | 346 | @pytest.fixture |
333 | 347 | def json_server_port() -> int: |
334 | 348 | """Find an available port for the JSON response server.""" |
@@ -1600,3 +1614,123 @@ async def bad_client(): |
1600 | 1614 | assert isinstance(result, InitializeResult) |
1601 | 1615 | tools = await session.list_tools() |
1602 | 1616 | assert tools.tools |
| 1617 | + |
| 1618 | + |
| 1619 | +@pytest.fixture |
| 1620 | +def proxy_server(basic_server_url: str, proxy_port: int) -> Generator[str, None, None]: |
| 1621 | + BUFFER_SIZE: int = 4096 |
| 1622 | + parsed: ParseResult = urlparse(basic_server_url) |
| 1623 | + server_host: str = parsed.hostname or "127.0.0.1" |
| 1624 | + server_port: int = parsed.port or 80 |
| 1625 | + |
| 1626 | + def run_proxy(stop_event: threading.Event) -> None: |
| 1627 | + def handle_client(client_socket: socket.socket) -> None: |
| 1628 | + server_socket: socket.socket | None = None |
| 1629 | + try: |
| 1630 | + request: bytes = client_socket.recv(BUFFER_SIZE) |
| 1631 | + if not request: |
| 1632 | + return |
| 1633 | + |
| 1634 | + first_line, rest = request.split(b"\r\n", 1) |
| 1635 | + parts: list[str] = first_line.decode().split(" ") |
| 1636 | + if len(parts) != 3: |
| 1637 | + return # malformed |
| 1638 | + method, url, version = parts |
| 1639 | + |
| 1640 | + parsed_url: ParseResult = urlparse(url) |
| 1641 | + if parsed_url.scheme and parsed_url.netloc: |
| 1642 | + # absolute-form (proxy request) |
| 1643 | + path: str = parsed_url.path or "/" |
| 1644 | + if parsed_url.query: |
| 1645 | + path += "?" + parsed_url.query |
| 1646 | + else: |
| 1647 | + path = url |
| 1648 | + |
| 1649 | + fixed_first_line: bytes = f"{method} {path} {version}".encode() |
| 1650 | + new_request: bytes = b"\r\n".join([fixed_first_line, rest]) |
| 1651 | + |
| 1652 | + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 1653 | + server_socket.connect((server_host, server_port)) |
| 1654 | + server_socket.sendall(new_request) |
| 1655 | + print(f"[PROXY] Forwarding {method} {path} -> {server_host}:{server_port}") |
| 1656 | + |
| 1657 | + def forward(src: socket.socket, dst: socket.socket, direction: str) -> None: |
| 1658 | + while not stop_event.is_set(): |
| 1659 | + try: |
| 1660 | + data: bytes = src.recv(BUFFER_SIZE) |
| 1661 | + if not data: |
| 1662 | + break |
| 1663 | + dst.sendall(data) |
| 1664 | + except (ConnectionResetError, OSError): |
| 1665 | + break |
| 1666 | + |
| 1667 | + t1 = threading.Thread( |
| 1668 | + target=forward, args=(client_socket, server_socket, "client->server"), daemon=True |
| 1669 | + ) |
| 1670 | + t2 = threading.Thread( |
| 1671 | + target=forward, args=(server_socket, client_socket, "server->client"), daemon=True |
| 1672 | + ) |
| 1673 | + t1.start() |
| 1674 | + t2.start() |
| 1675 | + t1.join() |
| 1676 | + t2.join() |
| 1677 | + finally: |
| 1678 | + try: |
| 1679 | + client_socket.close() |
| 1680 | + except Exception: |
| 1681 | + pass |
| 1682 | + if server_socket: |
| 1683 | + try: |
| 1684 | + server_socket.close() |
| 1685 | + except Exception: |
| 1686 | + pass |
| 1687 | + print("[PROXY] Closed sockets") |
| 1688 | + |
| 1689 | + proxy_socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
| 1690 | + proxy_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 1691 | + proxy_socket.bind(("127.0.0.1", proxy_port)) |
| 1692 | + proxy_socket.listen(5) |
| 1693 | + |
| 1694 | + print(f"[PROXY] Listening on 127.0.0.1:{proxy_port}, forwarding to {server_host}:{server_port}") |
| 1695 | + |
| 1696 | + while not stop_event.is_set(): |
| 1697 | + proxy_socket.settimeout(1.0) |
| 1698 | + try: |
| 1699 | + client_socket, addr = proxy_socket.accept() |
| 1700 | + print(f"[PROXY] Accepted connection from {addr}") |
| 1701 | + threading.Thread(target=handle_client, args=(client_socket,), daemon=True).start() |
| 1702 | + except TimeoutError: |
| 1703 | + continue |
| 1704 | + except OSError: |
| 1705 | + break |
| 1706 | + |
| 1707 | + proxy_socket.close() |
| 1708 | + print("[PROXY] Proxy stopped") |
| 1709 | + |
| 1710 | + stop_event: threading.Event = threading.Event() |
| 1711 | + thread = threading.Thread(target=run_proxy, args=(stop_event,), daemon=True) |
| 1712 | + thread.start() |
| 1713 | + |
| 1714 | + proxy_url: str = f"http://127.0.0.1:{proxy_port}" |
| 1715 | + |
| 1716 | + yield proxy_url |
| 1717 | + |
| 1718 | + stop_event.set() |
| 1719 | + thread.join(timeout=2) |
| 1720 | + print("[PROXY] Fixture teardown complete") |
| 1721 | + |
| 1722 | + |
| 1723 | +# Example test |
| 1724 | +@pytest.mark.anyio |
| 1725 | +async def test_streamable_client_proxy_config( |
| 1726 | + basic_server: None, proxy_server: str, proxy_url: str, basic_server_url: str |
| 1727 | +) -> None: |
| 1728 | + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( |
| 1729 | + read_stream, |
| 1730 | + write_stream, |
| 1731 | + _, |
| 1732 | + ): |
| 1733 | + async with ClientSession(read_stream, write_stream) as session: |
| 1734 | + result = await session.initialize() |
| 1735 | + assert isinstance(result, InitializeResult) |
| 1736 | + assert result.serverInfo.name == SERVER_NAME |
0 commit comments