@@ -37,7 +37,7 @@ class ChannelLock:
3737 to mitigate multi-event loop problems.
3838 """
3939
40- def __init__ (self ):
40+ def __init__ (self ) -> None :
4141 self .locks : collections .defaultdict [str , asyncio .Lock ] = (
4242 collections .defaultdict (asyncio .Lock )
4343 )
@@ -58,7 +58,7 @@ def locked(self, channel: str) -> bool:
5858 """
5959 return self .locks [channel ].locked ()
6060
61- def release (self , channel : str ):
61+ def release (self , channel : str ) -> None :
6262 """
6363 Release the lock for the given channel.
6464 """
@@ -69,8 +69,8 @@ def release(self, channel: str):
6969 del self .wait_counts [channel ]
7070
7171
72- class BoundedQueue (asyncio .Queue ):
73- def put_nowait (self , item ) :
72+ class BoundedQueue (asyncio .Queue [ typing . Any ] ):
73+ def put_nowait (self , item : typing . Any ) -> None :
7474 if self .full ():
7575 # see: https://github.com/django/channels_redis/issues/212
7676 # if we actually get into this code block, it likely means that
@@ -83,7 +83,7 @@ def put_nowait(self, item):
8383
8484
8585class RedisLoopLayer :
86- def __init__ (self , channel_layer : "RedisChannelLayer" ):
86+ def __init__ (self , channel_layer : "RedisChannelLayer" ) -> None :
8787 self ._lock = asyncio .Lock ()
8888 self .channel_layer = channel_layer
8989 self ._connections : typing .Dict [int , "Redis" ] = {}
@@ -95,7 +95,7 @@ def get_connection(self, index: int) -> "Redis":
9595
9696 return self ._connections [index ]
9797
98- async def flush (self ):
98+ async def flush (self ) -> None :
9999 async with self ._lock :
100100 for index in list (self ._connections ):
101101 connection = self ._connections .pop (index )
@@ -116,15 +116,15 @@ class RedisChannelLayer(BaseChannelLayer):
116116 def __init__ (
117117 self ,
118118 hosts = None ,
119- prefix = "asgi" ,
119+ prefix : str = "asgi" ,
120120 expiry = 60 ,
121- group_expiry = 86400 ,
121+ group_expiry : int = 86400 ,
122122 capacity = 100 ,
123123 channel_capacity = None ,
124124 symmetric_encryption_keys = None ,
125125 random_prefix_length = 12 ,
126126 serializer_format = "msgpack" ,
127- ):
127+ ) -> None :
128128 # Store basic information
129129 self .expiry = expiry
130130 self .group_expiry = group_expiry
@@ -161,7 +161,7 @@ def __init__(
161161 collections .defaultdict (functools .partial (BoundedQueue , self .capacity ))
162162 )
163163 # Detached channel cleanup tasks
164- self .receive_cleaners : typing .List [asyncio .Task ] = []
164+ self .receive_cleaners : typing .List [asyncio .Task [ typing . Any ] ] = []
165165 # Per-channel cleanup locks to prevent a receive starting and moving
166166 # a message back into the main queue before its cleanup has completed
167167 self .receive_clean_locks = ChannelLock ()
@@ -173,7 +173,7 @@ def create_pool(self, index: int) -> "ConnectionPool":
173173
174174 extensions = ["groups" , "flush" ]
175175
176- async def send (self , channel : str , message ) :
176+ async def send (self , channel : str , message : typing . Any ) -> None :
177177 """
178178 Send a message onto a (general or specific) channel.
179179 """
@@ -221,7 +221,7 @@ def _backup_channel_name(self, channel: str) -> str:
221221
222222 async def _brpop_with_clean (
223223 self , index : int , channel : str , timeout : typing .Union [int , float , bytes , str ]
224- ):
224+ ) -> typing . Any :
225225 """
226226 Perform a Redis BRPOP and manage the backup processing queue.
227227 In case of cancellation, make sure the message is not lost.
@@ -252,15 +252,15 @@ async def _brpop_with_clean(
252252
253253 return member
254254
255- async def _clean_receive_backup (self , index : int , channel : str ):
255+ async def _clean_receive_backup (self , index : int , channel : str ) -> None :
256256 """
257257 Pop the oldest message off the channel backup queue.
258258 The result isn't interesting as it was already processed.
259259 """
260260 connection = self .connection (index )
261261 await connection .zpopmin (self ._backup_channel_name (channel ))
262262
263- async def receive (self , channel : str ):
263+ async def receive (self , channel : str ) -> typing . Any :
264264 """
265265 Receive the first message that arrives on the channel.
266266 If more than one coroutine waits on the same channel, the first waiter
@@ -271,9 +271,9 @@ async def receive(self, channel: str):
271271 assert self .valid_channel_name (channel )
272272 if "!" in channel :
273273 real_channel = self .non_local_name (channel )
274- assert real_channel .endswith (
275- self . client_prefix + "! "
276- ), "Wrong client prefix"
274+ assert real_channel .endswith (self . client_prefix + "!" ), (
275+ "Wrong client prefix "
276+ )
277277 # Enter receiving section
278278 loop = asyncio .get_running_loop ()
279279 self .receive_count += 1
@@ -292,11 +292,11 @@ async def receive(self, channel: str):
292292 # Wait for our message to appear
293293 message = None
294294 while self .receive_buffer [channel ].empty ():
295- tasks = [
295+ _tasks = [
296296 self .receive_lock .acquire (),
297297 self .receive_buffer [channel ].get (),
298298 ]
299- tasks = [asyncio .ensure_future (task ) for task in tasks ]
299+ tasks = [asyncio .ensure_future (task ) for task in _tasks ]
300300 try :
301301 done , pending = await asyncio .wait (
302302 tasks , return_when = asyncio .FIRST_COMPLETED
@@ -384,7 +384,9 @@ async def receive(self, channel: str):
384384 # Do a plain direct receive
385385 return (await self .receive_single (channel ))[1 ]
386386
387- async def receive_single (self , channel : str ) -> typing .Tuple :
387+ async def receive_single (
388+ self , channel : str
389+ ) -> typing .Tuple [typing .Any , typing .Any ]:
388390 """
389391 Receives a single message off of the channel and returns it.
390392 """
@@ -420,7 +422,7 @@ async def receive_single(self, channel: str) -> typing.Tuple:
420422 )
421423 self .receive_cleaners .append (cleaner )
422424
423- def _cleanup_done (cleaner : asyncio .Task ) :
425+ def _cleanup_done (cleaner : asyncio .Task [ typing . Any ]) -> None :
424426 self .receive_cleaners .remove (cleaner )
425427 self .receive_clean_locks .release (channel_key )
426428
@@ -448,7 +450,7 @@ async def new_channel(self, prefix: str = "specific") -> str:
448450
449451 ### Flush extension ###
450452
451- async def flush (self ):
453+ async def flush (self ) -> None :
452454 """
453455 Deletes all messages and groups on all shards.
454456 """
@@ -470,7 +472,7 @@ async def flush(self):
470472 # Now clear the pools as well
471473 await self .close_pools ()
472474
473- async def close_pools (self ):
475+ async def close_pools (self ) -> None :
474476 """
475477 Close all connections in the event loop pools.
476478 """
@@ -480,7 +482,7 @@ async def close_pools(self):
480482 for layer in self ._layers .values ():
481483 await layer .flush ()
482484
483- async def wait_received (self ):
485+ async def wait_received (self ) -> None :
484486 """
485487 Wait for all channel cleanup functions to finish.
486488 """
@@ -489,13 +491,13 @@ async def wait_received(self):
489491
490492 ### Groups extension ###
491493
492- async def group_add (self , group : str , channel : str ):
494+ async def group_add (self , group : str , channel : str ) -> None :
493495 """
494496 Adds the channel name to a group.
495497 """
496498 # Check the inputs
497- assert self .valid_group_name (group ), True
498- assert self .valid_channel_name (channel ), True
499+ assert self .valid_group_name (group ), "Group name not valid"
500+ assert self .valid_channel_name (channel ), "Channel name not valid"
499501 # Get a connection to the right shard
500502 group_key = self ._group_key (group )
501503 connection = self .connection (self .consistent_hash (group ))
@@ -505,7 +507,7 @@ async def group_add(self, group: str, channel: str):
505507 # it at this point is guaranteed to expire before that
506508 await connection .expire (group_key , self .group_expiry )
507509
508- async def group_discard (self , group : str , channel : str ):
510+ async def group_discard (self , group : str , channel : str ) -> None :
509511 """
510512 Removes the channel from the named group if it is in the group;
511513 does nothing otherwise (does not error)
@@ -516,7 +518,7 @@ async def group_discard(self, group: str, channel: str):
516518 connection = self .connection (self .consistent_hash (group ))
517519 await connection .zrem (key , channel )
518520
519- async def group_send (self , group : str , message ) :
521+ async def group_send (self , group : str , message : typing . Any ) -> None :
520522 """
521523 Sends a message to the entire group.
522524 """
@@ -540,9 +542,9 @@ async def group_send(self, group: str, message):
540542 for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
541543 # Discard old messages based on expiry
542544 pipe = connection .pipeline ()
543- for key in channel_redis_keys :
545+ for _key in channel_redis_keys :
544546 pipe .zremrangebyscore (
545- key , min = 0 , max = int (time .time ()) - int (self .expiry )
547+ _key , min = 0 , max = int (time .time ()) - int (self .expiry )
546548 )
547549 await pipe .execute ()
548550
@@ -585,7 +587,7 @@ async def group_send(self, group: str, message):
585587 channels_over_capacity = await connection .eval (
586588 group_send_lua , len (channel_redis_keys ), * channel_redis_keys , * args
587589 )
588- _channels_over_capacity = - 1
590+ _channels_over_capacity = - 1.0
589591 try :
590592 _channels_over_capacity = float (channels_over_capacity )
591593 except Exception :
@@ -598,7 +600,13 @@ async def group_send(self, group: str, message):
598600 group ,
599601 )
600602
601- def _map_channel_keys_to_connection (self , channel_names , message ):
603+ def _map_channel_keys_to_connection (
604+ self , channel_names : typing .Iterable [str ], message : typing .Any
605+ ) -> typing .Tuple [
606+ typing .Dict [int , typing .List [str ]],
607+ typing .Dict [str , typing .Any ],
608+ typing .Dict [str , int ],
609+ ]:
602610 """
603611 For a list of channel names, GET
604612
@@ -611,19 +619,21 @@ def _map_channel_keys_to_connection(self, channel_names, message):
611619 """
612620
613621 # Connection dict keyed by index to list of redis keys mapped on that index
614- connection_to_channel_keys = collections .defaultdict (list )
622+ connection_to_channel_keys : typing .Dict [int , typing .List [str ]] = (
623+ collections .defaultdict (list )
624+ )
615625 # Message dict maps redis key to the message that needs to be send on that key
616- channel_key_to_message = dict ()
626+ channel_key_to_message : typing . Dict [ str , typing . Any ] = dict ()
617627 # Channel key mapped to its capacity
618- channel_key_to_capacity = dict ()
628+ channel_key_to_capacity : typing . Dict [ str , int ] = dict ()
619629
620630 # For each channel
621631 for channel in channel_names :
622632 channel_non_local_name = channel
623633 if "!" in channel :
624634 channel_non_local_name = self .non_local_name (channel )
625635 # Get its redis key
626- channel_key = self .prefix + channel_non_local_name
636+ channel_key : str = self .prefix + channel_non_local_name
627637 # Have we come across the same redis key?
628638 if channel_key not in channel_key_to_message :
629639 # If not, fill the corresponding dicts
@@ -654,13 +664,15 @@ def _group_key(self, group: str) -> bytes:
654664 """
655665 return f"{ self .prefix } :group:{ group } " .encode ("utf8" )
656666
657- def serialize (self , message ) -> bytes :
667+ ### Serialization ###
668+
669+ def serialize (self , message : typing .Any ) -> bytes :
658670 """
659671 Serializes message to a byte string.
660672 """
661673 return self ._serializer .serialize (message )
662674
663- def deserialize (self , message : bytes ):
675+ def deserialize (self , message : bytes ) -> typing . Any :
664676 """
665677 Deserializes from a byte string.
666678 """
@@ -671,7 +683,7 @@ def deserialize(self, message: bytes):
671683 def consistent_hash (self , value : typing .Union [str , "Buffer" ]) -> int :
672684 return _consistent_hash (value , self .ring_size )
673685
674- def __str__ (self ):
686+ def __str__ (self ) -> str :
675687 return f"{ self .__class__ .__name__ } (hosts={ self .hosts } )"
676688
677689 ### Connection handling ###
0 commit comments