@@ -178,24 +178,37 @@ def _setup_encryption(self, symmetric_encryption_keys):
178178
179179 async def send (self , channel , message ):
180180 """
181- Send a message onto a (general or specific) channel.
181+ Send one or multiple messages onto a (general or specific) channel.
182+ The `message` can be a single dict or an iterable of dicts.
182183 """
184+ messages = self ._parse_messages (message )
185+
183186 # Typecheck
184- assert isinstance (message , dict ), "message is not a dict"
185187 assert self .valid_channel_name (channel ), "Channel name not valid"
186- # Make sure the message does not contain reserved keys
187- assert "__asgi_channel__" not in message
188+
188189 # If it's a process-local channel, strip off local part and stick full name in message
189190 channel_non_local_name = channel
190- if "!" in channel :
191- message = dict (message .items ())
192- message ["__asgi_channel__" ] = channel
191+ process_local = "!" in channel
192+ if process_local :
193193 channel_non_local_name = self .non_local_name (channel )
194+
195+ now = time .time ()
196+ mapping = {}
197+ for message in messages :
198+ assert isinstance (message , dict ), "message is not a dict"
199+ # Make sure the message does not contain reserved keys
200+ assert "__asgi_channel__" not in message
201+ if process_local :
202+ message = dict (message .items ())
203+ message ["__asgi_channel__" ] = channel
204+
205+ mapping [self .serialize (message )] = now
206+
194207 # Write out message into expiring key (avoids big items in list)
195208 channel_key = self .prefix + channel_non_local_name
196209 # Pick a connection to the right server - consistent for specific
197210 # channels, random for general channels
198- if "!" in channel :
211+ if process_local :
199212 index = self .consistent_hash (channel )
200213 else :
201214 index = next (self ._send_index_generator )
@@ -207,15 +220,23 @@ async def send(self, channel, message):
207220
208221 # Check the length of the list before send
209222 # This can allow the list to leak slightly over capacity, but that's fine.
210- if await connection .zcount (channel_key , "-inf" , "+inf" ) >= self . get_capacity (
211- channel
212- ):
223+ current_length = await connection .zcount (channel_key , "-inf" , "+inf" )
224+
225+ if current_length + len ( messages ) > self . get_capacity ( channel ):
213226 raise ChannelFull ()
214227
215228 # Push onto the list then set it to expire in case it's not consumed
216- await connection .zadd (channel_key , { self . serialize ( message ): time . time ()} )
229+ await connection .zadd (channel_key , mapping )
217230 await connection .expire (channel_key , int (self .expiry ))
218231
232+ def _parse_messages (self , message ):
233+ """
234+ Convert a passed message arg to a tuple of messages.
235+ """
236+ if not isinstance (message , dict ) and hasattr (message , "__iter__" ):
237+ return tuple (message )
238+ return (message ,)
239+
219240 def _backup_channel_name (self , channel ):
220241 """
221242 Construct the key used as a backup queue for the given channel.
@@ -519,8 +540,11 @@ async def group_discard(self, group, channel):
519540
520541 async def group_send (self , group , message ):
521542 """
522- Sends a message to the entire group.
543+ Sends one or multiple messages to the entire group.
544+ The `message` can be a single dict or an iterable of dicts.
523545 """
546+ messages = self ._parse_messages (message )
547+
524548 assert self .valid_group_name (group ), "Group name not valid"
525549 # Retrieve list of all channel names
526550 key = self ._group_key (group )
@@ -536,7 +560,7 @@ async def group_send(self, group, message):
536560 connection_to_channel_keys ,
537561 channel_keys_to_message ,
538562 channel_keys_to_capacity ,
539- ) = self ._map_channel_keys_to_connection (channel_names , message )
563+ ) = self ._map_channel_keys_to_connection (channel_names , messages )
540564
541565 for connection_index , channel_redis_keys in connection_to_channel_keys .items ():
542566 # Discard old messages based on expiry
@@ -548,17 +572,23 @@ async def group_send(self, group, message):
548572 await pipe .execute ()
549573
550574 # Create a LUA script specific for this connection.
551- # Make sure to use the message specific to this channel, it is
552- # stored in channel_to_message dict and contains the
575+ # Make sure to use the message list specific to this channel, it is
576+ # stored in channel_to_message dict and each message contains the
553577 # __asgi_channel__ key.
554578
555579 group_send_lua = """
556580 local over_capacity = 0
581+ local num_messages = tonumber(ARGV[#ARGV - 2])
557582 local current_time = ARGV[#ARGV - 1]
558583 local expiry = ARGV[#ARGV]
559584 for i=1,#KEYS do
560- if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
561- redis.call('ZADD', KEYS[i], current_time, ARGV[i])
585+ if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
586+ local messages = {}
587+ for j=num_messages * (i - 1) + 1, num_messages * i do
588+ table.insert(messages, current_time)
589+ table.insert(messages, ARGV[j])
590+ end
591+ redis.call('ZADD', KEYS[i], unpack(messages))
562592 redis.call('EXPIRE', KEYS[i], expiry)
563593 else
564594 over_capacity = over_capacity + 1
@@ -568,18 +598,18 @@ async def group_send(self, group, message):
568598 """
569599
570600 # We need to filter the messages to keep those related to the connection
571- args = [
572- channel_keys_to_message [ channel_key ]
573- for channel_key in channel_redis_keys
574- ]
601+ args = []
602+
603+ for channel_key in channel_redis_keys :
604+ args += channel_keys_to_message [ channel_key ]
575605
576606 # We need to send the capacity for each channel
577607 args += [
578608 channel_keys_to_capacity [channel_key ]
579609 for channel_key in channel_redis_keys
580610 ]
581611
582- args += [time .time (), self .expiry ]
612+ args += [len ( messages ), time .time (), self .expiry ]
583613
584614 # channel_keys does not contain a single redis key more than once
585615 connection = self .connection (connection_index )
@@ -594,7 +624,7 @@ async def group_send(self, group, message):
594624 group ,
595625 )
596626
597- def _map_channel_keys_to_connection (self , channel_names , message ):
627+ def _map_channel_keys_to_connection (self , channel_names , messages ):
598628 """
599629 For a list of channel names, GET
600630
@@ -609,7 +639,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
609639 # Connection dict keyed by index to list of redis keys mapped on that index
610640 connection_to_channel_keys = collections .defaultdict (list )
611641 # Message dict maps redis key to the message that needs to be send on that key
612- channel_key_to_message = dict ( )
642+ channel_key_to_message = collections . defaultdict ( list )
613643 # Channel key mapped to its capacity
614644 channel_key_to_capacity = dict ()
615645
@@ -623,20 +653,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
623653 # Have we come across the same redis key?
624654 if channel_key not in channel_key_to_message :
625655 # If not, fill the corresponding dicts
626- message = dict (message .items ())
627- message ["__asgi_channel__" ] = [channel ]
628- channel_key_to_message [channel_key ] = message
656+ for message in messages :
657+ message = dict (message .items ())
658+ message ["__asgi_channel__" ] = [channel ]
659+ channel_key_to_message [channel_key ].append (message )
629660 channel_key_to_capacity [channel_key ] = self .get_capacity (channel )
630661 idx = self .consistent_hash (channel_non_local_name )
631662 connection_to_channel_keys [idx ].append (channel_key )
632663 else :
633664 # Yes, Append the channel in message dict
634- channel_key_to_message [channel_key ]["__asgi_channel__" ].append (channel )
665+ for message in channel_key_to_message [channel_key ]:
666+ message ["__asgi_channel__" ].append (channel )
635667
636668 # Now that we know what message needs to be send on a redis key we serialize it
637669 for key , value in channel_key_to_message .items ():
638670 # Serialize the message stored for each redis key
639- channel_key_to_message [key ] = self .serialize (value )
671+ for idx , message in enumerate (value ):
672+ channel_key_to_message [key ][idx ] = self .serialize (message )
640673
641674 return (
642675 connection_to_channel_keys ,
0 commit comments