33from __future__ import annotations
44
55import json
6+ import logging
67import sqlite3
78import time
89from collections import deque
1314import uvicorn
1415from mcp import types
1516from mcp .server .fastmcp import FastMCP
17+ from mcp .server .fastmcp .server import Context
1618from mcp .server .session import ServerSession
1719from mcp .shared ._httpx_utils import create_mcp_http_client
1820from mcp .shared .async_operations import (
1921 AsyncOperationBroker ,
2022 AsyncOperationStore ,
23+ OperationEventQueue ,
2124 PendingAsyncTask ,
2225 ServerAsyncOperation ,
2326 ServerAsyncOperationManager ,
2427)
25- from mcp .shared .context import RequestContext
28+ from mcp .shared .context import RequestContext , SerializableRequestContext
2629from mcp .types import AsyncOperationStatus , CallToolResult
30+ from pydantic import BaseModel , Field
31+
32+ logger = logging .getLogger (__name__ )
2733
2834
2935class SQLiteAsyncOperationStore (AsyncOperationStore ):
@@ -207,6 +213,78 @@ async def cleanup_expired(self) -> int:
207213 return cursor .rowcount
208214
209215
216+ class SQLiteOperationEventQueue (OperationEventQueue ):
217+ """SQLite-based implementation of OperationEventQueue for operation-specific event delivery."""
218+
219+ def __init__ (self , db_path : str = "async_operations.db" ):
220+ self .db_path = db_path
221+ self ._init_db ()
222+
223+ def _init_db (self ):
224+ """Initialize the SQLite database for operation event queuing."""
225+ with sqlite3 .connect (self .db_path ) as conn :
226+ conn .execute ("""
227+ CREATE TABLE IF NOT EXISTS operation_events (
228+ id INTEGER PRIMARY KEY AUTOINCREMENT,
229+ operation_token TEXT NOT NULL,
230+ message TEXT NOT NULL,
231+ created_at REAL NOT NULL
232+ )
233+ """ )
234+ conn .execute ("""
235+ CREATE INDEX IF NOT EXISTS idx_operation_events_token_created
236+ ON operation_events(operation_token, created_at)
237+ """ )
238+ conn .commit ()
239+
240+ async def enqueue_event (self , operation_token : str , message : types .JSONRPCMessage ) -> None :
241+ """Enqueue an event for a specific operation token."""
242+ message_json = json .dumps (message .model_dump ())
243+ created_at = time .time ()
244+
245+ with sqlite3 .connect (self .db_path ) as conn :
246+ conn .execute (
247+ """
248+ INSERT INTO operation_events (operation_token, message, created_at)
249+ VALUES (?, ?, ?)
250+ """ ,
251+ (operation_token , message_json , created_at ),
252+ )
253+ conn .commit ()
254+
255+ async def dequeue_events (self , operation_token : str ) -> list [types .JSONRPCMessage ]:
256+ """Dequeue all pending events for a specific operation token."""
257+ with sqlite3 .connect (self .db_path ) as conn :
258+ conn .row_factory = sqlite3 .Row
259+
260+ # Get all events for this operation token
261+ cursor = conn .execute (
262+ """
263+ SELECT id, message FROM operation_events
264+ WHERE operation_token = ?
265+ ORDER BY created_at
266+ """ ,
267+ (operation_token ,),
268+ )
269+
270+ events : list [types .JSONRPCMessage ] = []
271+ event_ids : list [int ] = []
272+
273+ for row in cursor :
274+ event_ids .append (row ["id" ])
275+ message_data = json .loads (row ["message" ])
276+ message = types .JSONRPCMessage .model_validate (message_data )
277+ events .append (message )
278+
279+ # Delete the dequeued events
280+ if event_ids :
281+ placeholders = "," .join ("?" * len (event_ids ))
282+ conn .execute (f"DELETE FROM operation_events WHERE id IN ({ placeholders } )" , event_ids )
283+ conn .commit ()
284+
285+ return events
286+
287+
210288class SQLiteAsyncOperationBroker (AsyncOperationBroker ):
211289 """SQLite-based implementation of AsyncOperationBroker for persistent task queuing."""
212290
@@ -234,23 +312,19 @@ def _load_persisted_tasks_sync(self):
234312 if op_row and op_row ["status" ] in ("completed" , "failed" , "canceled" ):
235313 continue
236314
237- # Reconstruct serializable parts of RequestContext
238- from mcp .shared .context import SerializableRequestContext
239-
240- serializable_context = None
241- if row ["request_id" ]:
242- serializable_context = SerializableRequestContext (
243- request_id = row ["request_id" ],
244- operation_token = row ["operation_token" ],
245- meta = json .loads (row ["meta" ]) if row ["meta" ] else None ,
246- supports_async = bool (row ["supports_async" ]),
247- )
315+ # Reconstruct context - the server will hydrate the session
316+ request_context = SerializableRequestContext (
317+ request_id = row ["request_id" ],
318+ operation_token = row ["operation_token" ],
319+ meta = json .loads (row ["meta" ]) if row ["meta" ] else None ,
320+ supports_async = bool (row ["supports_async" ]),
321+ )
248322
249323 task = PendingAsyncTask (
250324 token = row ["token" ],
251325 tool_name = row ["tool_name" ],
252326 arguments = json .loads (row ["arguments" ]),
253- request_context = serializable_context ,
327+ request_context = request_context ,
254328 )
255329 self ._task_queue .append (task )
256330
@@ -329,6 +403,10 @@ async def complete_task(self, token: str) -> None:
329403 conn .commit ()
330404
331405
406+ class UserPreferences (BaseModel ):
407+ continue_processing : bool = Field (description = "Should we continue with the operation?" )
408+
409+
332410@click .command ()
333411@click .option ("--port" , default = 8000 , help = "Port to listen on for HTTP" )
334412@click .option (
@@ -341,31 +419,54 @@ async def complete_task(self, token: str) -> None:
341419def main (port : int , transport : str , db_path : str ):
342420 """Run the SQLite async operations example server."""
343421 # Create components with specified database path
422+ operation_event_queue = SQLiteOperationEventQueue (db_path )
344423 broker = SQLiteAsyncOperationBroker (db_path )
345- store = SQLiteAsyncOperationStore (db_path ) # No broker reference needed
346- manager = ServerAsyncOperationManager (store = store , broker = broker )
347- mcp = FastMCP ("SQLite Async Operations Demo" , async_operations = manager )
424+ store = SQLiteAsyncOperationStore (db_path )
425+ manager = ServerAsyncOperationManager (store = store , broker = broker , operation_request_queue = operation_event_queue )
426+ mcp = FastMCP (
427+ "SQLite Async Operations Demo" ,
428+ operation_event_queue = operation_event_queue ,
429+ async_operations = manager ,
430+ )
348431
349432 @mcp .tool (invocation_modes = ["async" ])
350433 async def fetch_website (
351434 url : str ,
435+ ctx : Context [ServerSession , None ],
352436 ) -> list [types .ContentBlock ]:
353437 headers = {"User-Agent" : "MCP Test Server (github.com/modelcontextprotocol/python-sdk)" }
354438 async with create_mcp_http_client (headers = headers ) as client :
439+ logger .info ("Entered fetch_website" )
440+
441+ # Simulate delay
355442 await anyio .sleep (10 )
443+
444+ # Request approval from user
445+ logger .info ("Sending elicitation to confirm" )
446+ result = await ctx .elicit (
447+ message = f"Please confirm that you would like to fetch from { url } ." ,
448+ schema = UserPreferences ,
449+ )
450+ logger .info (f"Elicitation result: { result } " )
451+
452+ if result .action != "accept" or not result .data .continue_processing :
453+ return [types .TextContent (type = "text" , text = "Operation cancelled by user" )]
454+
455+ logger .info (f"Fetching { url } " )
356456 response = await client .get (url )
357457 response .raise_for_status ()
458+ logger .info ("Returning fetch result" )
358459 return [types .TextContent (type = "text" , text = response .text )]
359460
360- print (f"Starting server with SQLite database: { db_path } " )
361- print ("Pending tasks will be automatically restarted on server restart!" )
461+ logger . info (f"Starting server with SQLite database: { db_path } " )
462+ logger . info ("Pending tasks will be automatically restarted on server restart!" )
362463
363464 if transport == "stdio" :
364465 mcp .run (transport = "stdio" )
365466 elif transport == "streamable-http" :
366467 app = mcp .streamable_http_app ()
367468 server = uvicorn .Server (config = uvicorn .Config (app = app , host = "127.0.0.1" , port = port , log_level = "error" ))
368- print (f"Starting { transport } server on port { port } " )
469+ logger . info (f"Starting { transport } server on port { port } " )
369470 server .run ()
370471 else :
371472 raise ValueError (f"Invalid transport for test server: { transport } " )
0 commit comments