From 31e470954a50a0d9f1ce8f1a7927e4a60c5ec7f9 Mon Sep 17 00:00:00 2001 From: Yaroslav Petrov Date: Sat, 15 Nov 2025 15:25:01 +0000 Subject: [PATCH 1/3] Add safe root model handling for path searching in json codec --- pyproject.toml | 2 +- src/asyncapi_python/contrib/codec/json.py | 32 ++- tests/core/codec/test_json_codec.py | 232 ++++++++++++++++++ .../test_parameterized_subscriptions.py | 141 ++++++++++- uv.lock | 2 +- 5 files changed, 397 insertions(+), 12 deletions(-) create mode 100644 tests/core/codec/test_json_codec.py diff --git a/pyproject.toml b/pyproject.toml index 88a99fa..bd4b11f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "asyncapi-python" -version = "0.3.0rc8" +version = "0.3.0rc9" license = { text = "Apache-2.0" } description = "Easily generate type-safe and async Python applications from AsyncAPI 3 specifications." authors = [{ name = "Yaroslav Petrov", email = "yaroslav.v.petrov@gmail.com" }] diff --git a/src/asyncapi_python/contrib/codec/json.py b/src/asyncapi_python/contrib/codec/json.py index ce5aded..437101c 100644 --- a/src/asyncapi_python/contrib/codec/json.py +++ b/src/asyncapi_python/contrib/codec/json.py @@ -1,9 +1,9 @@ import json from enum import Enum from types import ModuleType -from typing import ClassVar, Type +from typing import Any, ClassVar, Type -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, RootModel, ValidationError from asyncapi_python.kernel.codec import Codec, CodecFactory from asyncapi_python.kernel.document.message import Message @@ -31,13 +31,16 @@ def decode(self, payload: bytes) -> BaseModel: def extract_field(self, payload: BaseModel, location: str) -> str: """Extract field from Pydantic model using location path. + Handles both regular BaseModel and RootModel wrappers. RootModel instances + are automatically unwrapped (recursively) to access the underlying data. + Examples: "$message.payload#/userId" → payload.userId → "123" "$message.payload#/user/id" → payload.user.id → "456" "$message.payload#/items" → payload.items → "[1, 2, 3]" Args: - payload: Pydantic BaseModel instance + payload: Pydantic BaseModel instance (may be RootModel wrapper) location: Location expression like "$message.payload#/userId" Returns: @@ -56,18 +59,29 @@ def extract_field(self, payload: BaseModel, location: str) -> str: parts = [p for p in path.split("/") if p] try: - value = payload + value: Any = payload for part in parts: - value = getattr(value, part) + # Recursively unwrap any RootModel wrappers before accessing attributes + while isinstance(value, RootModel): + value = value.root # type: ignore[assignment, misc] + value = getattr(value, part) # type: ignore[arg-type] + + # Unwrap final value if it's a RootModel + while isinstance(value, RootModel): + value = value.root # type: ignore[assignment, misc] # Convert to string - if isinstance(value, (str, int, float, bool)): - return str(value) - elif isinstance(value, Enum): + # Check Enum FIRST (before str/int/etc) because str/int Enums are also instances of str/int + if isinstance(value, Enum): # Handle Enum types - extract the value attribute return str(value.value) + elif isinstance(value, (str, int, float, bool)): + return str(value) + elif isinstance(value, BaseModel): + # Pydantic models: dump to dict then JSON serialize + return json.dumps(value.model_dump()) else: - # Complex types: JSON serialize + # Other complex types: JSON serialize directly return json.dumps(value) except AttributeError as e: diff --git a/tests/core/codec/test_json_codec.py b/tests/core/codec/test_json_codec.py new file mode 100644 index 0000000..4e52310 --- /dev/null +++ b/tests/core/codec/test_json_codec.py @@ -0,0 +1,232 @@ +"""Tests for JSON codec extract_field() method with RootModel support""" + +from enum import Enum + +import pytest +from pydantic import BaseModel, RootModel + +from asyncapi_python.contrib.codec.json import JsonCodec + + +# Test models +class SimpleMessage(BaseModel): + """Regular BaseModel for testing""" + + chat_id: int + message: str + + +class NestedUser(BaseModel): + """Nested model for path traversal testing""" + + id: str + name: str + + +class MessageWithNested(BaseModel): + """Model with nested fields""" + + user: NestedUser + content: str + + +class Severity(str, Enum): + """Enum for testing enum extraction""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class MessageWithEnum(BaseModel): + """Model with enum field""" + + severity: Severity + description: str + + +class ComplexData(BaseModel): + """Complex nested data""" + + items: list[str] + metadata: dict[str, str] + + +class MessageWithComplex(BaseModel): + """Model with complex types""" + + data: ComplexData + + +# RootModel wrappers +class SimpleRootModel(RootModel[SimpleMessage]): + """Single-level RootModel wrapper""" + + root: SimpleMessage + + +class InnerRootModel(RootModel[NestedUser]): + """Inner RootModel for nested testing""" + + root: NestedUser + + +class OuterMessageWithRootModel(BaseModel): + """Message containing a RootModel field""" + + user: InnerRootModel + content: str + + +class DoubleRootModel(RootModel[SimpleRootModel]): + """Nested RootModel (RootModel containing RootModel)""" + + root: SimpleRootModel + + +# Tests +def test_extract_field_from_base_model(): + """Test extracting fields from regular BaseModel""" + codec = JsonCodec(SimpleMessage) + message = SimpleMessage(chat_id=123, message="hello") + + result = codec.extract_field(message, "$message.payload#/chat_id") + assert result == "123" + + result = codec.extract_field(message, "$message.payload#/message") + assert result == "hello" + + +def test_extract_field_from_root_model(): + """Test extracting fields from single-level RootModel wrapper""" + codec = JsonCodec(SimpleRootModel) + wrapped = SimpleRootModel.model_validate({"chat_id": 456, "message": "world"}) + + # Should unwrap RootModel and access fields on the root + result = codec.extract_field(wrapped, "$message.payload#/chat_id") + assert result == "456" + + result = codec.extract_field(wrapped, "$message.payload#/message") + assert result == "world" + + +def test_extract_field_from_nested_root_model(): + """Test extracting fields from nested RootModel (RootModel containing RootModel)""" + codec = JsonCodec(DoubleRootModel) + + # Create nested RootModel: DoubleRootModel -> SimpleRootModel -> SimpleMessage + inner = SimpleRootModel.model_validate({"chat_id": 789, "message": "nested"}) + wrapped = DoubleRootModel.model_validate(inner.model_dump()) + + # Should recursively unwrap both RootModel layers + result = codec.extract_field(wrapped, "$message.payload#/chat_id") + assert result == "789" + + result = codec.extract_field(wrapped, "$message.payload#/message") + assert result == "nested" + + +def test_extract_field_nested_path(): + """Test extracting nested fields using path like $message.payload#/user/id""" + codec = JsonCodec(MessageWithNested) + message = MessageWithNested( + user=NestedUser(id="user123", name="Alice"), content="test" + ) + + result = codec.extract_field(message, "$message.payload#/user/id") + assert result == "user123" + + result = codec.extract_field(message, "$message.payload#/user/name") + assert result == "Alice" + + +def test_extract_field_nested_path_with_root_model(): + """Test extracting nested fields when intermediate field is a RootModel""" + codec = JsonCodec(OuterMessageWithRootModel) + + # The user field is a RootModel wrapper + user_wrapped = InnerRootModel.model_validate({"id": "user456", "name": "Bob"}) + message = OuterMessageWithRootModel(user=user_wrapped, content="test") + + # Should unwrap the RootModel at the intermediate step + result = codec.extract_field(message, "$message.payload#/user/id") + assert result == "user456" + + result = codec.extract_field(message, "$message.payload#/user/name") + assert result == "Bob" + + +def test_extract_field_enum_value(): + """Test extracting enum values (should return the enum value, not the enum object)""" + codec = JsonCodec(MessageWithEnum) + message = MessageWithEnum(severity=Severity.HIGH, description="critical alert") + + result = codec.extract_field(message, "$message.payload#/severity") + assert result == "high" # Should extract the value, not "Severity.HIGH" + + +def test_extract_field_complex_type(): + """Test extracting complex types (should JSON serialize)""" + codec = JsonCodec(MessageWithComplex) + message = MessageWithComplex( + data=ComplexData(items=["a", "b", "c"], metadata={"key": "value"}) + ) + + result = codec.extract_field(message, "$message.payload#/data") + # Should be JSON serialized + assert '"items": ["a", "b", "c"]' in result + assert '"metadata": {"key": "value"}' in result + + +def test_extract_field_invalid_location(): + """Test error handling for invalid location format""" + codec = JsonCodec(SimpleMessage) + message = SimpleMessage(chat_id=123, message="hello") + + with pytest.raises(ValueError, match="Invalid location format"): + codec.extract_field(message, "invalid/location") + + with pytest.raises(ValueError, match="Invalid location format"): + codec.extract_field(message, "#/chat_id") + + +def test_extract_field_missing_path(): + """Test error handling for non-existent paths""" + codec = JsonCodec(SimpleMessage) + message = SimpleMessage(chat_id=123, message="hello") + + with pytest.raises(ValueError, match="Path 'nonexistent' not found in payload"): + codec.extract_field(message, "$message.payload#/nonexistent") + + +def test_extract_field_missing_nested_path(): + """Test error handling for non-existent nested paths""" + codec = JsonCodec(MessageWithNested) + message = MessageWithNested( + user=NestedUser(id="user123", name="Alice"), content="test" + ) + + with pytest.raises( + ValueError, match="Path 'user/nonexistent' not found in payload" + ): + codec.extract_field(message, "$message.payload#/user/nonexistent") + + +def test_extract_field_primitive_types() -> None: + """Test extraction returns proper string representations of primitive types""" + + class PrimitiveMessage(BaseModel): + str_field: str + int_field: int + float_field: float + bool_field: bool + + codec = JsonCodec(PrimitiveMessage) + message = PrimitiveMessage( + str_field="test", int_field=42, float_field=3.14, bool_field=True + ) + + assert codec.extract_field(message, "$message.payload#/str_field") == "test" + assert codec.extract_field(message, "$message.payload#/int_field") == "42" + assert codec.extract_field(message, "$message.payload#/float_field") == "3.14" + assert codec.extract_field(message, "$message.payload#/bool_field") == "True" diff --git a/tests/core/endpoint/test_parameterized_subscriptions.py b/tests/core/endpoint/test_parameterized_subscriptions.py index 038af0a..d23f141 100644 --- a/tests/core/endpoint/test_parameterized_subscriptions.py +++ b/tests/core/endpoint/test_parameterized_subscriptions.py @@ -17,7 +17,7 @@ from asyncapi_python.kernel.document.channel import AddressParameter, ChannelBindings from asyncapi_python.kernel.document.message import Message from asyncapi_python.kernel.endpoint import Publisher, Subscriber -from pydantic import BaseModel +from pydantic import BaseModel, RootModel class AlertMessage(BaseModel): @@ -632,3 +632,142 @@ async def handle_alert(msg: AlertMessage) -> None: # Should raise ValueError when starting with pytest.raises(ValueError, match="Unexpected parameters"): await subscriber.start() + + +async def test_publisher_extracts_parameters_from_root_model(): + """Publisher should extract parameters from RootModel-wrapped payloads.""" + from asyncapi_python.contrib.wire.in_memory import get_bus, reset_bus + + # Reset the bus for clean test + reset_bus() + wire = InMemoryWire() + + import types + + # Create RootModel wrapper for alert message + class AlertRootModel(RootModel[AlertMessage]): + """RootModel wrapper for testing parameter extraction""" + + root: AlertMessage + + test_module = types.SimpleNamespace() + test_module.messages = types.SimpleNamespace() + test_module.messages.json = types.SimpleNamespace() + test_module.messages.json.TestMessage = AlertRootModel + codec_factory = JsonCodecFactory(test_module) + + # Create parameterized channel directly + channel = Channel( + key="test_channel", + address="alerts.{location}.{severity}", + title="Test Channel", + summary=None, + description=None, + servers=[], + messages={ + "TestMessage": Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + }, + parameters={ + "location": AddressParameter( + key="location", + description="Location code", + location="$message.payload#/location", + ), + "severity": AddressParameter( + key="severity", + description="Severity level", + location="$message.payload#/severity", + ), + }, + tags=[], + external_docs=None, + bindings=ChannelBindings( + amqp=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="test_exchange", + type=AmqpExchangeType.TOPIC, + ), + ) + ), + ) + + # Create operation with parameterized channel + operation = Operation( + key="test_op", + action="send", + title=None, + summary=None, + description=None, + channel=channel, + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + publisher = Publisher( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + # Capture the channel name used for publishing by intercepting bus.publish + captured_channel: str | None = None + bus = get_bus() + original_publish = bus.publish + + async def mock_publish(channel_name: str, message: Any) -> None: + nonlocal captured_channel + captured_channel = channel_name + await original_publish(channel_name, message) + + bus.publish = mock_publish # type: ignore + + await publisher.start() + + # Send message wrapped in RootModel + wrapped_message = AlertRootModel.model_validate( + {"location": "NYC", "severity": "high", "data": "test"} + ) + await publisher(wrapped_message) + + # Should extract parameters from RootModel and build correct address + assert captured_channel == "alerts.NYC.high" + + await publisher.stop() diff --git a/uv.lock b/uv.lock index 0ee18d7..4bd3842 100644 --- a/uv.lock +++ b/uv.lock @@ -64,7 +64,7 @@ wheels = [ [[package]] name = "asyncapi-python" -version = "0.3.0rc7" +version = "0.3.0rc9" source = { editable = "." } dependencies = [ { name = "pydantic" }, From b3577a126dff534373364b6c9a40fb14b91634ae Mon Sep 17 00:00:00 2001 From: Yaroslav Petrov Date: Sun, 16 Nov 2025 11:55:01 +0000 Subject: [PATCH 2/3] Add arguments to amqp channels --- .../contrib/wire/amqp/config.py | 3 +++ .../contrib/wire/amqp/consumer.py | 14 +++++++++++ .../contrib/wire/amqp/producer.py | 23 +++++++++++++++---- .../contrib/wire/amqp/resolver.py | 21 +++++++++++++++++ .../kernel/document/bindings.py | 8 +++++-- 5 files changed, 63 insertions(+), 6 deletions(-) diff --git a/src/asyncapi_python/contrib/wire/amqp/config.py b/src/asyncapi_python/contrib/wire/amqp/config.py index e1ab25b..3a62f78 100644 --- a/src/asyncapi_python/contrib/wire/amqp/config.py +++ b/src/asyncapi_python/contrib/wire/amqp/config.py @@ -25,6 +25,7 @@ class AmqpConfig: binding_type: AmqpBindingType = AmqpBindingType.QUEUE queue_properties: dict[str, Any] = field(default_factory=lambda: {}) binding_arguments: dict[str, Any] = field(default_factory=lambda: {}) + arguments: dict[str, Any] = field(default_factory=lambda: {}) def to_producer_args(self) -> dict[str, Any]: """Convert to AmqpProducer constructor arguments""" @@ -34,6 +35,7 @@ def to_producer_args(self) -> dict[str, Any]: "exchange_type": self.exchange_type, "routing_key": self.routing_key, "queue_properties": self.queue_properties, + "arguments": self.arguments, } def to_consumer_args(self) -> dict[str, Any]: @@ -46,4 +48,5 @@ def to_consumer_args(self) -> dict[str, Any]: "binding_type": self.binding_type, "queue_properties": self.queue_properties, "binding_arguments": self.binding_arguments, + "arguments": self.arguments, } diff --git a/src/asyncapi_python/contrib/wire/amqp/consumer.py b/src/asyncapi_python/contrib/wire/amqp/consumer.py index 1530644..3569473 100644 --- a/src/asyncapi_python/contrib/wire/amqp/consumer.py +++ b/src/asyncapi_python/contrib/wire/amqp/consumer.py @@ -36,6 +36,7 @@ def __init__( binding_type: AmqpBindingType = AmqpBindingType.QUEUE, queue_properties: dict[str, Any] | None = None, binding_arguments: dict[str, Any] | None = None, + arguments: dict[str, Any] | None = None, ): self._connection = connection self._queue_name = queue_name @@ -45,6 +46,7 @@ def __init__( self._binding_type = binding_type self._queue_properties = queue_properties or {} self._binding_arguments = binding_arguments or {} + self._arguments = arguments or {} self._channel: AbstractChannel | None = None self._queue: AbstractQueue | None = None self._exchange: AbstractExchange | None = None @@ -67,6 +69,7 @@ async def start(self) -> None: durable=self._queue_properties.get("durable", True), exclusive=self._queue_properties.get("exclusive", False), auto_delete=self._queue_properties.get("auto_delete", False), + arguments=self._arguments, ) # Simple queue binding pattern (default exchange) @@ -76,6 +79,7 @@ async def start(self) -> None: durable=self._queue_properties.get("durable", True), exclusive=self._queue_properties.get("exclusive", False), auto_delete=self._queue_properties.get("auto_delete", False), + arguments=self._arguments, ) # Routing key binding pattern (pub/sub with named exchange) @@ -87,24 +91,28 @@ async def start(self) -> None: name=self._exchange_name, type=ExchangeType.DIRECT, durable=True, + arguments=self._arguments, ) case "topic": self._exchange = await self._channel.declare_exchange( name=self._exchange_name, type=ExchangeType.TOPIC, durable=True, + arguments=self._arguments, ) case "fanout": self._exchange = await self._channel.declare_exchange( name=self._exchange_name, type=ExchangeType.FANOUT, durable=True, + arguments=self._arguments, ) case "headers": self._exchange = await self._channel.declare_exchange( name=self._exchange_name, type=ExchangeType.HEADERS, durable=True, + arguments=self._arguments, ) case unknown_type: raise ValueError(f"Unsupported exchange type: {unknown_type}") @@ -115,6 +123,7 @@ async def start(self) -> None: durable=self._queue_properties.get("durable", False), exclusive=self._queue_properties.get("exclusive", True), auto_delete=self._queue_properties.get("auto_delete", True), + arguments=self._arguments, ) # Bind queue to exchange with routing key @@ -129,24 +138,28 @@ async def start(self) -> None: name=self._exchange_name, type=ExchangeType.FANOUT, durable=True, + arguments=self._arguments, ) case "headers": self._exchange = await self._channel.declare_exchange( name=self._exchange_name, type=ExchangeType.HEADERS, durable=True, + arguments=self._arguments, ) case "topic": self._exchange = await self._channel.declare_exchange( name=self._exchange_name, type=ExchangeType.TOPIC, durable=True, + arguments=self._arguments, ) case "direct": self._exchange = await self._channel.declare_exchange( name=self._exchange_name, type=ExchangeType.DIRECT, durable=True, + arguments=self._arguments, ) case unknown_type: raise ValueError(f"Unsupported exchange type: {unknown_type}") @@ -157,6 +170,7 @@ async def start(self) -> None: durable=self._queue_properties.get("durable", False), exclusive=self._queue_properties.get("exclusive", True), auto_delete=self._queue_properties.get("auto_delete", True), + arguments=self._arguments, ) # Bind queue to exchange with binding arguments (for headers exchange) diff --git a/src/asyncapi_python/contrib/wire/amqp/producer.py b/src/asyncapi_python/contrib/wire/amqp/producer.py index 82da03c..e12fa9e 100644 --- a/src/asyncapi_python/contrib/wire/amqp/producer.py +++ b/src/asyncapi_python/contrib/wire/amqp/producer.py @@ -31,6 +31,7 @@ def __init__( exchange_type: str = "direct", routing_key: str = "", queue_properties: dict[str, Any] | None = None, + arguments: dict[str, Any] | None = None, ): self._connection = connection self._queue_name = queue_name @@ -38,6 +39,7 @@ def __init__( self._exchange_type = exchange_type self._routing_key = routing_key self._queue_properties = queue_properties or {} + self._arguments = arguments or {} self._channel: AbstractChannel | None = None self._target_exchange: AbstractExchange | None = None self._started = False @@ -61,27 +63,40 @@ async def start(self) -> None: durable=self._queue_properties.get("durable", True), exclusive=self._queue_properties.get("exclusive", False), auto_delete=self._queue_properties.get("auto_delete", False), + arguments=self._arguments, ) # Named exchange patterns case (exchange_name, "direct"): self._target_exchange = await self._channel.declare_exchange( - name=exchange_name, type=ExchangeType.DIRECT, durable=True + name=exchange_name, + type=ExchangeType.DIRECT, + durable=True, + arguments=self._arguments, ) case (exchange_name, "topic"): self._target_exchange = await self._channel.declare_exchange( - name=exchange_name, type=ExchangeType.TOPIC, durable=True + name=exchange_name, + type=ExchangeType.TOPIC, + durable=True, + arguments=self._arguments, ) case (exchange_name, "fanout"): self._target_exchange = await self._channel.declare_exchange( - name=exchange_name, type=ExchangeType.FANOUT, durable=True + name=exchange_name, + type=ExchangeType.FANOUT, + durable=True, + arguments=self._arguments, ) case (exchange_name, "headers"): self._target_exchange = await self._channel.declare_exchange( - name=exchange_name, type=ExchangeType.HEADERS, durable=True + name=exchange_name, + type=ExchangeType.HEADERS, + durable=True, + arguments=self._arguments, ) case (exchange_name, unknown_type): diff --git a/src/asyncapi_python/contrib/wire/amqp/resolver.py b/src/asyncapi_python/contrib/wire/amqp/resolver.py index 9f58255..ec6a71a 100644 --- a/src/asyncapi_python/contrib/wire/amqp/resolver.py +++ b/src/asyncapi_python/contrib/wire/amqp/resolver.py @@ -116,6 +116,7 @@ def resolve_amqp_config( "exclusive": True, "auto_delete": True, }, + arguments={}, ) # Reply channel with explicit address - check if direct queue or topic exchange @@ -133,6 +134,7 @@ def resolve_amqp_config( "exclusive": True, "auto_delete": True, }, + arguments={}, ) else: # Topic-based reply pattern - shared exchange with filtering @@ -143,6 +145,7 @@ def resolve_amqp_config( routing_key=app_id, # Filter messages by app_id binding_type=AmqpBindingType.REPLY, queue_properties={"durable": True, "exclusive": False}, + arguments={}, ) # Reply channel with binding - defer to binding resolution @@ -192,6 +195,7 @@ def resolve_amqp_config( routing_key=resolved_address, binding_type=AmqpBindingType.QUEUE, queue_properties={"durable": True, "exclusive": False}, + arguments={}, ) # Operation name pattern (fallback) @@ -204,6 +208,7 @@ def resolve_amqp_config( routing_key=op_name, binding_type=AmqpBindingType.QUEUE, queue_properties={"durable": True, "exclusive": False}, + arguments={}, ) # No match - reject creation @@ -245,6 +250,7 @@ def resolve_queue_binding( # Extract queue properties queue_config = getattr(binding, "queue", None) queue_properties = {"durable": True, "exclusive": False} # Defaults + arguments: dict[str, Any] = {} if queue_config: if hasattr(queue_config, "durable"): queue_properties["durable"] = queue_config.durable @@ -252,6 +258,8 @@ def resolve_queue_binding( queue_properties["exclusive"] = queue_config.exclusive if hasattr(queue_config, "auto_delete"): queue_properties["auto_delete"] = queue_config.auto_delete + if hasattr(queue_config, "arguments") and queue_config.arguments: + arguments = queue_config.arguments return AmqpConfig( queue_name=queue_name, @@ -259,6 +267,7 @@ def resolve_queue_binding( routing_key=queue_name, # For default exchange, routing_key = queue_name binding_type=AmqpBindingType.QUEUE, queue_properties=queue_properties, + arguments=arguments, ) @@ -303,6 +312,11 @@ def resolve_routing_key_binding( if exchange_config and hasattr(exchange_config, "type"): exchange_type = exchange_config.type + # Extract exchange arguments + arguments: dict[str, Any] = {} + if exchange_config and hasattr(exchange_config, "arguments") and exchange_config.arguments: + arguments = exchange_config.arguments + # Determine routing key - this is where wildcards are allowed match (getattr(binding, "routingKey", None), channel.address, operation_name): case (routing_key, _, _) if routing_key: @@ -327,6 +341,7 @@ def resolve_routing_key_binding( routing_key=resolved_routing_key, binding_type=AmqpBindingType.ROUTING_KEY, queue_properties={"durable": False, "exclusive": True, "auto_delete": True}, + arguments=arguments, ) @@ -366,6 +381,11 @@ def resolve_exchange_binding( if exchange_config and hasattr(exchange_config, "type"): exchange_type = exchange_config.type + # Extract exchange arguments + arguments: dict[str, Any] = {} + if exchange_config and hasattr(exchange_config, "arguments") and exchange_config.arguments: + arguments = exchange_config.arguments + # Extract binding arguments for headers exchange from dataclass binding_args: dict[str, Any] = {} # Note: bindingKeys is not part of AmqpChannelBinding spec @@ -379,4 +399,5 @@ def resolve_exchange_binding( binding_type=AmqpBindingType.EXCHANGE, queue_properties={"durable": False, "exclusive": True, "auto_delete": True}, binding_arguments=binding_args, + arguments=arguments, ) diff --git a/src/asyncapi_python/kernel/document/bindings.py b/src/asyncapi_python/kernel/document/bindings.py index ba5abbf..97b6f92 100644 --- a/src/asyncapi_python/kernel/document/bindings.py +++ b/src/asyncapi_python/kernel/document/bindings.py @@ -26,13 +26,14 @@ class AmqpExchange: durable: Optional[bool] = None auto_delete: Optional[bool] = None vhost: Optional[str] = None + arguments: Optional[Dict[str, Any]] = None def __repr__(self) -> str: """Custom repr to handle enum properly for code generation.""" from asyncapi_python.kernel.document.bindings import AmqpExchangeType _ = AmqpExchangeType # Explicitly reference the import - return f"spec.AmqpExchange(name={self.name!r}, type=spec.AmqpExchangeType.{self.type.name}, durable={self.durable!r}, auto_delete={self.auto_delete!r}, vhost={self.vhost!r})" + return f"spec.AmqpExchange(name={self.name!r}, type=spec.AmqpExchangeType.{self.type.name}, durable={self.durable!r}, auto_delete={self.auto_delete!r}, vhost={self.vhost!r}, arguments={self.arguments!r})" @dataclass @@ -44,10 +45,11 @@ class AmqpQueue: exclusive: Optional[bool] = None auto_delete: Optional[bool] = None vhost: Optional[str] = None + arguments: Optional[Dict[str, Any]] = None def __repr__(self) -> str: """Custom repr for code generation.""" - return f"spec.AmqpQueue(name={self.name!r}, durable={self.durable!r}, exclusive={self.exclusive!r}, auto_delete={self.auto_delete!r}, vhost={self.vhost!r})" + return f"spec.AmqpQueue(name={self.name!r}, durable={self.durable!r}, exclusive={self.exclusive!r}, auto_delete={self.auto_delete!r}, vhost={self.vhost!r}, arguments={self.arguments!r})" @dataclass @@ -159,6 +161,7 @@ def create_amqp_binding_from_dict(binding_dict: Dict[str, Any]) -> AmqpChannelBi exclusive=queue_config.get("exclusive"), auto_delete=queue_config.get("auto_delete"), vhost=queue_config.get("vhost"), + arguments=queue_config.get("arguments"), ) elif binding_type == "routingKey" and "exchange" in binding_dict: exchange_config = binding_dict["exchange"] @@ -176,6 +179,7 @@ def create_amqp_binding_from_dict(binding_dict: Dict[str, Any]) -> AmqpChannelBi durable=exchange_config.get("durable"), auto_delete=exchange_config.get("auto_delete"), vhost=exchange_config.get("vhost"), + arguments=exchange_config.get("arguments"), ) return binding From 5f519cc105be1971721cf2a701e332f2e9a11016 Mon Sep 17 00:00:00 2001 From: Yaroslav Petrov Date: Sun, 16 Nov 2025 11:57:13 +0000 Subject: [PATCH 3/3] Format code --- src/asyncapi_python/contrib/wire/amqp/resolver.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/asyncapi_python/contrib/wire/amqp/resolver.py b/src/asyncapi_python/contrib/wire/amqp/resolver.py index ec6a71a..c9d54ae 100644 --- a/src/asyncapi_python/contrib/wire/amqp/resolver.py +++ b/src/asyncapi_python/contrib/wire/amqp/resolver.py @@ -314,7 +314,11 @@ def resolve_routing_key_binding( # Extract exchange arguments arguments: dict[str, Any] = {} - if exchange_config and hasattr(exchange_config, "arguments") and exchange_config.arguments: + if ( + exchange_config + and hasattr(exchange_config, "arguments") + and exchange_config.arguments + ): arguments = exchange_config.arguments # Determine routing key - this is where wildcards are allowed @@ -383,7 +387,11 @@ def resolve_exchange_binding( # Extract exchange arguments arguments: dict[str, Any] = {} - if exchange_config and hasattr(exchange_config, "arguments") and exchange_config.arguments: + if ( + exchange_config + and hasattr(exchange_config, "arguments") + and exchange_config.arguments + ): arguments = exchange_config.arguments # Extract binding arguments for headers exchange from dataclass