Skip to content

Commit 9419ad0

Browse files
committed
Move to session heartbeat w/ TTL
1 parent 564561f commit 9419ad0

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/mcp/server/message_queue/redis.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,20 @@ class RedisMessageDispatch:
3030
"""
3131

3232
def __init__(
33-
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:pubsub:"
33+
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:pubsub:",
34+
session_ttl: int = 3600 # 1 hour default TTL for sessions
3435
) -> None:
3536
"""Initialize Redis message dispatch.
3637
3738
Args:
3839
redis_url: Redis connection string
3940
prefix: Key prefix for Redis channels to avoid collisions
41+
session_ttl: TTL in seconds for session keys (default: 1 hour)
4042
"""
4143
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
4244
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
4345
self._prefix = prefix
44-
self._active_sessions_key = f"{prefix}active_sessions"
46+
self._session_ttl = session_ttl
4547
# Maps session IDs to the callback and task group for that SSE session.
4648
self._session_state: dict[UUID, tuple[MessageCallback, TaskGroup]] = {}
4749
# Ensures only one polling task runs at a time for message handling
@@ -56,27 +58,50 @@ def _session_channel(self, session_id: UUID) -> str:
5658
"""Get the Redis channel for a session."""
5759
return f"{self._prefix}session:{session_id.hex}"
5860

61+
def _session_key(self, session_id: UUID) -> str:
62+
"""Get the Redis key for a session."""
63+
return f"{self._prefix}session_active:{session_id.hex}"
64+
5965
@asynccontextmanager
6066
async def subscribe(self, session_id: UUID, callback: MessageCallback):
6167
"""Request-scoped context manager that subscribes to messages for a session."""
62-
await self._redis.sadd(self._active_sessions_key, session_id.hex)
68+
session_key = self._session_key(session_id)
69+
await self._redis.setex(session_key, self._session_ttl, "1") # type: ignore
70+
6371
channel = self._session_channel(session_id)
6472
await self._pubsub.subscribe(channel) # type: ignore
6573

6674
logger.debug(f"Subscribing to Redis channel for session {session_id}")
6775
async with anyio.create_task_group() as tg:
6876
self._session_state[session_id] = (callback, tg)
6977
tg.start_soon(self._listen_for_messages)
78+
# Start heartbeat for this session
79+
tg.start_soon(self._session_heartbeat, session_id)
7080
try:
7181
yield
7282
finally:
7383
with anyio.CancelScope(shield=True):
7484
tg.cancel_scope.cancel()
7585
await self._pubsub.unsubscribe(channel) # type: ignore
76-
await self._redis.srem(self._active_sessions_key, session_id.hex)
86+
await self._redis.delete(session_key) # type: ignore
7787
del self._session_state[session_id]
7888
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
7989

90+
async def _session_heartbeat(self, session_id: UUID) -> None:
91+
"""Periodically refresh the TTL for a session."""
92+
session_key = self._session_key(session_id)
93+
while True:
94+
await lowlevel.checkpoint()
95+
try:
96+
# Refresh TTL at half the TTL interval to avoid expiration
97+
await anyio.sleep(self._session_ttl / 2)
98+
with anyio.CancelScope(shield=True):
99+
await self._redis.expire(session_key, self._session_ttl) # type: ignore
100+
except anyio.get_cancelled_exc_class():
101+
break
102+
except Exception as e:
103+
logger.error(f"Error refreshing TTL for session {session_id}: {e}")
104+
80105
def _extract_session_id(self, channel: str) -> UUID | None:
81106
"""Extract and validate session ID from channel."""
82107
expected_prefix = f"{self._prefix}session:"
@@ -167,6 +192,5 @@ async def publish_message(
167192

168193
async def session_exists(self, session_id: UUID) -> bool:
169194
"""Check if a session exists."""
170-
return bool(
171-
await self._redis.sismember(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
172-
)
195+
session_key = self._session_key(session_id)
196+
return bool(await self._redis.exists(session_key)) # type: ignore

0 commit comments

Comments
 (0)