From 31e470954a50a0d9f1ce8f1a7927e4a60c5ec7f9 Mon Sep 17 00:00:00 2001 From: Yaroslav Petrov Date: Sat, 15 Nov 2025 15:25:01 +0000 Subject: [PATCH] Add safe root model handling for path searching in json codec --- pyproject.toml | 2 +- src/asyncapi_python/contrib/codec/json.py | 32 ++- tests/core/codec/test_json_codec.py | 232 ++++++++++++++++++ .../test_parameterized_subscriptions.py | 141 ++++++++++- uv.lock | 2 +- 5 files changed, 397 insertions(+), 12 deletions(-) create mode 100644 tests/core/codec/test_json_codec.py diff --git a/pyproject.toml b/pyproject.toml index 88a99fa..bd4b11f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "asyncapi-python" -version = "0.3.0rc8" +version = "0.3.0rc9" license = { text = "Apache-2.0" } description = "Easily generate type-safe and async Python applications from AsyncAPI 3 specifications." authors = [{ name = "Yaroslav Petrov", email = "yaroslav.v.petrov@gmail.com" }] diff --git a/src/asyncapi_python/contrib/codec/json.py b/src/asyncapi_python/contrib/codec/json.py index ce5aded..437101c 100644 --- a/src/asyncapi_python/contrib/codec/json.py +++ b/src/asyncapi_python/contrib/codec/json.py @@ -1,9 +1,9 @@ import json from enum import Enum from types import ModuleType -from typing import ClassVar, Type +from typing import Any, ClassVar, Type -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, RootModel, ValidationError from asyncapi_python.kernel.codec import Codec, CodecFactory from asyncapi_python.kernel.document.message import Message @@ -31,13 +31,16 @@ def decode(self, payload: bytes) -> BaseModel: def extract_field(self, payload: BaseModel, location: str) -> str: """Extract field from Pydantic model using location path. + Handles both regular BaseModel and RootModel wrappers. RootModel instances + are automatically unwrapped (recursively) to access the underlying data. + Examples: "$message.payload#/userId" → payload.userId → "123" "$message.payload#/user/id" → payload.user.id → "456" "$message.payload#/items" → payload.items → "[1, 2, 3]" Args: - payload: Pydantic BaseModel instance + payload: Pydantic BaseModel instance (may be RootModel wrapper) location: Location expression like "$message.payload#/userId" Returns: @@ -56,18 +59,29 @@ def extract_field(self, payload: BaseModel, location: str) -> str: parts = [p for p in path.split("/") if p] try: - value = payload + value: Any = payload for part in parts: - value = getattr(value, part) + # Recursively unwrap any RootModel wrappers before accessing attributes + while isinstance(value, RootModel): + value = value.root # type: ignore[assignment, misc] + value = getattr(value, part) # type: ignore[arg-type] + + # Unwrap final value if it's a RootModel + while isinstance(value, RootModel): + value = value.root # type: ignore[assignment, misc] # Convert to string - if isinstance(value, (str, int, float, bool)): - return str(value) - elif isinstance(value, Enum): + # Check Enum FIRST (before str/int/etc) because str/int Enums are also instances of str/int + if isinstance(value, Enum): # Handle Enum types - extract the value attribute return str(value.value) + elif isinstance(value, (str, int, float, bool)): + return str(value) + elif isinstance(value, BaseModel): + # Pydantic models: dump to dict then JSON serialize + return json.dumps(value.model_dump()) else: - # Complex types: JSON serialize + # Other complex types: JSON serialize directly return json.dumps(value) except AttributeError as e: diff --git a/tests/core/codec/test_json_codec.py b/tests/core/codec/test_json_codec.py new file mode 100644 index 0000000..4e52310 --- /dev/null +++ b/tests/core/codec/test_json_codec.py @@ -0,0 +1,232 @@ +"""Tests for JSON codec extract_field() method with RootModel support""" + +from enum import Enum + +import pytest +from pydantic import BaseModel, RootModel + +from asyncapi_python.contrib.codec.json import JsonCodec + + +# Test models +class SimpleMessage(BaseModel): + """Regular BaseModel for testing""" + + chat_id: int + message: str + + +class NestedUser(BaseModel): + """Nested model for path traversal testing""" + + id: str + name: str + + +class MessageWithNested(BaseModel): + """Model with nested fields""" + + user: NestedUser + content: str + + +class Severity(str, Enum): + """Enum for testing enum extraction""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class MessageWithEnum(BaseModel): + """Model with enum field""" + + severity: Severity + description: str + + +class ComplexData(BaseModel): + """Complex nested data""" + + items: list[str] + metadata: dict[str, str] + + +class MessageWithComplex(BaseModel): + """Model with complex types""" + + data: ComplexData + + +# RootModel wrappers +class SimpleRootModel(RootModel[SimpleMessage]): + """Single-level RootModel wrapper""" + + root: SimpleMessage + + +class InnerRootModel(RootModel[NestedUser]): + """Inner RootModel for nested testing""" + + root: NestedUser + + +class OuterMessageWithRootModel(BaseModel): + """Message containing a RootModel field""" + + user: InnerRootModel + content: str + + +class DoubleRootModel(RootModel[SimpleRootModel]): + """Nested RootModel (RootModel containing RootModel)""" + + root: SimpleRootModel + + +# Tests +def test_extract_field_from_base_model(): + """Test extracting fields from regular BaseModel""" + codec = JsonCodec(SimpleMessage) + message = SimpleMessage(chat_id=123, message="hello") + + result = codec.extract_field(message, "$message.payload#/chat_id") + assert result == "123" + + result = codec.extract_field(message, "$message.payload#/message") + assert result == "hello" + + +def test_extract_field_from_root_model(): + """Test extracting fields from single-level RootModel wrapper""" + codec = JsonCodec(SimpleRootModel) + wrapped = SimpleRootModel.model_validate({"chat_id": 456, "message": "world"}) + + # Should unwrap RootModel and access fields on the root + result = codec.extract_field(wrapped, "$message.payload#/chat_id") + assert result == "456" + + result = codec.extract_field(wrapped, "$message.payload#/message") + assert result == "world" + + +def test_extract_field_from_nested_root_model(): + """Test extracting fields from nested RootModel (RootModel containing RootModel)""" + codec = JsonCodec(DoubleRootModel) + + # Create nested RootModel: DoubleRootModel -> SimpleRootModel -> SimpleMessage + inner = SimpleRootModel.model_validate({"chat_id": 789, "message": "nested"}) + wrapped = DoubleRootModel.model_validate(inner.model_dump()) + + # Should recursively unwrap both RootModel layers + result = codec.extract_field(wrapped, "$message.payload#/chat_id") + assert result == "789" + + result = codec.extract_field(wrapped, "$message.payload#/message") + assert result == "nested" + + +def test_extract_field_nested_path(): + """Test extracting nested fields using path like $message.payload#/user/id""" + codec = JsonCodec(MessageWithNested) + message = MessageWithNested( + user=NestedUser(id="user123", name="Alice"), content="test" + ) + + result = codec.extract_field(message, "$message.payload#/user/id") + assert result == "user123" + + result = codec.extract_field(message, "$message.payload#/user/name") + assert result == "Alice" + + +def test_extract_field_nested_path_with_root_model(): + """Test extracting nested fields when intermediate field is a RootModel""" + codec = JsonCodec(OuterMessageWithRootModel) + + # The user field is a RootModel wrapper + user_wrapped = InnerRootModel.model_validate({"id": "user456", "name": "Bob"}) + message = OuterMessageWithRootModel(user=user_wrapped, content="test") + + # Should unwrap the RootModel at the intermediate step + result = codec.extract_field(message, "$message.payload#/user/id") + assert result == "user456" + + result = codec.extract_field(message, "$message.payload#/user/name") + assert result == "Bob" + + +def test_extract_field_enum_value(): + """Test extracting enum values (should return the enum value, not the enum object)""" + codec = JsonCodec(MessageWithEnum) + message = MessageWithEnum(severity=Severity.HIGH, description="critical alert") + + result = codec.extract_field(message, "$message.payload#/severity") + assert result == "high" # Should extract the value, not "Severity.HIGH" + + +def test_extract_field_complex_type(): + """Test extracting complex types (should JSON serialize)""" + codec = JsonCodec(MessageWithComplex) + message = MessageWithComplex( + data=ComplexData(items=["a", "b", "c"], metadata={"key": "value"}) + ) + + result = codec.extract_field(message, "$message.payload#/data") + # Should be JSON serialized + assert '"items": ["a", "b", "c"]' in result + assert '"metadata": {"key": "value"}' in result + + +def test_extract_field_invalid_location(): + """Test error handling for invalid location format""" + codec = JsonCodec(SimpleMessage) + message = SimpleMessage(chat_id=123, message="hello") + + with pytest.raises(ValueError, match="Invalid location format"): + codec.extract_field(message, "invalid/location") + + with pytest.raises(ValueError, match="Invalid location format"): + codec.extract_field(message, "#/chat_id") + + +def test_extract_field_missing_path(): + """Test error handling for non-existent paths""" + codec = JsonCodec(SimpleMessage) + message = SimpleMessage(chat_id=123, message="hello") + + with pytest.raises(ValueError, match="Path 'nonexistent' not found in payload"): + codec.extract_field(message, "$message.payload#/nonexistent") + + +def test_extract_field_missing_nested_path(): + """Test error handling for non-existent nested paths""" + codec = JsonCodec(MessageWithNested) + message = MessageWithNested( + user=NestedUser(id="user123", name="Alice"), content="test" + ) + + with pytest.raises( + ValueError, match="Path 'user/nonexistent' not found in payload" + ): + codec.extract_field(message, "$message.payload#/user/nonexistent") + + +def test_extract_field_primitive_types() -> None: + """Test extraction returns proper string representations of primitive types""" + + class PrimitiveMessage(BaseModel): + str_field: str + int_field: int + float_field: float + bool_field: bool + + codec = JsonCodec(PrimitiveMessage) + message = PrimitiveMessage( + str_field="test", int_field=42, float_field=3.14, bool_field=True + ) + + assert codec.extract_field(message, "$message.payload#/str_field") == "test" + assert codec.extract_field(message, "$message.payload#/int_field") == "42" + assert codec.extract_field(message, "$message.payload#/float_field") == "3.14" + assert codec.extract_field(message, "$message.payload#/bool_field") == "True" diff --git a/tests/core/endpoint/test_parameterized_subscriptions.py b/tests/core/endpoint/test_parameterized_subscriptions.py index 038af0a..d23f141 100644 --- a/tests/core/endpoint/test_parameterized_subscriptions.py +++ b/tests/core/endpoint/test_parameterized_subscriptions.py @@ -17,7 +17,7 @@ from asyncapi_python.kernel.document.channel import AddressParameter, ChannelBindings from asyncapi_python.kernel.document.message import Message from asyncapi_python.kernel.endpoint import Publisher, Subscriber -from pydantic import BaseModel +from pydantic import BaseModel, RootModel class AlertMessage(BaseModel): @@ -632,3 +632,142 @@ async def handle_alert(msg: AlertMessage) -> None: # Should raise ValueError when starting with pytest.raises(ValueError, match="Unexpected parameters"): await subscriber.start() + + +async def test_publisher_extracts_parameters_from_root_model(): + """Publisher should extract parameters from RootModel-wrapped payloads.""" + from asyncapi_python.contrib.wire.in_memory import get_bus, reset_bus + + # Reset the bus for clean test + reset_bus() + wire = InMemoryWire() + + import types + + # Create RootModel wrapper for alert message + class AlertRootModel(RootModel[AlertMessage]): + """RootModel wrapper for testing parameter extraction""" + + root: AlertMessage + + test_module = types.SimpleNamespace() + test_module.messages = types.SimpleNamespace() + test_module.messages.json = types.SimpleNamespace() + test_module.messages.json.TestMessage = AlertRootModel + codec_factory = JsonCodecFactory(test_module) + + # Create parameterized channel directly + channel = Channel( + key="test_channel", + address="alerts.{location}.{severity}", + title="Test Channel", + summary=None, + description=None, + servers=[], + messages={ + "TestMessage": Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + }, + parameters={ + "location": AddressParameter( + key="location", + description="Location code", + location="$message.payload#/location", + ), + "severity": AddressParameter( + key="severity", + description="Severity level", + location="$message.payload#/severity", + ), + }, + tags=[], + external_docs=None, + bindings=ChannelBindings( + amqp=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="test_exchange", + type=AmqpExchangeType.TOPIC, + ), + ) + ), + ) + + # Create operation with parameterized channel + operation = Operation( + key="test_op", + action="send", + title=None, + summary=None, + description=None, + channel=channel, + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + publisher = Publisher( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + # Capture the channel name used for publishing by intercepting bus.publish + captured_channel: str | None = None + bus = get_bus() + original_publish = bus.publish + + async def mock_publish(channel_name: str, message: Any) -> None: + nonlocal captured_channel + captured_channel = channel_name + await original_publish(channel_name, message) + + bus.publish = mock_publish # type: ignore + + await publisher.start() + + # Send message wrapped in RootModel + wrapped_message = AlertRootModel.model_validate( + {"location": "NYC", "severity": "high", "data": "test"} + ) + await publisher(wrapped_message) + + # Should extract parameters from RootModel and build correct address + assert captured_channel == "alerts.NYC.high" + + await publisher.stop() diff --git a/uv.lock b/uv.lock index 0ee18d7..4bd3842 100644 --- a/uv.lock +++ b/uv.lock @@ -64,7 +64,7 @@ wheels = [ [[package]] name = "asyncapi-python" -version = "0.3.0rc7" +version = "0.3.0rc9" source = { editable = "." } dependencies = [ { name = "pydantic" },