1919 create_pool ,
2020 decode_hosts ,
2121)
22+ from typing import TYPE_CHECKING , Dict , List , Tuple , Union , Optional
23+
24+ if TYPE_CHECKING :
25+ from redis .asyncio .connection import ConnectionPool
26+ from redis .asyncio .client import Redis
27+ from .core import RedisChannelLayer
28+ from _typeshed import ReadableBuffer
29+ from redis .typing import TimeoutSecT
2230
2331logger = logging .getLogger (__name__ )
2432
@@ -32,23 +40,27 @@ class ChannelLock:
3240 """
3341
3442 def __init__ (self ):
35- self .locks = collections .defaultdict (asyncio .Lock )
36- self .wait_counts = collections .defaultdict (int )
43+ self .locks : "collections.defaultdict[str, asyncio.Lock]" = (
44+ collections .defaultdict (asyncio .Lock )
45+ )
46+ self .wait_counts : "collections.defaultdict[str, int]" = collections .defaultdict (
47+ int
48+ )
3749
38- async def acquire (self , channel ) :
50+ async def acquire (self , channel : str ) -> bool :
3951 """
4052 Acquire the lock for the given channel.
4153 """
4254 self .wait_counts [channel ] += 1
4355 return await self .locks [channel ].acquire ()
4456
45- def locked (self , channel ) :
57+ def locked (self , channel : str ) -> bool :
4658 """
4759 Return ``True`` if the lock for the given channel is acquired.
4860 """
4961 return self .locks [channel ].locked ()
5062
51- def release (self , channel ):
63+ def release (self , channel : str ):
5264 """
5365 Release the lock for the given channel.
5466 """
@@ -73,12 +85,12 @@ def put_nowait(self, item):
7385
7486
7587class RedisLoopLayer :
76- def __init__ (self , channel_layer ):
88+ def __init__ (self , channel_layer : "RedisChannelLayer" ):
7789 self ._lock = asyncio .Lock ()
7890 self .channel_layer = channel_layer
79- self ._connections = {}
91+ self ._connections : "Dict[int, Redis]" = {}
8092
81- def get_connection (self , index ) :
93+ def get_connection (self , index : int ) -> "Redis" :
8294 if index not in self ._connections :
8395 pool = self .channel_layer .create_pool (index )
8496 self ._connections [index ] = aioredis .Redis (connection_pool = pool )
@@ -134,7 +146,7 @@ def __init__(
134146 symmetric_encryption_keys = symmetric_encryption_keys ,
135147 )
136148 # Cached redis connection pools and the event loop they are from
137- self ._layers = {}
149+ self ._layers : "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
138150 # Normal channels choose a host index by cycling through the available hosts
139151 self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
140152 self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -143,33 +155,33 @@ def __init__(
143155 # Number of coroutines trying to receive right now
144156 self .receive_count = 0
145157 # The receive lock
146- self .receive_lock = None
158+ self .receive_lock : "Optional[asyncio.Lock]" = None
147159 # Event loop they are trying to receive on
148- self .receive_event_loop = None
160+ self .receive_event_loop : "Optional[asyncio.AbstractEventLoop]" = None
149161 # Buffered messages by process-local channel name
150- self .receive_buffer = collections .defaultdict (
151- functools .partial (BoundedQueue , self .capacity )
162+ self .receive_buffer : " collections.defaultdict[str, BoundedQueue]" = (
163+ collections . defaultdict ( functools .partial (BoundedQueue , self .capacity ) )
152164 )
153165 # Detached channel cleanup tasks
154- self .receive_cleaners = []
166+ self .receive_cleaners : "List[asyncio.Task]" = []
155167 # Per-channel cleanup locks to prevent a receive starting and moving
156168 # a message back into the main queue before its cleanup has completed
157169 self .receive_clean_locks = ChannelLock ()
158170
159- def create_pool (self , index ) :
171+ def create_pool (self , index : int ) -> "ConnectionPool" :
160172 return create_pool (self .hosts [index ])
161173
162174 ### Channel layer API ###
163175
164176 extensions = ["groups" , "flush" ]
165177
166- async def send (self , channel , message ):
178+ async def send (self , channel : str , message ):
167179 """
168180 Send a message onto a (general or specific) channel.
169181 """
170182 # Typecheck
171183 assert isinstance (message , dict ), "message is not a dict"
172- assert self .valid_channel_name (channel ), "Channel name not valid"
184+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
173185 # Make sure the message does not contain reserved keys
174186 assert "__asgi_channel__" not in message
175187 # If it's a process-local channel, strip off local part and stick full name in message
@@ -203,13 +215,13 @@ async def send(self, channel, message):
203215 await connection .zadd (channel_key , {self .serialize (message ): time .time ()})
204216 await connection .expire (channel_key , int (self .expiry ))
205217
206- def _backup_channel_name (self , channel ) :
218+ def _backup_channel_name (self , channel : str ) -> str :
207219 """
208220 Construct the key used as a backup queue for the given channel.
209221 """
210222 return channel + "$inflight"
211223
212- async def _brpop_with_clean (self , index , channel , timeout ):
224+ async def _brpop_with_clean (self , index : int , channel : str , timeout : "TimeoutSecT" ):
213225 """
214226 Perform a Redis BRPOP and manage the backup processing queue.
215227 In case of cancellation, make sure the message is not lost.
@@ -240,23 +252,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
240252
241253 return member
242254
243- async def _clean_receive_backup (self , index , channel ):
255+ async def _clean_receive_backup (self , index : int , channel : str ):
244256 """
245257 Pop the oldest message off the channel backup queue.
246258 The result isn't interesting as it was already processed.
247259 """
248260 connection = self .connection (index )
249261 await connection .zpopmin (self ._backup_channel_name (channel ))
250262
251- async def receive (self , channel ):
263+ async def receive (self , channel : str ):
252264 """
253265 Receive the first message that arrives on the channel.
254266 If more than one coroutine waits on the same channel, the first waiter
255267 will be given the message when it arrives.
256268 """
257269 # Make sure the channel name is valid then get the non-local part
258270 # and thus its index
259- assert self .valid_channel_name (channel )
271+ assert self .require_valid_channel_name (channel )
260272 if "!" in channel :
261273 real_channel = self .non_local_name (channel )
262274 assert real_channel .endswith (
@@ -372,12 +384,14 @@ async def receive(self, channel):
372384 # Do a plain direct receive
373385 return (await self .receive_single (channel ))[1 ]
374386
375- async def receive_single (self , channel ) :
387+ async def receive_single (self , channel : str ) -> "Tuple" :
376388 """
377389 Receives a single message off of the channel and returns it.
378390 """
379391 # Check channel name
380- assert self .valid_channel_name (channel , receive = True ), "Channel name invalid"
392+ assert self .require_valid_channel_name (
393+ channel , receive = True
394+ ), "Channel name invalid"
381395 # Work out the connection to use
382396 if "!" in channel :
383397 assert channel .endswith ("!" )
@@ -408,7 +422,7 @@ async def receive_single(self, channel):
408422 )
409423 self .receive_cleaners .append (cleaner )
410424
411- def _cleanup_done (cleaner ):
425+ def _cleanup_done (cleaner : "asyncio.Task" ):
412426 self .receive_cleaners .remove (cleaner )
413427 self .receive_clean_locks .release (channel_key )
414428
@@ -427,7 +441,7 @@ def _cleanup_done(cleaner):
427441 del message ["__asgi_channel__" ]
428442 return channel , message
429443
430- async def new_channel (self , prefix = "specific" ):
444+ async def new_channel (self , prefix : str = "specific" ) -> str :
431445 """
432446 Returns a new channel name that can be used by something in our
433447 process as a specific channel.
@@ -477,13 +491,13 @@ async def wait_received(self):
477491
478492 ### Groups extension ###
479493
480- async def group_add (self , group , channel ):
494+ async def group_add (self , group : str , channel : str ):
481495 """
482496 Adds the channel name to a group.
483497 """
484498 # Check the inputs
485- assert self .valid_group_name (group ), "Group name not valid"
486- assert self .valid_channel_name (channel ), "Channel name not valid"
499+ assert self .require_valid_group_name (group ), True
500+ assert self .require_valid_channel_name (channel ), True
487501 # Get a connection to the right shard
488502 group_key = self ._group_key (group )
489503 connection = self .connection (self .consistent_hash (group ))
@@ -493,22 +507,22 @@ async def group_add(self, group, channel):
493507 # it at this point is guaranteed to expire before that
494508 await connection .expire (group_key , self .group_expiry )
495509
496- async def group_discard (self , group , channel ):
510+ async def group_discard (self , group : str , channel : str ):
497511 """
498512 Removes the channel from the named group if it is in the group;
499513 does nothing otherwise (does not error)
500514 """
501- assert self .valid_group_name (group ), "Group name not valid"
502- assert self .valid_channel_name (channel ), "Channel name not valid"
515+ assert self .require_valid_group_name (group ), "Group name not valid"
516+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
503517 key = self ._group_key (group )
504518 connection = self .connection (self .consistent_hash (group ))
505519 await connection .zrem (key , channel )
506520
507- async def group_send (self , group , message ):
521+ async def group_send (self , group : str , message ):
508522 """
509523 Sends a message to the entire group.
510524 """
511- assert self .valid_group_name (group ), "Group name not valid"
525+ assert self .require_valid_group_name (group ), "Group name not valid"
512526 # Retrieve list of all channel names
513527 key = self ._group_key (group )
514528 connection = self .connection (self .consistent_hash (group ))
@@ -573,7 +587,12 @@ async def group_send(self, group, message):
573587 channels_over_capacity = await connection .eval (
574588 group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
575589 )
576- if channels_over_capacity > 0 :
590+ _channels_over_capacity = - 1
591+ try :
592+ _channels_over_capacity = float (channels_over_capacity )
593+ except Exception :
594+ pass
595+ if _channels_over_capacity > 0 :
577596 logger .info (
578597 "%s of %s channels over capacity in group %s" ,
579598 channels_over_capacity ,
@@ -631,37 +650,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
631650 channel_key_to_capacity ,
632651 )
633652
634- def _group_key (self , group ) :
653+ def _group_key (self , group : str ) -> bytes :
635654 """
636655 Common function to make the storage key for the group.
637656 """
638657 return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
639658
640- ### Serialization ###
641-
642- def serialize (self , message ):
659+ def serialize (self , message ) -> bytes :
643660 """
644661 Serializes message to a byte string.
645662 """
646663 return self ._serializer .serialize (message )
647664
648- def deserialize (self , message ):
665+ def deserialize (self , message : bytes ):
649666 """
650667 Deserializes from a byte string.
651668 """
652669 return self ._serializer .deserialize (message )
653670
654671 ### Internal functions ###
655672
656- def consistent_hash (self , value ) :
673+ def consistent_hash (self , value : "Union[str, ReadableBuffer]" ) -> int :
657674 return _consistent_hash (value , self .ring_size )
658675
659676 def __str__ (self ):
660677 return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
661678
662679 ### Connection handling ###
663680
664- def connection (self , index ) :
681+ def connection (self , index : int ) -> "Redis" :
665682 """
666683 Returns the correct connection for the index given.
667684 Lazily instantiates pools.
0 commit comments