@@ -30,18 +30,20 @@ class RedisMessageDispatch:
3030 """
3131
3232 def __init__ (
33- self , redis_url : str = "redis://localhost:6379/0" , prefix : str = "mcp:pubsub:"
33+ self , redis_url : str = "redis://localhost:6379/0" , prefix : str = "mcp:pubsub:" ,
34+ session_ttl : int = 3600 # 1 hour default TTL for sessions
3435 ) -> None :
3536 """Initialize Redis message dispatch.
3637
3738 Args:
3839 redis_url: Redis connection string
3940 prefix: Key prefix for Redis channels to avoid collisions
41+ session_ttl: TTL in seconds for session keys (default: 1 hour)
4042 """
4143 self ._redis = redis .from_url (redis_url , decode_responses = True ) # type: ignore
4244 self ._pubsub = self ._redis .pubsub (ignore_subscribe_messages = True ) # type: ignore
4345 self ._prefix = prefix
44- self ._active_sessions_key = f" { prefix } active_sessions"
46+ self ._session_ttl = session_ttl
4547 # Maps session IDs to the callback and task group for that SSE session.
4648 self ._session_state : dict [UUID , tuple [MessageCallback , TaskGroup ]] = {}
4749 # Ensures only one polling task runs at a time for message handling
@@ -56,27 +58,50 @@ def _session_channel(self, session_id: UUID) -> str:
5658 """Get the Redis channel for a session."""
5759 return f"{ self ._prefix } session:{ session_id .hex } "
5860
61+ def _session_key (self , session_id : UUID ) -> str :
62+ """Get the Redis key for a session."""
63+ return f"{ self ._prefix } session_active:{ session_id .hex } "
64+
5965 @asynccontextmanager
6066 async def subscribe (self , session_id : UUID , callback : MessageCallback ):
6167 """Request-scoped context manager that subscribes to messages for a session."""
62- await self ._redis .sadd (self ._active_sessions_key , session_id .hex )
68+ session_key = self ._session_key (session_id )
69+ await self ._redis .setex (session_key , self ._session_ttl , "1" ) # type: ignore
70+
6371 channel = self ._session_channel (session_id )
6472 await self ._pubsub .subscribe (channel ) # type: ignore
6573
6674 logger .debug (f"Subscribing to Redis channel for session { session_id } " )
6775 async with anyio .create_task_group () as tg :
6876 self ._session_state [session_id ] = (callback , tg )
6977 tg .start_soon (self ._listen_for_messages )
78+ # Start heartbeat for this session
79+ tg .start_soon (self ._session_heartbeat , session_id )
7080 try :
7181 yield
7282 finally :
7383 with anyio .CancelScope (shield = True ):
7484 tg .cancel_scope .cancel ()
7585 await self ._pubsub .unsubscribe (channel ) # type: ignore
76- await self ._redis .srem ( self . _active_sessions_key , session_id . hex )
86+ await self ._redis .delete ( session_key ) # type: ignore
7787 del self ._session_state [session_id ]
7888 logger .debug (f"Unsubscribed from Redis channel: { session_id } " )
7989
90+ async def _session_heartbeat (self , session_id : UUID ) -> None :
91+ """Periodically refresh the TTL for a session."""
92+ session_key = self ._session_key (session_id )
93+ while True :
94+ await lowlevel .checkpoint ()
95+ try :
96+ # Refresh TTL at half the TTL interval to avoid expiration
97+ await anyio .sleep (self ._session_ttl / 2 )
98+ with anyio .CancelScope (shield = True ):
99+ await self ._redis .expire (session_key , self ._session_ttl ) # type: ignore
100+ except anyio .get_cancelled_exc_class ():
101+ break
102+ except Exception as e :
103+ logger .error (f"Error refreshing TTL for session { session_id } : { e } " )
104+
80105 def _extract_session_id (self , channel : str ) -> UUID | None :
81106 """Extract and validate session ID from channel."""
82107 expected_prefix = f"{ self ._prefix } session:"
@@ -167,6 +192,5 @@ async def publish_message(
167192
168193 async def session_exists (self , session_id : UUID ) -> bool :
169194 """Check if a session exists."""
170- return bool (
171- await self ._redis .sismember (self ._active_sessions_key , session_id .hex ) # type: ignore[attr-defined]
172- )
195+ session_key = self ._session_key (session_id )
196+ return bool (await self ._redis .exists (session_key )) # type: ignore
0 commit comments