1919 create_pool ,
2020 decode_hosts ,
2121)
22+ from typing import TYPE_CHECKING , Dict , List , Tuple , Union , Optional
23+
24+ if TYPE_CHECKING :
25+ from redis .asyncio .connection import ConnectionPool
26+ from redis .asyncio .client import Redis
27+ from .core import RedisChannelLayer
28+ from typing_extensions import Buffer
2229
2330logger = logging .getLogger (__name__ )
2431
@@ -32,23 +39,27 @@ class ChannelLock:
3239 """
3340
3441 def __init__ (self ):
35- self .locks = collections .defaultdict (asyncio .Lock )
36- self .wait_counts = collections .defaultdict (int )
42+ self .locks : "collections.defaultdict[str, asyncio.Lock]" = (
43+ collections .defaultdict (asyncio .Lock )
44+ )
45+ self .wait_counts : "collections.defaultdict[str, int]" = collections .defaultdict (
46+ int
47+ )
3748
38- async def acquire (self , channel ) :
49+ async def acquire (self , channel : str ) -> bool :
3950 """
4051 Acquire the lock for the given channel.
4152 """
4253 self .wait_counts [channel ] += 1
4354 return await self .locks [channel ].acquire ()
4455
45- def locked (self , channel ) :
56+ def locked (self , channel : str ) -> bool :
4657 """
4758 Return ``True`` if the lock for the given channel is acquired.
4859 """
4960 return self .locks [channel ].locked ()
5061
51- def release (self , channel ):
62+ def release (self , channel : str ):
5263 """
5364 Release the lock for the given channel.
5465 """
@@ -73,12 +84,12 @@ def put_nowait(self, item):
7384
7485
7586class RedisLoopLayer :
76- def __init__ (self , channel_layer ):
87+ def __init__ (self , channel_layer : "RedisChannelLayer" ):
7788 self ._lock = asyncio .Lock ()
7889 self .channel_layer = channel_layer
79- self ._connections = {}
90+ self ._connections : "Dict[int, Redis]" = {}
8091
81- def get_connection (self , index ) :
92+ def get_connection (self , index : int ) -> "Redis" :
8293 if index not in self ._connections :
8394 pool = self .channel_layer .create_pool (index )
8495 self ._connections [index ] = aioredis .Redis (connection_pool = pool )
@@ -134,7 +145,7 @@ def __init__(
134145 symmetric_encryption_keys = symmetric_encryption_keys ,
135146 )
136147 # Cached redis connection pools and the event loop they are from
137- self ._layers = {}
148+ self ._layers : "Dict[asyncio.AbstractEventLoop, RedisLoopLayer]" = {}
138149 # Normal channels choose a host index by cycling through the available hosts
139150 self ._receive_index_generator = itertools .cycle (range (len (self .hosts )))
140151 self ._send_index_generator = itertools .cycle (range (len (self .hosts )))
@@ -143,33 +154,33 @@ def __init__(
143154 # Number of coroutines trying to receive right now
144155 self .receive_count = 0
145156 # The receive lock
146- self .receive_lock = None
157+ self .receive_lock : "Optional[asyncio.Lock]" = None
147158 # Event loop they are trying to receive on
148- self .receive_event_loop = None
159+ self .receive_event_loop : "Optional[asyncio.AbstractEventLoop]" = None
149160 # Buffered messages by process-local channel name
150- self .receive_buffer = collections .defaultdict (
151- functools .partial (BoundedQueue , self .capacity )
161+ self .receive_buffer : " collections.defaultdict[str, BoundedQueue]" = (
162+ collections . defaultdict ( functools .partial (BoundedQueue , self .capacity ) )
152163 )
153164 # Detached channel cleanup tasks
154- self .receive_cleaners = []
165+ self .receive_cleaners : "List[asyncio.Task]" = []
155166 # Per-channel cleanup locks to prevent a receive starting and moving
156167 # a message back into the main queue before its cleanup has completed
157168 self .receive_clean_locks = ChannelLock ()
158169
159- def create_pool (self , index ) :
170+ def create_pool (self , index : int ) -> "ConnectionPool" :
160171 return create_pool (self .hosts [index ])
161172
162173 ### Channel layer API ###
163174
164175 extensions = ["groups" , "flush" ]
165176
166- async def send (self , channel , message ):
177+ async def send (self , channel : str , message ):
167178 """
168179 Send a message onto a (general or specific) channel.
169180 """
170181 # Typecheck
171182 assert isinstance (message , dict ), "message is not a dict"
172- assert self .valid_channel_name (channel ), "Channel name not valid"
183+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
173184 # Make sure the message does not contain reserved keys
174185 assert "__asgi_channel__" not in message
175186 # If it's a process-local channel, strip off local part and stick full name in message
@@ -203,13 +214,15 @@ async def send(self, channel, message):
203214 await connection .zadd (channel_key , {self .serialize (message ): time .time ()})
204215 await connection .expire (channel_key , int (self .expiry ))
205216
206- def _backup_channel_name (self , channel ) :
217+ def _backup_channel_name (self , channel : str ) -> str :
207218 """
208219 Construct the key used as a backup queue for the given channel.
209220 """
210221 return channel + "$inflight"
211222
212- async def _brpop_with_clean (self , index , channel , timeout ):
223+ async def _brpop_with_clean (
224+ self , index : int , channel : str , timeout : "Union[int, float, bytes, str]"
225+ ):
213226 """
214227 Perform a Redis BRPOP and manage the backup processing queue.
215228 In case of cancellation, make sure the message is not lost.
@@ -240,23 +253,23 @@ async def _brpop_with_clean(self, index, channel, timeout):
240253
241254 return member
242255
243- async def _clean_receive_backup (self , index , channel ):
256+ async def _clean_receive_backup (self , index : int , channel : str ):
244257 """
245258 Pop the oldest message off the channel backup queue.
246259 The result isn't interesting as it was already processed.
247260 """
248261 connection = self .connection (index )
249262 await connection .zpopmin (self ._backup_channel_name (channel ))
250263
251- async def receive (self , channel ):
264+ async def receive (self , channel : str ):
252265 """
253266 Receive the first message that arrives on the channel.
254267 If more than one coroutine waits on the same channel, the first waiter
255268 will be given the message when it arrives.
256269 """
257270 # Make sure the channel name is valid then get the non-local part
258271 # and thus its index
259- assert self .valid_channel_name (channel )
272+ assert self .require_valid_channel_name (channel )
260273 if "!" in channel :
261274 real_channel = self .non_local_name (channel )
262275 assert real_channel .endswith (
@@ -372,12 +385,14 @@ async def receive(self, channel):
372385 # Do a plain direct receive
373386 return (await self .receive_single (channel ))[1 ]
374387
375- async def receive_single (self , channel ) :
388+ async def receive_single (self , channel : str ) -> "Tuple" :
376389 """
377390 Receives a single message off of the channel and returns it.
378391 """
379392 # Check channel name
380- assert self .valid_channel_name (channel , receive = True ), "Channel name invalid"
393+ assert self .require_valid_channel_name (
394+ channel , receive = True
395+ ), "Channel name invalid"
381396 # Work out the connection to use
382397 if "!" in channel :
383398 assert channel .endswith ("!" )
@@ -408,7 +423,7 @@ async def receive_single(self, channel):
408423 )
409424 self .receive_cleaners .append (cleaner )
410425
411- def _cleanup_done (cleaner ):
426+ def _cleanup_done (cleaner : "asyncio.Task" ):
412427 self .receive_cleaners .remove (cleaner )
413428 self .receive_clean_locks .release (channel_key )
414429
@@ -427,7 +442,7 @@ def _cleanup_done(cleaner):
427442 del message ["__asgi_channel__" ]
428443 return channel , message
429444
430- async def new_channel (self , prefix = "specific" ):
445+ async def new_channel (self , prefix : str = "specific" ) -> str :
431446 """
432447 Returns a new channel name that can be used by something in our
433448 process as a specific channel.
@@ -477,13 +492,13 @@ async def wait_received(self):
477492
478493 ### Groups extension ###
479494
480- async def group_add (self , group , channel ):
495+ async def group_add (self , group : str , channel : str ):
481496 """
482497 Adds the channel name to a group.
483498 """
484499 # Check the inputs
485- assert self .valid_group_name (group ), "Group name not valid"
486- assert self .valid_channel_name (channel ), "Channel name not valid"
500+ assert self .require_valid_group_name (group ), True
501+ assert self .require_valid_channel_name (channel ), True
487502 # Get a connection to the right shard
488503 group_key = self ._group_key (group )
489504 connection = self .connection (self .consistent_hash (group ))
@@ -493,22 +508,22 @@ async def group_add(self, group, channel):
493508 # it at this point is guaranteed to expire before that
494509 await connection .expire (group_key , self .group_expiry )
495510
496- async def group_discard (self , group , channel ):
511+ async def group_discard (self , group : str , channel : str ):
497512 """
498513 Removes the channel from the named group if it is in the group;
499514 does nothing otherwise (does not error)
500515 """
501- assert self .valid_group_name (group ), "Group name not valid"
502- assert self .valid_channel_name (channel ), "Channel name not valid"
516+ assert self .require_valid_group_name (group ), "Group name not valid"
517+ assert self .require_valid_channel_name (channel ), "Channel name not valid"
503518 key = self ._group_key (group )
504519 connection = self .connection (self .consistent_hash (group ))
505520 await connection .zrem (key , channel )
506521
507- async def group_send (self , group , message ):
522+ async def group_send (self , group : str , message ):
508523 """
509524 Sends a message to the entire group.
510525 """
511- assert self .valid_group_name (group ), "Group name not valid"
526+ assert self .require_valid_group_name (group ), "Group name not valid"
512527 # Retrieve list of all channel names
513528 key = self ._group_key (group )
514529 connection = self .connection (self .consistent_hash (group ))
@@ -573,7 +588,12 @@ async def group_send(self, group, message):
573588 channels_over_capacity = await connection .eval (
574589 group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
575590 )
576- if channels_over_capacity > 0 :
591+ _channels_over_capacity = - 1
592+ try :
593+ _channels_over_capacity = float (channels_over_capacity )
594+ except Exception :
595+ pass
596+ if _channels_over_capacity > 0 :
577597 logger .info (
578598 "%s of %s channels over capacity in group %s" ,
579599 channels_over_capacity ,
@@ -631,37 +651,35 @@ def _map_channel_keys_to_connection(self, channel_names, message):
631651 channel_key_to_capacity ,
632652 )
633653
634- def _group_key (self , group ) :
654+ def _group_key (self , group : str ) -> bytes :
635655 """
636656 Common function to make the storage key for the group.
637657 """
638658 return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
639659
640- ### Serialization ###
641-
642- def serialize (self , message ):
660+ def serialize (self , message ) -> bytes :
643661 """
644662 Serializes message to a byte string.
645663 """
646664 return self ._serializer .serialize (message )
647665
648- def deserialize (self , message ):
666+ def deserialize (self , message : bytes ):
649667 """
650668 Deserializes from a byte string.
651669 """
652670 return self ._serializer .deserialize (message )
653671
654672 ### Internal functions ###
655673
656- def consistent_hash (self , value ) :
674+ def consistent_hash (self , value : "Union[str, Buffer]" ) -> int :
657675 return _consistent_hash (value , self .ring_size )
658676
659677 def __str__ (self ):
660678 return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
661679
662680 ### Connection handling ###
663681
664- def connection (self , index ) :
682+ def connection (self , index : int ) -> "Redis" :
665683 """
666684 Returns the correct connection for the index given.
667685 Lazily instantiates pools.
0 commit comments