diff --git a/examples/amqp-topic/.gitignore b/examples/amqp-topic/.gitignore new file mode 100644 index 0000000..91eb213 --- /dev/null +++ b/examples/amqp-topic/.gitignore @@ -0,0 +1,7 @@ +# Generated code directories +publisher/ +subscriber1/ +subscriber2/ + +# Virtual environment +.venv/ \ No newline at end of file diff --git a/examples/amqp-topic/Makefile b/examples/amqp-topic/Makefile new file mode 100644 index 0000000..bb60a75 --- /dev/null +++ b/examples/amqp-topic/Makefile @@ -0,0 +1,64 @@ +.PHONY: venv install generate publisher subscriber1 subscriber2 clean help + +# Virtual environment +VENV_NAME := .venv +PYTHON := $(VENV_NAME)/bin/python +PIP := $(VENV_NAME)/bin/pip +CODEGEN := $(VENV_NAME)/bin/asyncapi-python-codegen + +help: + @echo "Available targets:" + @echo " make venv - Create virtual environment" + @echo " make install - Install dependencies" + @echo " make generate - Generate code from AsyncAPI specs" + @echo " make publisher - Run the publisher" + @echo " make subscriber1 - Run subscriber 1" + @echo " make subscriber2 - Run subscriber 2" + @echo " make clean - Remove virtual environment and generated code" + +venv: + @echo "Creating virtual environment..." + python3 -m venv $(VENV_NAME) + @echo "✅ Virtual environment created" + +install: venv + @echo "Installing dependencies..." + $(PIP) install -e ../../[amqp,codegen] + @echo "✅ Dependencies installed" + +generate: install + @echo "Generating publisher code..." + $(CODEGEN) spec/publisher.asyncapi.yaml publisher --force + @echo "✅ Publisher code generated" + @echo "" + @echo "Generating subscriber1 code..." + $(CODEGEN) spec/subscriber1.asyncapi.yaml subscriber1 --force + @echo "✅ Subscriber1 code generated" + @echo "" + @echo "Generating subscriber2 code..." + $(CODEGEN) spec/subscriber2.asyncapi.yaml subscriber2 --force + @echo "✅ Subscriber2 code generated" + +publisher: generate + @echo "Starting publisher..." + @echo "" + $(PYTHON) main-publisher.py + +subscriber1: generate + @echo "Starting subscriber 1..." + @echo "" + $(PYTHON) main-subscriber1.py + +subscriber2: generate + @echo "Starting subscriber 2..." + @echo "" + $(PYTHON) main-subscriber2.py + +clean: + @echo "Cleaning up..." + rm -rf $(VENV_NAME) + rm -rf publisher/ + rm -rf subscriber1/ + rm -rf subscriber2/ + rm -rf __pycache__ + @echo "✅ Cleanup complete" diff --git a/examples/amqp-topic/README.md b/examples/amqp-topic/README.md new file mode 100644 index 0000000..24a796c --- /dev/null +++ b/examples/amqp-topic/README.md @@ -0,0 +1,138 @@ +# AMQP Topic Exchange Example + +Demonstrates **parameterized channels with wildcard subscriptions** using AMQP topic exchanges. + +## Overview + +Weather alert system showing: +- Publishers send to specific routing keys (concrete parameters) +- Subscribers use wildcards (`*` and `#`) for pattern matching +- Topic exchange routes messages based on routing key patterns + +## Architecture + +``` +Topic Exchange: weather_alerts +Channel: weather.{location}.{severity} + +Routing Keys: + weather.NYC.high + weather.LA.low + weather.CHI.critical +``` + +## Project Structure + +``` +examples/amqp-topic/ +├── spec/ +│ ├── common.asyncapi.yaml # Shared channel/message definitions +│ ├── publisher.asyncapi.yaml # Publisher app spec +│ ├── subscriber1.asyncapi.yaml # Subscriber 1 spec +│ └── subscriber2.asyncapi.yaml # Subscriber 2 spec +├── main-publisher.py # Publisher implementation +├── main-subscriber1.py # Subscriber 1 implementation +├── main-subscriber2.py # Subscriber 2 implementation +├── Makefile # Build and run commands +└── README.md +``` + +## Usage + +### 1. Generate Code + +```bash +make generate +``` + +This generates type-safe Python code from AsyncAPI specs: +- `publisher/` - from `spec/publisher.asyncapi.yaml` +- `subscriber1/` - from `spec/subscriber1.asyncapi.yaml` +- `subscriber2/` - from `spec/subscriber2.asyncapi.yaml` + +### 2. Run Publisher + +```bash +make publisher +``` + +Publishes weather alerts to the topic exchange. + +### 3. Run Subscribers + +Terminal 1: +```bash +make subscriber1 +``` + +Terminal 2: +```bash +make subscriber2 +``` + +## Key Features + +### Parameterized Channels + +Channel address: `weather.{location}.{severity}` + +Parameters are extracted from message payload: +```python +WeatherAlert( + location="NYC", # → {location} + severity="high", # → {severity} + ... +) +# Creates routing key: weather.NYC.high +``` + +### Wildcard Subscriptions + +Subscribers can use AMQP wildcards for pattern matching: +- `*` - Matches exactly one word +- `#` - Matches zero or more words + +**This Example**: +- **Subscriber 1**: `weather.NYC.*` - All NYC alerts (any severity) + - Uses `parameters={"location": "NYC"}` + - Receives: NYC-HIGH +- **Subscriber 2**: `weather.*.critical` - Critical alerts (any location) + - Uses `parameters={"severity": "critical"}` + - Receives: CHI-CRITICAL + +**Other Possible Patterns**: +- `weather.LA.*` - All LA alerts +- `weather.*.high` - High severity alerts from any location +- `weather.*.*` - ALL weather alerts (empty parameters) + +### Parameter Validation + +The runtime enforces: +- ✅ All required parameters must be provided +- ✅ Exact match required (strict validation) +- ✅ Queue bindings reject wildcards (concrete values only) +- ✅ Routing key bindings accept wildcards (pattern matching) + +## Development + +### Clean Up + +```bash +make clean +``` + +Removes virtual environment and generated code. + +### Help + +```bash +make help +``` + +Shows available Makefile targets. + +## Learn More + +- [AsyncAPI Specification](https://www.asyncapi.com/) +- [AMQP Topic Exchanges](https://www.rabbitmq.com/tutorials/tutorial-five-python.html) +- [AsyncAPI Python Documentation](https://github.com/yourorg/asyncapi-python) diff --git a/examples/amqp-topic/main-publisher.py b/examples/amqp-topic/main-publisher.py new file mode 100644 index 0000000..86d2d66 --- /dev/null +++ b/examples/amqp-topic/main-publisher.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +""" +Weather Alert Publisher + +Publishes weather alerts to an AMQP topic exchange with dynamic routing keys. +The routing key is built from the message payload fields (location and severity). + +Example usage: + python main-publisher.py +""" + +import asyncio +from datetime import datetime, timezone +from os import environ + +from asyncapi_python.contrib.wire.amqp import AmqpWire +from publisher import Application +from publisher.messages.json import WeatherAlert, Severity + +# AMQP connection URI (can be overridden via environment variable) +AMQP_URI = environ.get("AMQP_URI", "amqp://guest:guest@localhost") + +# Initialize application with AMQP wire +app = Application(AmqpWire(AMQP_URI)) + + +async def main() -> None: + """Main publisher routine""" + print("🌤️ Weather Alert Publisher") + print("=" * 50) + print(f"Connecting to: {AMQP_URI}") + + # Start the application + await app.start() + print("✅ Connected to AMQP broker") + print() + + # Sample weather alerts to publish + alerts = [ + WeatherAlert( + location="NYC", + severity=Severity.HIGH, + temperature=95, + description="Heat wave warning in effect. Stay hydrated!", + timestamp=datetime.now(timezone.utc), + ), + WeatherAlert( + location="LA", + severity=Severity.LOW, + temperature=72, + description="Sunny and pleasant weather expected.", + timestamp=datetime.now(timezone.utc), + ), + WeatherAlert( + location="CHI", + severity=Severity.CRITICAL, + temperature=5, + description="Severe winter storm approaching. Travel not recommended.", + timestamp=datetime.now(timezone.utc), + ), + WeatherAlert( + location="MIA", + severity=Severity.MEDIUM, + temperature=88, + description="Scattered thunderstorms expected this afternoon.", + timestamp=datetime.now(timezone.utc), + ), + WeatherAlert( + location="SEA", + severity=Severity.LOW, + temperature=65, + description="Light rain throughout the day.", + timestamp=datetime.now(timezone.utc), + ), + ] + + # Publish each alert + print("📡 Publishing weather alerts...") + print() + + for alert in alerts: + # The routing key will be dynamically built as: weather.{location}.{severity} + # For example: weather.NYC.high, weather.LA.low, etc. + await app.producer.publish_weather_alert(alert) + + print(f"✉️ Published alert:") + print(f" Routing Key: weather.{alert.location}.{alert.severity.value}") + print(f" Location: {alert.location}") + print(f" Severity: {alert.severity.value}") + print(f" Temperature: {alert.temperature}°F") + print(f" Description: {alert.description}") + print() + + # Small delay between messages for visibility + await asyncio.sleep(0.5) + + print(f"✅ Published {len(alerts)} weather alerts") + print() + + # Stop the application + await app.stop() + print("👋 Disconnected from AMQP broker") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/amqp-topic/main-subscriber1.py b/examples/amqp-topic/main-subscriber1.py new file mode 100644 index 0000000..902cbd2 --- /dev/null +++ b/examples/amqp-topic/main-subscriber1.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" +Weather Alert Subscriber 1 + +Subscribes to weather alerts from an AMQP topic exchange. +This is the first subscriber - demonstrates receiving all alerts. + +Example usage: + python main-subscriber1.py +""" + +import asyncio +import signal +from os import environ +from types import FrameType + +from asyncapi_python.contrib.wire.amqp import AmqpWire +from subscriber1 import Application +from subscriber1.messages.json import WeatherAlert + +# AMQP connection URI (can be overridden via environment variable) +AMQP_URI = environ.get("AMQP_URI", "amqp://guest:guest@localhost") + +# Initialize application with AMQP wire +app = Application(AmqpWire(AMQP_URI)) + +# Shutdown event +shutdown_event = asyncio.Event() + + +def signal_handler(signum: int, frame: FrameType | None) -> None: + """Handle shutdown signals""" + print("\n⚠️ Shutdown signal received") + shutdown_event.set() + + +# Register signal handlers +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + + +@app.consumer.receive_weather_alert(parameters={"location": "NYC"}) +async def handle_weather_alert(alert: WeatherAlert) -> None: + """ + Handle incoming weather alerts for NYC. + + This handler subscribes to weather.NYC.* pattern to receive + all NYC alerts regardless of severity. + """ + # Determine severity emoji + severity_emoji = { + "low": "🟢", + "medium": "🟡", + "high": "🟠", + "critical": "🔴", + }.get(alert.severity.value, "⚪") + + print(f"\n{severity_emoji} Weather Alert Received [SUBSCRIBER 1]") + print(f" Location: {alert.location}") + print(f" Severity: {alert.severity.value.upper()}") + print(f" Temperature: {alert.temperature}°F") + print(f" Description: {alert.description}") + print(f" Timestamp: {alert.timestamp}") + + +async def main() -> None: + """Main subscriber routine""" + print("🌤️ Weather Alert Subscriber 1") + print("=" * 50) + print(f"Connecting to: {AMQP_URI}") + + # Start the application + await app.start() + print("✅ Connected to AMQP broker") + print("👂 Listening for weather alerts...") + print(" (Press Ctrl+C to stop)") + print() + + # Wait for shutdown signal + await shutdown_event.wait() + + # Stop the application + print("\n🛑 Stopping subscriber...") + await app.stop() + print("👋 Disconnected from AMQP broker") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/amqp-topic/main-subscriber2.py b/examples/amqp-topic/main-subscriber2.py new file mode 100644 index 0000000..edd51b2 --- /dev/null +++ b/examples/amqp-topic/main-subscriber2.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +""" +Weather Alert Subscriber 2 + +Subscribes to weather alerts from an AMQP topic exchange. +This is the second subscriber - demonstrates multiple independent consumers. + +Example usage: + python main-subscriber2.py +""" + +import asyncio +import signal +from os import environ +from types import FrameType + +from asyncapi_python.contrib.wire.amqp import AmqpWire +from subscriber2 import Application +from subscriber2.messages.json import WeatherAlert + +# AMQP connection URI (can be overridden via environment variable) +AMQP_URI = environ.get("AMQP_URI", "amqp://guest:guest@localhost") + +# Initialize application with AMQP wire +app = Application(AmqpWire(AMQP_URI)) + +# Shutdown event +shutdown_event = asyncio.Event() + + +def signal_handler(signum: int, frame: FrameType | None) -> None: + """Handle shutdown signals""" + print("\n⚠️ Shutdown signal received") + shutdown_event.set() + + +# Register signal handlers +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + + +@app.consumer.receive_weather_alert(parameters={"severity": "critical"}) +async def handle_weather_alert(alert: WeatherAlert) -> None: + """ + Handle incoming critical weather alerts. + + This handler subscribes to weather.*.critical pattern to receive + all critical alerts regardless of location. + """ + # Determine severity emoji + severity_emoji = { + "low": "🟢", + "medium": "🟡", + "high": "🟠", + "critical": "🔴", + }.get(alert.severity.value, "⚪") + + print(f"\n{severity_emoji} Weather Alert Received [SUBSCRIBER 2]") + print(f" Location: {alert.location}") + print(f" Severity: {alert.severity.value.upper()}") + print(f" Temperature: {alert.temperature}°F") + print(f" Description: {alert.description}") + print(f" Timestamp: {alert.timestamp}") + + +async def main() -> None: + """Main subscriber routine""" + print("🌤️ Weather Alert Subscriber 2") + print("=" * 50) + print(f"Connecting to: {AMQP_URI}") + + # Start the application + await app.start() + print("✅ Connected to AMQP broker") + print("👂 Listening for weather alerts...") + print(" (Press Ctrl+C to stop)") + print() + + # Wait for shutdown signal + await shutdown_event.wait() + + # Stop the application + print("\n🛑 Stopping subscriber...") + await app.stop() + print("👋 Disconnected from AMQP broker") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/amqp-topic/spec/common.asyncapi.yaml b/examples/amqp-topic/spec/common.asyncapi.yaml new file mode 100644 index 0000000..28c7d92 --- /dev/null +++ b/examples/amqp-topic/spec/common.asyncapi.yaml @@ -0,0 +1,81 @@ +asyncapi: "3.0.0" + +info: + title: Weather Alert System - Common Definitions + version: 1.0.0 + description: | + Shared channel and message definitions for a weather alert system using AMQP topic exchange. + Demonstrates parameterized channels where routing keys are dynamically built from message payload. + +channels: + weatherAlerts: + address: weather.{location}.{severity} + title: Weather Alert Channel + description: | + Topic-based weather alerts with dynamic routing keys. + Parameters are extracted from message payload at runtime. + parameters: + location: + description: Geographic location code (e.g., NYC, LA, CHI) + location: $message.payload#/location + severity: + description: Alert severity level + location: $message.payload#/severity + bindings: + amqp: + is: routingKey + exchange: + name: weather_alerts + type: topic + durable: true + autoDelete: false + vhost: / + messages: + WeatherAlert: + name: WeatherAlert + title: Weather Alert Message + contentType: application/json + payload: + type: object + required: + - location + - severity + - temperature + - description + - timestamp + properties: + location: + type: string + description: Geographic location code + pattern: '^[A-Z]{2,5}$' + examples: + - NYC + - LA + - CHI + severity: + type: string + description: Alert severity level + enum: + - low + - medium + - high + - critical + examples: + - high + temperature: + type: number + description: Current temperature in Fahrenheit + minimum: -50 + maximum: 150 + examples: + - 95 + description: + type: string + description: Human-readable alert description + maxLength: 500 + examples: + - Heat wave warning in effect + timestamp: + type: string + format: date-time + description: Alert timestamp in ISO 8601 format diff --git a/examples/amqp-topic/spec/publisher.asyncapi.yaml b/examples/amqp-topic/spec/publisher.asyncapi.yaml new file mode 100644 index 0000000..72a98f8 --- /dev/null +++ b/examples/amqp-topic/spec/publisher.asyncapi.yaml @@ -0,0 +1,25 @@ +asyncapi: "3.0.0" + +info: + title: Weather Alert Publisher + version: 1.0.0 + description: | + Publisher service that sends weather alerts to a topic exchange. + Uses parameterized channels to dynamically route messages based on location and severity. + +operations: + publishWeatherAlert: + action: send + title: Publish Weather Alert + description: | + Sends a weather alert to the topic exchange. + The routing key is dynamically built from the message payload: + - location field becomes the {location} parameter + - severity field becomes the {severity} parameter + + Example: payload with location="NYC" and severity="high" + creates routing key "weather.NYC.high" + channel: + $ref: "./common.asyncapi.yaml#/channels/weatherAlerts" + messages: + - $ref: "./common.asyncapi.yaml#/channels/weatherAlerts/messages/WeatherAlert" diff --git a/examples/amqp-topic/spec/subscriber1.asyncapi.yaml b/examples/amqp-topic/spec/subscriber1.asyncapi.yaml new file mode 100644 index 0000000..0adc2c4 --- /dev/null +++ b/examples/amqp-topic/spec/subscriber1.asyncapi.yaml @@ -0,0 +1,24 @@ +asyncapi: "3.0.0" + +info: + title: Weather Alert Subscriber 1 + version: 1.0.0 + description: | + First subscriber service that receives all weather alerts from the topic exchange. + Demonstrates topic subscription with wildcard routing keys. + +operations: + receiveWeatherAlert: + action: receive + title: Receive Weather Alert + description: | + Receives weather alerts from the topic exchange. + This subscriber binds to all routing keys matching "weather.*.#" + to receive all alerts regardless of location or severity. + channel: + $ref: "./common.asyncapi.yaml#/channels/weatherAlerts" + messages: + - $ref: "./common.asyncapi.yaml#/channels/weatherAlerts/messages/WeatherAlert" + bindings: + amqp: + ack: true diff --git a/examples/amqp-topic/spec/subscriber2.asyncapi.yaml b/examples/amqp-topic/spec/subscriber2.asyncapi.yaml new file mode 100644 index 0000000..9079af5 --- /dev/null +++ b/examples/amqp-topic/spec/subscriber2.asyncapi.yaml @@ -0,0 +1,24 @@ +asyncapi: "3.0.0" + +info: + title: Weather Alert Subscriber 2 + version: 1.0.0 + description: | + Second subscriber service that receives all weather alerts from the topic exchange. + Demonstrates multiple independent consumers on the same channel. + +operations: + receiveWeatherAlert: + action: receive + title: Receive Weather Alert + description: | + Receives weather alerts from the topic exchange. + This is a second independent subscriber that also receives all alerts, + demonstrating that multiple consumers can subscribe to the same topic. + channel: + $ref: "./common.asyncapi.yaml#/channels/weatherAlerts" + messages: + - $ref: "./common.asyncapi.yaml#/channels/weatherAlerts/messages/WeatherAlert" + bindings: + amqp: + ack: true diff --git a/pyproject.toml b/pyproject.toml index 45aad4c..21bf464 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "asyncapi-python" -version = "0.3.0rc5" +version = "0.3.0rc6" license = { text = "Apache-2.0" } description = "Easily generate type-safe and async Python applications from AsyncAPI 3 specifications." authors = [{ name = "Yaroslav Petrov", email = "yaroslav.v.petrov@gmail.com" }] diff --git a/src/asyncapi_python/contrib/codec/json.py b/src/asyncapi_python/contrib/codec/json.py index 9872660..ce5aded 100644 --- a/src/asyncapi_python/contrib/codec/json.py +++ b/src/asyncapi_python/contrib/codec/json.py @@ -1,4 +1,5 @@ import json +from enum import Enum from types import ModuleType from typing import ClassVar, Type @@ -27,6 +28,51 @@ def decode(self, payload: bytes) -> BaseModel: except (json.JSONDecodeError, ValidationError, UnicodeDecodeError) as e: raise ValueError(f"Failed to decode JSON payload: {e}") + def extract_field(self, payload: BaseModel, location: str) -> str: + """Extract field from Pydantic model using location path. + + Examples: + "$message.payload#/userId" → payload.userId → "123" + "$message.payload#/user/id" → payload.user.id → "456" + "$message.payload#/items" → payload.items → "[1, 2, 3]" + + Args: + payload: Pydantic BaseModel instance + location: Location expression like "$message.payload#/userId" + + Returns: + str: Extracted value converted to string + + Raises: + ValueError: If location format is invalid or path doesn't exist in payload + """ + # Parse location: "$message.payload#/userId" → "/userId" + if not location.startswith("$message.payload#/"): + raise ValueError(f"Invalid location format: {location}") + + path = location.replace("$message.payload#/", "") + + # Navigate path: "user/id" → ["user", "id"] + parts = [p for p in path.split("/") if p] + + try: + value = payload + for part in parts: + value = getattr(value, part) + + # Convert to string + if isinstance(value, (str, int, float, bool)): + return str(value) + elif isinstance(value, Enum): + # Handle Enum types - extract the value attribute + return str(value.value) + else: + # Complex types: JSON serialize + return json.dumps(value) + + except AttributeError as e: + raise ValueError(f"Path '{path}' not found in payload: {e}") + class JsonCodecFactory(CodecFactory[BaseModel, bytes]): """Factory for creating JSON codecs for Pydantic models diff --git a/src/asyncapi_python/contrib/wire/amqp/resolver.py b/src/asyncapi_python/contrib/wire/amqp/resolver.py index a4e5c0c..4054c11 100644 --- a/src/asyncapi_python/contrib/wire/amqp/resolver.py +++ b/src/asyncapi_python/contrib/wire/amqp/resolver.py @@ -1,13 +1,79 @@ """Binding resolution with comprehensive pattern matching""" +import re from typing import Any from asyncapi_python.kernel.document.bindings import AmqpChannelBinding from asyncapi_python.kernel.document.channel import Channel from asyncapi_python.kernel.wire import EndpointParams +from asyncapi_python.kernel.wire.utils import ( + substitute_parameters, + validate_parameters_strict, +) from .config import AmqpBindingType, AmqpConfig -from .utils import substitute_parameters, validate_parameters_strict + + +def _validate_no_wildcards_in_queue(param_values: dict[str, str]) -> None: + """Validate that parameter values don't contain AMQP wildcards when using queue bindings. + + AMQP queue names are literal - they don't support pattern matching. + Only topic exchange routing keys support wildcards (* and #). + + Args: + param_values: Dictionary of parameter values to check + + Raises: + ValueError: If any parameter value contains wildcard characters + """ + wildcards_found: list[str] = [] + for param_name, param_value in param_values.items(): + if "*" in param_value or "#" in param_value: + wildcards_found.append(f"{param_name}={param_value}") + + if wildcards_found: + raise ValueError( + f"AMQP queue bindings do not support wildcard patterns ('*' or '#'). " + f"Found wildcards in parameters: {', '.join(wildcards_found)}. " + f"Use 'is: routingKey' with a topic exchange for pattern matching, " + f"or provide concrete parameter values for queue bindings." + ) + + +def _substitute_routing_key_with_wildcards( + template: str, param_values: dict[str, str] +) -> str: + """ + Substitute parameters in routing key template, using wildcards for missing parameters. + + For topic exchange bindings, missing parameters are replaced with '*' (single-word wildcard). + If no parameters are provided and template has placeholders, all are replaced with '*'. + Parameter values can also explicitly contain wildcards ('*' or '#'). + + Args: + template: Template string with {param} placeholders (e.g., "weather.{location}.{severity}") + param_values: Dictionary of parameter values (can be empty, partial, or contain wildcards) + + Returns: + Resolved routing key with wildcards for missing parameters + + Examples: + - template="weather.{location}.{severity}", params={} → "weather.*.*" + - template="weather.{location}.{severity}", params={"location": "NYC"} → "weather.NYC.*" + - template="weather.{location}.{severity}", params={"severity": "#"} → "weather.*.#" + """ + # Find all {param} placeholders + placeholders = re.findall(r"\{(\w+)\}", template) + + # Build substitution dict - use provided value or '*' wildcard for missing params + substitutions = {p: param_values.get(p, "*") for p in placeholders} + + # Perform substitution + result = template + for key, value in substitutions.items(): + result = result.replace(f"{{{key}}}", value) + + return result def resolve_amqp_config( @@ -27,10 +93,7 @@ def resolve_amqp_config( param_values = params["parameters"] or {} is_reply = params["is_reply"] - # Strict parameter validation first - validate_parameters_strict(channel, param_values) - - # Extract AMQP binding if present + # Extract AMQP binding if present (validation will be done later based on binding type) amqp_binding = None if channel.bindings and hasattr(channel.bindings, "amqp") and channel.bindings.amqp: amqp_binding = channel.bindings.amqp @@ -123,6 +186,9 @@ def resolve_amqp_config( # Channel address pattern (with parameter substitution) case (False, None, address, _) if address: + # Strict validation for implicit queue binding + validate_parameters_strict(channel, param_values) + _validate_no_wildcards_in_queue(param_values) resolved_address = substitute_parameters(address, param_values) return AmqpConfig( queue_name=resolved_address, @@ -134,6 +200,9 @@ def resolve_amqp_config( # Operation name pattern (fallback) case (False, None, None, op_name) if op_name: + # Strict validation for implicit queue binding + validate_parameters_strict(channel, param_values) + _validate_no_wildcards_in_queue(param_values) return AmqpConfig( queue_name=op_name, exchange_name="", # Default exchange @@ -156,7 +225,18 @@ def resolve_queue_binding( channel: Channel, operation_name: str, ) -> AmqpConfig: - """Resolve AMQP queue binding configuration""" + """Resolve AMQP queue binding configuration + + Queue bindings require: + - All channel parameters must be provided (strict validation) + - No wildcards allowed in parameter values + """ + + # Strict validation: all parameters required, exact match + validate_parameters_strict(channel, param_values) + + # Validate no wildcards in queue binding parameters + _validate_no_wildcards_in_queue(param_values) # Determine queue name with precedence match (getattr(binding, "queue", None), channel.address, operation_name): @@ -197,9 +277,20 @@ def resolve_routing_key_binding( channel: Channel, operation_name: str, ) -> AmqpConfig: - """Resolve AMQP routing key binding configuration for pub/sub patterns""" + """Resolve AMQP routing key binding configuration for pub/sub patterns + + For routing key bindings: + - All channel-defined parameters must be provided (strict validation) + - Parameter values can explicitly contain wildcards ('*' or '#') + - Wildcards are allowed for topic exchange pattern matching + """ + + # Strict validation: all parameters required, exact match + validate_parameters_strict(channel, param_values) # Determine exchange name and type + # For exchange name, we need concrete values (no wildcards) + # If param_values has placeholders, use them; otherwise use literal exchange name exchange_config = getattr(binding, "exchange", None) match ( exchange_config and getattr(exchange_config, "name", None), @@ -207,9 +298,14 @@ def resolve_routing_key_binding( operation_name, ): case (exchange_name, _, _) if exchange_name: - resolved_exchange = substitute_parameters(exchange_name, param_values) + # Exchange name should be literal (no parameter substitution for exchange names) + resolved_exchange = exchange_name case (None, address, _) if address: - resolved_exchange = substitute_parameters(address, param_values) + # If address is used for exchange, check if it has parameters + # If it does, use wildcards; if not, use as-is + resolved_exchange = _substitute_routing_key_with_wildcards( + address, param_values + ) case (None, None, op_name) if op_name: resolved_exchange = op_name case _: @@ -220,12 +316,18 @@ def resolve_routing_key_binding( if exchange_config and hasattr(exchange_config, "type"): exchange_type = exchange_config.type - # Determine routing key + # Determine routing key - this is where wildcards are allowed match (getattr(binding, "routingKey", None), channel.address, operation_name): case (routing_key, _, _) if routing_key: - resolved_routing_key = substitute_parameters(routing_key, param_values) + # Use wildcard substitution for routing keys + resolved_routing_key = _substitute_routing_key_with_wildcards( + routing_key, param_values + ) case (None, address, _) if address: - resolved_routing_key = substitute_parameters(address, param_values) + # Use wildcard substitution for routing keys from address + resolved_routing_key = _substitute_routing_key_with_wildcards( + address, param_values + ) case (None, None, op_name) if op_name: resolved_routing_key = op_name case _: diff --git a/src/asyncapi_python/contrib/wire/amqp/utils.py b/src/asyncapi_python/contrib/wire/amqp/utils.py index b28bfbd..1d10bdf 100644 --- a/src/asyncapi_python/contrib/wire/amqp/utils.py +++ b/src/asyncapi_python/contrib/wire/amqp/utils.py @@ -1,60 +1,19 @@ -"""Parameter validation and substitution utilities""" - -# TODO: This thing should be general wire utils, not tied to specific wire +"""AMQP-specific parameter validation utilities""" import re from asyncapi_python.kernel.document.channel import Channel - - -def validate_parameters_strict(channel: Channel, provided: dict[str, str]) -> None: - """ - Strict parameter validation - all defined parameters must be provided. - Raises ValueError with detailed message if any parameters are missing. - """ - if not channel.parameters: - return # No parameters defined, nothing to validate - - required = set(channel.parameters.keys()) - provided_keys = set(provided.keys()) - - missing = required - provided_keys - if missing: - raise ValueError( - f"Missing required parameters for channel '{channel.address}': {missing}. " - f"Required: {sorted(required)}, Provided: {sorted(provided_keys)}" - ) - - extra = provided_keys - required - if extra: - raise ValueError( - f"Unexpected parameters for channel '{channel.address}': {extra}. " - f"Expected: {sorted(required)}, Provided: {sorted(provided_keys)}" - ) - - -def substitute_parameters(template: str, parameters: dict[str, str]) -> str: - """ - Substitute {param} placeholders with actual values. - All placeholders must have corresponding parameter values. - """ - # Find all {param} placeholders - placeholders = re.findall(r"\{(\w+)\}", template) - - # Check for undefined placeholders - undefined = [p for p in placeholders if p not in parameters] - if undefined: - raise ValueError( - f"Template '{template}' references undefined parameters: {undefined}. " - f"Available parameters: {sorted(parameters.keys())}" - ) - - # Perform substitution - result = template - for key, value in parameters.items(): - result = result.replace(f"{{{key}}}", value) - - return result +from asyncapi_python.kernel.wire.utils import ( + substitute_parameters, + validate_parameters_strict, +) + +# Re-export for backward compatibility +__all__ = [ + "validate_parameters_strict", + "substitute_parameters", + "validate_channel_template", +] def validate_channel_template( diff --git a/src/asyncapi_python/kernel/codec.py b/src/asyncapi_python/kernel/codec.py index ed66fab..842259e 100644 --- a/src/asyncapi_python/kernel/codec.py +++ b/src/asyncapi_python/kernel/codec.py @@ -12,6 +12,21 @@ def encode(self, payload: T_DecodedPayload) -> T_EncodedPayload: ... def decode(self, payload: T_EncodedPayload) -> T_DecodedPayload: ... + def extract_field(self, payload: T_DecodedPayload, location: str) -> str: + """Extract field value from decoded payload using location expression. + + Args: + payload: Decoded payload (Pydantic model, Protobuf object, etc.) + location: Location expression like "$message.payload#/userId" + + Returns: + str: Extracted value converted to string + + Raises: + ValueError: If location path doesn't exist in payload + """ + ... + class CodecFactory(ABC, Generic[T_DecodedPayload, T_EncodedPayload]): """A codec factory diff --git a/src/asyncapi_python/kernel/endpoint/abc/__init__.py b/src/asyncapi_python/kernel/endpoint/abc/__init__.py new file mode 100644 index 0000000..d3ef77a --- /dev/null +++ b/src/asyncapi_python/kernel/endpoint/abc/__init__.py @@ -0,0 +1,16 @@ +"""Abstract base classes and interfaces for endpoints. + +This module provides the foundational abstractions for all endpoint implementations. +""" + +from .base import AbstractEndpoint +from .interfaces import Receive, Send +from .params import EndpointParams, HandlerParams + +__all__ = [ + "AbstractEndpoint", + "EndpointParams", + "HandlerParams", + "Receive", + "Send", +] diff --git a/src/asyncapi_python/kernel/endpoint/abc.py b/src/asyncapi_python/kernel/endpoint/abc/base.py similarity index 57% rename from src/asyncapi_python/kernel/endpoint/abc.py rename to src/asyncapi_python/kernel/endpoint/abc/base.py index 60d5581..d7669c2 100644 --- a/src/asyncapi_python/kernel/endpoint/abc.py +++ b/src/asyncapi_python/kernel/endpoint/abc/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Callable, Generic, TypedDict, Union, overload +from typing import Any, Callable, TypedDict from typing_extensions import NotRequired, Required, Unpack @@ -7,23 +7,9 @@ from asyncapi_python.kernel.document import Operation from asyncapi_python.kernel.wire import AbstractWireFactory -from ..typing import BatchConfig, Handler, T_Input, T_Output +from .params import EndpointParams - -class EndpointParams(TypedDict, total=False): - """Optional parameters for endpoint configuration""" - - service_name: str # Service name for generating app_id - default_rpc_timeout: ( - float | None - ) # Default timeout in seconds for RPC client requests (default: 180.0), or None to disable - disable_handler_validation: bool # Opt-out of handler enforcement for testing - - -class HandlerParams(TypedDict): - """Parameters for message handlers""" - - pass # Currently empty, but extensible for future parameters like queue, routing_key, etc. +__all__ = ["AbstractEndpoint"] class AbstractEndpoint(ABC): @@ -108,56 +94,77 @@ def _try_codecs( f"Failed to {operation} payload with any available codec. Last error: {last_error}" ) + def _extract_parameters(self, payload: Any) -> dict[str, str]: + """Extract channel parameters from decoded payload. + + Uses the channel parameter definitions to extract values from the payload + using the codec's extract_field method. Parameters without a location are skipped. + + Args: + payload: The decoded message payload + + Returns: + Dictionary mapping parameter names to extracted string values + + Raises: + ValueError: If parameter extraction fails for any parameter + """ + parameters: dict[str, str] = {} + for param_name, param_def in self._operation.channel.parameters.items(): + if param_def.location: + try: + # Use first codec (all should extract consistently) + codec = self._codecs[0] + value = codec.extract_field(payload, param_def.location) + parameters[param_name] = value + except ValueError as e: + raise ValueError(f"Failed to extract parameter '{param_name}': {e}") + return parameters + + def _build_address(self, parameters: dict[str, str]) -> str: + """Build address from channel template and parameters. + + Replaces {param_name} placeholders in the channel address with the + corresponding parameter values. + + Args: + parameters: Dictionary of parameter names to values + + Returns: + The fully resolved address string + + Raises: + ValueError: If channel address is None + """ + address = self._operation.channel.address + if address is None: + raise ValueError( + "Channel address is None, cannot build parameterized address" + ) + for param_name, param_value in parameters.items(): + address = address.replace(f"{{{param_name}}}", param_value) + return address + + def _build_address_with_parameters(self, payload: Any) -> str | None: + """Extract parameters from payload and build address if needed. + + Convenience method that extracts parameters and builds the address in one call. + Returns None if no parameters are defined or extracted. + + Args: + payload: The decoded message payload + + Returns: + The resolved address string, or None if no parameters to extract + + Raises: + ValueError: If parameter extraction or address building fails + """ + parameters = self._extract_parameters(payload) + return self._build_address(parameters) if parameters else None + @abstractmethod async def start(self, **params: Unpack[StartParams]) -> None: ... @abstractmethod async def stop(self) -> None: ... - - -class Send(ABC, Generic[T_Input, T_Output]): - """An interface that sending endpoint implements""" - - class RouterInputs(TypedDict): - """Base inputs for send endpoints. Router subclasses can extend this with specific parameters.""" - - pass # Empty for now, extensible for future fields - - @abstractmethod - async def __call__( - self, payload: T_Input, /, **kwargs: Unpack[RouterInputs] - ) -> T_Output: ... - - -class Receive(ABC, Generic[T_Input, T_Output]): - - @overload - def __call__( - self, fn: Handler[T_Input, T_Output] - ) -> Handler[T_Input, T_Output]: ... - - @overload - def __call__( - self, - fn: None = None, - *, - batch: BatchConfig, - **kwargs: Unpack[HandlerParams], - ) -> Callable[[Handler[T_Input, T_Output]], Handler[T_Input, T_Output]]: ... - - @overload - def __call__( - self, fn: None = None, **kwargs: Unpack[HandlerParams] - ) -> Callable[[Handler[T_Input, T_Output]], Handler[T_Input, T_Output]]: ... - - @abstractmethod - def __call__( - self, - fn: Handler[T_Input, T_Output] | None = None, - *, - batch: BatchConfig | None = None, - **kwargs: Unpack[HandlerParams], - ) -> Union[ - Handler[T_Input, T_Output], - Callable[[Handler[T_Input, T_Output]], Handler[T_Input, T_Output]], - ]: ... diff --git a/src/asyncapi_python/kernel/endpoint/abc/interfaces.py b/src/asyncapi_python/kernel/endpoint/abc/interfaces.py new file mode 100644 index 0000000..62057a3 --- /dev/null +++ b/src/asyncapi_python/kernel/endpoint/abc/interfaces.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import Callable, Generic, TypedDict, Union, overload + +from typing_extensions import Unpack + +from ...typing import BatchConfig, Handler, T_Input, T_Output +from .params import HandlerParams + +__all__ = ["Send", "Receive"] + + +class Send(ABC, Generic[T_Input, T_Output]): + """An interface that sending endpoint implements""" + + class RouterInputs(TypedDict): + """Base inputs for send endpoints. Router subclasses can extend this with specific parameters.""" + + pass # Empty for now, extensible for future fields + + @abstractmethod + async def __call__( + self, payload: T_Input, /, **kwargs: Unpack[RouterInputs] + ) -> T_Output: ... + + +class Receive(ABC, Generic[T_Input, T_Output]): + + @overload + def __call__( + self, fn: Handler[T_Input, T_Output] + ) -> Handler[T_Input, T_Output]: ... + + @overload + def __call__( + self, + fn: None = None, + *, + batch: BatchConfig, + **kwargs: Unpack[HandlerParams], + ) -> Callable[[Handler[T_Input, T_Output]], Handler[T_Input, T_Output]]: ... + + @overload + def __call__( + self, fn: None = None, **kwargs: Unpack[HandlerParams] + ) -> Callable[[Handler[T_Input, T_Output]], Handler[T_Input, T_Output]]: ... + + @abstractmethod + def __call__( + self, + fn: Handler[T_Input, T_Output] | None = None, + *, + batch: BatchConfig | None = None, + **kwargs: Unpack[HandlerParams], + ) -> Union[ + Handler[T_Input, T_Output], + Callable[[Handler[T_Input, T_Output]], Handler[T_Input, T_Output]], + ]: ... diff --git a/src/asyncapi_python/kernel/endpoint/abc/params.py b/src/asyncapi_python/kernel/endpoint/abc/params.py new file mode 100644 index 0000000..a9da127 --- /dev/null +++ b/src/asyncapi_python/kernel/endpoint/abc/params.py @@ -0,0 +1,22 @@ +from typing import TypedDict +from typing_extensions import NotRequired + +__all__ = ["EndpointParams", "HandlerParams"] + + +class EndpointParams(TypedDict, total=False): + """Optional parameters for endpoint configuration""" + + service_name: str # Service name for generating app_id + default_rpc_timeout: ( + float | None + ) # Default timeout in seconds for RPC client requests (default: 180.0), or None to disable + disable_handler_validation: bool # Opt-out of handler enforcement for testing + + +class HandlerParams(TypedDict): + """Parameters for message handlers""" + + parameters: NotRequired[ + dict[str, str] + ] # Channel parameter values for subscription (e.g., {"location": "*", "severity": "high"}) diff --git a/src/asyncapi_python/kernel/endpoint/publisher.py b/src/asyncapi_python/kernel/endpoint/publisher.py index 1ebb265..754a272 100644 --- a/src/asyncapi_python/kernel/endpoint/publisher.py +++ b/src/asyncapi_python/kernel/endpoint/publisher.py @@ -60,6 +60,9 @@ async def __call__( if not self._producer: raise UninitializedError() + # Extract parameters and build address (if parameters exist) + address_override = self._build_address_with_parameters(payload) + # Encode payload using main message codecs encoded_payload = self._encode_message(payload) @@ -69,4 +72,6 @@ async def __call__( ) # Send via producer - await self._producer.send_batch([wire_message]) + await self._producer.send_batch( + [wire_message], address_override=address_override + ) diff --git a/src/asyncapi_python/kernel/endpoint/rpc_client.py b/src/asyncapi_python/kernel/endpoint/rpc_client.py index a87fb5b..fb771cd 100644 --- a/src/asyncapi_python/kernel/endpoint/rpc_client.py +++ b/src/asyncapi_python/kernel/endpoint/rpc_client.py @@ -122,6 +122,9 @@ async def __call__( # type: ignore[override] ) try: + # Extract parameters and build address (if parameters exist) + address_override = self._build_address_with_parameters(payload) + # Encode request payload encoded_payload: bytes = self._encode_message(payload) @@ -134,7 +137,9 @@ async def __call__( # type: ignore[override] ) # Send request - await self._producer.send_batch([wire_message]) + await self._producer.send_batch( + [wire_message], address_override=address_override + ) # Wait for response with timeout (handled by global background task) try: diff --git a/src/asyncapi_python/kernel/endpoint/rpc_server.py b/src/asyncapi_python/kernel/endpoint/rpc_server.py index de396ea..d304113 100644 --- a/src/asyncapi_python/kernel/endpoint/rpc_server.py +++ b/src/asyncapi_python/kernel/endpoint/rpc_server.py @@ -36,6 +36,9 @@ def __init__(self, **kwargs: Unpack[AbstractEndpoint.Inputs]): self._handler_location: str | None = None self._batch_config: BatchConfig | None = None self._consume_task: asyncio.Task[None] | None = None + self._subscription_parameters: dict[str, str] = ( + {} + ) # Parameters for subscription (wildcards or concrete values) async def start(self, **params: Unpack[AbstractEndpoint.StartParams]) -> None: """Initialize the RPC server endpoint""" @@ -63,7 +66,7 @@ async def start(self, **params: Unpack[AbstractEndpoint.StartParams]) -> None: # Create consumer for receiving requests self._consumer = await self._wire.create_consumer( channel=self._operation.channel, - parameters={}, + parameters=self._subscription_parameters, op_bindings=self._operation.bindings, is_reply=False, ) @@ -218,6 +221,10 @@ def _register_handler( f"Each RPC server endpoint must have exactly one handler." ) + # Extract subscription parameters if provided + if "parameters" in params: + self._subscription_parameters = params["parameters"] + # Determine if this is a batch handler by checking if batch config exists if batch_config is not None: self._batch_handler = handler # type: ignore diff --git a/src/asyncapi_python/kernel/endpoint/subscriber.py b/src/asyncapi_python/kernel/endpoint/subscriber.py index d37ecad..b153feb 100644 --- a/src/asyncapi_python/kernel/endpoint/subscriber.py +++ b/src/asyncapi_python/kernel/endpoint/subscriber.py @@ -23,6 +23,9 @@ def __init__(self, **kwargs: Unpack[AbstractEndpoint.Inputs]): self._handler_location: str | None = None self._batch_config: BatchConfig | None = None self._consume_task: asyncio.Task[None] | None = None + self._subscription_parameters: dict[str, str] = ( + {} + ) # Parameters for subscription (wildcards or concrete values) async def start(self, **params: Unpack[AbstractEndpoint.StartParams]) -> None: """Initialize the subscriber endpoint""" @@ -46,7 +49,7 @@ async def start(self, **params: Unpack[AbstractEndpoint.StartParams]) -> None: # Create consumer from wire factory self._consumer = await self._wire.create_consumer( channel=self._operation.channel, - parameters={}, + parameters=self._subscription_parameters, op_bindings=self._operation.bindings, is_reply=False, ) @@ -166,6 +169,10 @@ def _register_handler( f"Each subscriber endpoint must have exactly one handler." ) + # Extract subscription parameters if provided + if "parameters" in params: + self._subscription_parameters = params["parameters"] + # Determine if this is a batch handler by checking if batch config exists if batch_config is not None: self._batch_handler = handler # type: ignore diff --git a/src/asyncapi_python/kernel/wire/utils.py b/src/asyncapi_python/kernel/wire/utils.py new file mode 100644 index 0000000..99e597f --- /dev/null +++ b/src/asyncapi_python/kernel/wire/utils.py @@ -0,0 +1,61 @@ +"""Universal parameter validation utilities for all wire implementations.""" + +import re + +from asyncapi_python.kernel.document.channel import Channel + + +def validate_parameters_strict(channel: Channel, provided: dict[str, str]) -> None: + """ + Strict parameter validation - exact match required. + Raises ValueError if any parameters are missing or unexpected. + """ + required = set(channel.parameters.keys() if channel.parameters else []) + provided_keys = set(provided.keys()) + + # Check for missing parameters (only if channel has params defined) + if channel.parameters: + missing = required - provided_keys + if missing: + raise ValueError( + f"Missing required parameters for channel '{channel.address}': {missing}. " + f"Required: {sorted(required)}, Provided: {sorted(provided_keys)}" + ) + + # Check for extra parameters (ALWAYS - even if channel has no params) + extra = provided_keys - required + if extra: + if channel.parameters: + raise ValueError( + f"Unexpected parameters for channel '{channel.address}': {extra}. " + f"Expected: {sorted(required)}, Provided: {sorted(provided_keys)}" + ) + else: + raise ValueError( + f"Unexpected parameters for channel '{channel.address}': {extra}. " + f"Channel defines no parameters, but received: {sorted(provided_keys)}" + ) + + +def substitute_parameters(template: str, parameters: dict[str, str]) -> str: + """ + Substitute {param} placeholders with actual values. + All placeholders must have corresponding parameter values. + """ + # Find all {param} placeholders + placeholders = re.findall(r"\{(\w+)\}", template) + + # Check for undefined placeholders + undefined = [p for p in placeholders if p not in parameters] + if undefined: + raise ValueError( + f"Template '{template}' references undefined parameters: {undefined}. " + f"Available parameters: {sorted(parameters.keys())}" + ) + + # Perform substitution + result = template + for key, value in parameters.items(): + result = result.replace(f"{{{key}}}", value) + + return result diff --git a/src/asyncapi_python_codegen/generators/main.py b/src/asyncapi_python_codegen/generators/main.py index 3e1c580..4e6e753 100644 --- a/src/asyncapi_python_codegen/generators/main.py +++ b/src/asyncapi_python_codegen/generators/main.py @@ -55,11 +55,9 @@ def generate(self, spec_path: Path, output_dir: Path, force: bool = False) -> No ) # Generate parameter TypedDicts for parameterized channels - import yaml - - with spec_path.open() as f: - spec = yaml.safe_load(f) - parameter_models_code = self.parameter_generator.generate_parameter_models(spec) + parameter_models_code = self.parameter_generator.generate_parameter_models( + list(operations.values()) + ) # Legacy compatibility - extract messages for router generation messages = self.message_generator.extract_messages(operations) diff --git a/src/asyncapi_python_codegen/generators/parameters.py b/src/asyncapi_python_codegen/generators/parameters.py index 0379584..f482a00 100644 --- a/src/asyncapi_python_codegen/generators/parameters.py +++ b/src/asyncapi_python_codegen/generators/parameters.py @@ -11,40 +11,65 @@ class ParameterGenerator: """Generates TypedDict classes for channel parameters.""" - def generate_parameter_models(self, spec: dict[str, Any]) -> str: - """Generate TypedDict models for all channel parameters.""" - channels = spec.get("channels", {}) + def generate_parameter_models(self, operations: list[Any]) -> str: + """Generate TypedDict models from operations' resolved channels.""" + from asyncapi_python.kernel.document import Operation + parameter_schemas: dict[str, Any] = {} - # Collect all parameter definitions from channels - for channel_name, channel_def in channels.items(): - if "{" in channel_name and "parameters" in channel_def: - # Generate TypedDict name from channel pattern - dict_name = self._channel_to_dict_name(channel_name) + # Collect unique parameterized channels from all operations + seen_addresses: set[str] = set() + + for op in operations: + if not isinstance(op, Operation): + continue + + channel = op.channel + if not channel or not channel.address: + continue + + # Check if channel has parameters in address + if "{" not in channel.address or "}" not in channel.address: + continue + + # Skip if we've already processed this channel address + if channel.address in seen_addresses: + continue + seen_addresses.add(channel.address) + + # Skip if channel has no parameters defined + if not channel.parameters: + continue - # Build schema for this channel's parameters - properties: dict[str, Any] = {} - required: list[str] = [] + # Generate TypedDict name from channel address pattern + dict_name = self._channel_to_dict_name(channel.address) - for param_name, param_def in channel_def["parameters"].items(): - # Skip parameters that have a 'location' field - if isinstance(param_def, dict) and "location" in param_def: - continue + # Build schema for this channel's parameters + properties: dict[str, Any] = {} + required: list[str] = [] + for param_name, param_obj in channel.parameters.items(): + # For parameters with 'location' field (used by publishers for extraction), + # generate as 'str' type for subscriber wildcard support + if hasattr(param_obj, "location") and param_obj.location: + properties[param_name] = {"type": "string"} + else: # Convert parameter definition to JSON Schema property - properties[param_name] = self._param_to_schema(param_def) # type: ignore[arg-type] - # All channel parameters are required - required.append(param_name) - - # Only create TypedDict if there are properties after filtering - if properties: - parameter_schemas[dict_name] = { - "type": "object", - "properties": properties, - "required": required, - "additionalProperties": False, - "title": dict_name, - } + # For AddressParameter objects without location, use default schema + properties[param_name] = {"type": "string"} + + # All channel parameters are required + required.append(param_name) + + # Only create TypedDict if there are properties + if properties: + parameter_schemas[dict_name] = { + "type": "object", + "properties": properties, + "required": required, + "additionalProperties": False, + "title": dict_name, + } if not parameter_schemas: return self._generate_empty_parameters() diff --git a/src/asyncapi_python_codegen/parser/document_loader.py b/src/asyncapi_python_codegen/parser/document_loader.py index 7192237..c391202 100644 --- a/src/asyncapi_python_codegen/parser/document_loader.py +++ b/src/asyncapi_python_codegen/parser/document_loader.py @@ -9,11 +9,15 @@ from .references import load_yaml_file -def extract_all_operations(yaml_path: Path) -> dict[str, Operation]: +def extract_all_operations( + yaml_path: Path, validate: bool = True, fail_on_error: bool = True +) -> dict[str, Operation]: """Extract all operations from AsyncAPI document. Args: yaml_path: Path to AsyncAPI YAML file + validate: Whether to run validation rules (default: True) + fail_on_error: Whether to raise ValidationError on errors (default: True) Returns: Dictionary mapping operation IDs to Operation dataclasses @@ -21,6 +25,7 @@ def extract_all_operations(yaml_path: Path) -> dict[str, Operation]: Raises: RuntimeError: If file cannot be loaded or parsed ValueError: If document structure is invalid + ValidationError: If validation fails and fail_on_error is True """ # Load the main document with parsing_context(yaml_path): @@ -66,6 +71,30 @@ def extract_all_operations(yaml_path: Path) -> dict[str, Operation]: f"Failed to extract operation '{operation_id}': {e}" ) from e + # Run validation if enabled + if validate: + from ..validation import Severity, ValidationError, validate_spec + + try: + issues = validate_spec( + spec=document, + operations=operations, + spec_path=yaml_path, + fail_on_error=fail_on_error, + ) + + # Print warnings and info even if not failing + for issue in issues: + if issue.severity in (Severity.WARNING, Severity.INFO): + print(f" {issue}") + + except ValidationError as e: + # Print all issues before re-raising + print("\n❌ Document validation failed:\n") + for issue in e.issues: + print(f" {issue}\n") + raise + return operations diff --git a/src/asyncapi_python_codegen/parser/extractors.py b/src/asyncapi_python_codegen/parser/extractors.py index b81230f..4bd45ff 100644 --- a/src/asyncapi_python_codegen/parser/extractors.py +++ b/src/asyncapi_python_codegen/parser/extractors.py @@ -59,7 +59,7 @@ def extract_address_parameter(data: YamlDocument) -> AddressParameter: """Extract AddressParameter from YAML data.""" return AddressParameter( description=data.get("description"), - location=data.get("location", ""), + location=data.get("location", ""), # Validation will catch if missing key="", # TODO: Pass actual parameter key from extraction context ) diff --git a/src/asyncapi_python_codegen/validation/__init__.py b/src/asyncapi_python_codegen/validation/__init__.py new file mode 100644 index 0000000..ca19d3a --- /dev/null +++ b/src/asyncapi_python_codegen/validation/__init__.py @@ -0,0 +1,75 @@ +"""AsyncAPI document validation system. + +This module provides a pluggable validation system for AsyncAPI 3.0 documents. +All validation is performed through rules that can be extended or disabled. + +Example: + from asyncapi_python_codegen.validation import validate_spec, ValidationError + + try: + issues = validate_spec(spec, operations, spec_path) + except ValidationError as e: + for error in e.errors: + print(f"{error.path}: {error.message}") +""" + +from pathlib import Path +from typing import Any + +from asyncapi_python.kernel.document import Operation + +from .base import get_registry, rule +from .context import ValidationContext +from .errors import Severity, ValidationError, ValidationIssue + +__all__ = [ + "validate_spec", + "rule", + "ValidationError", + "ValidationIssue", + "Severity", + "ValidationContext", +] + + +def validate_spec( + spec: dict[str, Any], + operations: dict[str, Operation] | None, + spec_path: Path, + categories: list[str] | None = None, + fail_on_error: bool = True, +) -> list[ValidationIssue]: + """ + Validate an AsyncAPI spec using registered rules. + + Args: + spec: Raw AsyncAPI YAML as dict + operations: Parsed operations (None if parsing failed) + spec_path: Path to the YAML file + categories: List of rule categories to run (None = all) + fail_on_error: If True, raise ValidationError on errors + + Returns: + List of all validation issues (errors, warnings, info) + + Raises: + ValidationError: If validation fails and fail_on_error is True + """ + # Import core rules to ensure they're registered + from . import core # noqa: F401 # pyright: ignore[reportUnusedImport] + + # Import protocol-specific rules (AMQP, etc.) + from . import protocol # noqa: F401 # pyright: ignore[reportUnusedImport] + + # Create validation context + context = ValidationContext(spec=spec, spec_path=spec_path, operations=operations) + + # Run validation + registry = get_registry() + issues = registry.validate(context, categories) + + # Raise on errors if requested + if fail_on_error and any(issue.severity == Severity.ERROR for issue in issues): + raise ValidationError(issues) + + return issues diff --git a/src/asyncapi_python_codegen/validation/base.py b/src/asyncapi_python_codegen/validation/base.py new file mode 100644 index 0000000..20cb9b2 --- /dev/null +++ b/src/asyncapi_python_codegen/validation/base.py @@ -0,0 +1,122 @@ +"""Base validation infrastructure: registry and decorators.""" + +from typing import Callable + +from .context import ValidationContext +from .errors import Severity, ValidationIssue + +# Type alias for validation rule functions +RuleFunction = Callable[[ValidationContext], list[ValidationIssue]] + + +class RuleRegistry: + """Registry for validation rules organized by category.""" + + def __init__(self): + """Initialize empty registry.""" + self._rules: dict[str, list[RuleFunction]] = {} + + def register(self, category: str, rule_func: RuleFunction) -> RuleFunction: + """ + Register a validation rule in a category. + + Args: + category: Category name (e.g., "core", "protocol.amqp", "codegen") + rule_func: Function that validates and returns issues + + Returns: + The rule function (for decorator pattern) + """ + if category not in self._rules: + self._rules[category] = [] + + self._rules[category].append(rule_func) + return rule_func + + def get_rules(self, category: str) -> list[RuleFunction]: + """Get all rules in a category.""" + return self._rules.get(category, []) + + def get_all_categories(self) -> list[str]: + """Get list of all registered categories.""" + return list(self._rules.keys()) + + def validate( + self, context: ValidationContext, categories: list[str] | None = None + ) -> list[ValidationIssue]: + """ + Run validation rules and collect issues. + + Args: + context: Validation context with document data + categories: List of categories to validate (None = default: ["core", "protocol.amqp"]) + + Returns: + List of all validation issues found + """ + issues: list[ValidationIssue] = [] + + # Determine which categories to run + if categories is None: + # Default: validate core rules + AMQP protocol rules + categories = ["core", "protocol.amqp"] + + # Run all rules in specified categories + for category in categories: + for rule_func in self.get_rules(category): + try: + rule_issues = rule_func(context) + issues.extend(rule_issues) + except Exception as e: + # If a rule crashes, report it as a validation issue + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Validation rule '{rule_func.__name__}' failed: {e}", + path="$", + rule="rule-execution-error", + ) + ) + + return issues + + +# Global registry instance +_global_registry = RuleRegistry() + + +def rule(*tags: str) -> Callable[[RuleFunction], RuleFunction]: + """ + Decorator to register a validation rule with one or more tags. + + Args: + *tags: One or more tag names (e.g., "core", "protocol.amqp", "requires-amqp") + + Returns: + Decorator function + + Example: + @rule("core") + def my_rule(ctx: ValidationContext) -> list[ValidationIssue]: + if "asyncapi" not in ctx.spec: + return [ValidationIssue(...)] + return [] + + @rule("core", "protocol.amqp") + def amqp_rule(ctx: ValidationContext) -> list[ValidationIssue]: + # AMQP-specific validation + return [] + """ + + def decorator(func: RuleFunction) -> RuleFunction: + # Register function under ALL provided tags + for tag in tags: + _global_registry.register(tag, func) + return func + + return decorator + + +def get_registry() -> RuleRegistry: + """Get the global rule registry.""" + return _global_registry diff --git a/src/asyncapi_python_codegen/validation/context.py b/src/asyncapi_python_codegen/validation/context.py new file mode 100644 index 0000000..21dde5e --- /dev/null +++ b/src/asyncapi_python_codegen/validation/context.py @@ -0,0 +1,36 @@ +"""Validation context providing access to document data.""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from asyncapi_python.kernel.document import Operation + + +@dataclass +class ValidationContext: + """Context passed to validation rules containing document data.""" + + spec: dict[str, Any] # Raw AsyncAPI YAML as dict + spec_path: Path # Path to the YAML file + operations: dict[str, Operation] | None = None # Parsed operations (if available) + + def get_channels(self) -> dict[str, Any]: + """Get all channels from the spec.""" + return self.spec.get("channels", {}) + + def get_operations_spec(self) -> dict[str, Any]: + """Get operations section from raw spec.""" + return self.spec.get("operations", {}) + + def get_components(self) -> dict[str, Any]: + """Get components section from spec.""" + return self.spec.get("components", {}) + + def get_info(self) -> dict[str, Any]: + """Get info section from spec.""" + return self.spec.get("info", {}) + + def get_servers(self) -> dict[str, Any]: + """Get servers section from spec.""" + return self.spec.get("servers", {}) diff --git a/src/asyncapi_python_codegen/validation/core/__init__.py b/src/asyncapi_python_codegen/validation/core/__init__.py new file mode 100644 index 0000000..9834641 --- /dev/null +++ b/src/asyncapi_python_codegen/validation/core/__init__.py @@ -0,0 +1,495 @@ +"""Core AsyncAPI 3.0 validation rules. + +These rules validate AsyncAPI 3.0 specification compliance and catch common errors. +Many of these rules address issues documented in BUG.md. +""" + +# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false + +import re +from typing import Any + +from ..base import rule +from ..context import ValidationContext +from ..errors import Severity, ValidationIssue + + +@rule("core") +def required_asyncapi_version(ctx: ValidationContext) -> list[ValidationIssue]: + """Validate that asyncapi field exists and is version 3.x.""" + if "asyncapi" not in ctx.spec: + return [ + ValidationIssue( + severity=Severity.ERROR, + message="Missing required 'asyncapi' field", + path="$", + rule="required-asyncapi-version", + suggestion="Add 'asyncapi: 3.0.0' at the root level", + ) + ] + + version = ctx.spec["asyncapi"] + if not isinstance(version, str) or not version.startswith("3."): + return [ + ValidationIssue( + severity=Severity.ERROR, + message=f"Unsupported AsyncAPI version: {version}", + path="$.asyncapi", + rule="required-asyncapi-version", + suggestion="This library supports AsyncAPI 3.x", + ) + ] + + return [] + + +@rule("core") +def required_operations_or_channels(ctx: ValidationContext) -> list[ValidationIssue]: + """Validate that at least operations or channels section exists.""" + has_operations = "operations" in ctx.spec and ctx.spec["operations"] + has_channels = "channels" in ctx.spec and ctx.spec["channels"] + + if not has_operations and not has_channels: + return [ + ValidationIssue( + severity=Severity.ERROR, + message="Document must have at least 'operations' or 'channels' section", + path="$", + rule="required-operations-or-channels", + ) + ] + + return [] + + +@rule("core") +def operations_is_dict(ctx: ValidationContext) -> list[ValidationIssue]: + """Validate that operations section is a dict.""" + if "operations" in ctx.spec: + if not isinstance(ctx.spec["operations"], dict): + return [ + ValidationIssue( + severity=Severity.ERROR, + message="'operations' must be an object/dict", + path="$.operations", + rule="operations-is-dict", + ) + ] + + return [] + + +@rule("core") +def channel_address_matches_parameters(ctx: ValidationContext) -> list[ValidationIssue]: + """ + Validate that channel address contains placeholders for all parameters. + + FIX BUG.md: This rule checks the 'address' field, not the channel key! + The parameter generator was incorrectly checking channel keys for {placeholders}. + """ + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + # CRITICAL FIX: Use 'address' field, not channel key! + address = channel_def.get("address", "") + parameters = channel_def.get("parameters", {}) + + if not parameters: + continue # No parameters to validate + + # Extract placeholders from address using regex + placeholders = set(re.findall(r"\{([^}]+)\}", address)) + param_names = set(parameters.keys()) + + # Check 1: All parameters should appear in address + missing_in_address = param_names - placeholders + if missing_in_address: + # Check if any have location field - they don't need to be in address + missing_without_location = { + name + for name in missing_in_address + if not parameters.get(name, {}).get("location") + } + + if missing_without_location: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Parameters defined but not used in address: {missing_without_location}", + path=f"$.channels.{channel_key}.parameters", + rule="channel-address-matches-parameters", + suggestion=f"Add to address: {address}.{{{', '.join(missing_without_location)}}}", + ) + ) + + # Check 2: All placeholders should have parameter definitions + undefined_params = placeholders - param_names + if undefined_params: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Address uses undefined parameters: {undefined_params}", + path=f"$.channels.{channel_key}.address", + rule="channel-address-matches-parameters", + suggestion="Define these parameters in the 'parameters' section", + ) + ) + + return issues + + +@rule("core") +def parameter_requires_location(ctx: ValidationContext) -> list[ValidationIssue]: + """All parameters MUST have a location field.""" + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + parameters = channel_def.get("parameters", {}) + for param_name, param_def in parameters.items(): + if not isinstance(param_def, dict): + continue + + location = param_def.get("location", "") + if not location: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Parameter '{param_name}' must have 'location' field", + path=f"$.channels.{channel_key}.parameters.{param_name}", + rule="parameter-requires-location", + suggestion="Add location: $message.payload#/fieldName", + ) + ) + + return issues + + +@rule("core") +def parameter_location_syntax_valid(ctx: ValidationContext) -> list[ValidationIssue]: + """ + Validate that parameter location fields use valid runtime expression syntax. + + FIX BUG.md: Validates location field syntax (though runtime extraction not implemented). + """ + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + parameters = channel_def.get("parameters", {}) + for param_name, param_def in parameters.items(): + if not isinstance(param_def, dict): + continue + + location = param_def.get("location") + if not location: + continue + + # Validate location syntax + if not isinstance(location, str): + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Parameter '{param_name}' location must be a string", + path=f"$.channels.{channel_key}.parameters.{param_name}.location", + rule="parameter-location-syntax-valid", + ) + ) + continue + + # Check for valid runtime expression pattern + # Valid: "$message.header#/userId", "$message.payload#/user/id" + if not location.startswith("$message."): + issues.append( + ValidationIssue( + severity=Severity.WARNING, + message=f"Parameter '{param_name}' location should start with '$message.'", + path=f"$.channels.{channel_key}.parameters.{param_name}.location", + rule="parameter-location-syntax-valid", + suggestion="Use format: $message.payload#/path or $message.header#/path", + ) + ) + + return issues + + +@rule("core") +def location_must_be_payload(ctx: ValidationContext) -> list[ValidationIssue]: + """Location must use $message.payload#/ format (headers not supported).""" + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + parameters = channel_def.get("parameters", {}) + for param_name, param_def in parameters.items(): + if not isinstance(param_def, dict): + continue + + location = param_def.get("location") + if location and not location.startswith("$message.payload#/"): + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Parameter '{param_name}' location must start with '$message.payload#/'", + path=f"$.channels.{channel_key}.parameters.{param_name}.location", + rule="location-must-be-payload", + suggestion="Use format: $message.payload#/path/to/field", + ) + ) + + return issues + + +@rule("core") +def location_path_exists_in_schema(ctx: ValidationContext) -> list[ValidationIssue]: + """Validate location path exists in message payload schemas.""" + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + parameters = channel_def.get("parameters", {}) + messages = channel_def.get("messages", {}) + + for param_name, param_def in parameters.items(): + if not isinstance(param_def, dict): + continue + + location = param_def.get("location") + if not location: + continue + + # Parse path from location + path = location.replace("$message.payload#/", "") + parts = [p for p in path.split("/") if p] + + # Check if path exists in ANY message schema + path_found = False + for msg_def in messages.values(): + if not isinstance(msg_def, dict): + continue + if _path_exists_in_schema(msg_def.get("payload"), parts): + path_found = True + break + + if not path_found and messages: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Parameter '{param_name}' location path '{path}' not found in message schemas", + path=f"$.channels.{channel_key}.parameters.{param_name}.location", + rule="location-path-exists-in-schema", + ) + ) + + return issues + + +def _path_exists_in_schema(schema: dict[str, Any] | None, parts: list[str]) -> bool: + """Helper to check if path exists in JSON schema.""" + if not schema or not parts: + return False + + current = schema + for part in parts: + if current.get("type") == "object": + props = current.get("properties", {}) + if part in props: + current = props[part] + else: + return False + else: + return False + return True + + +@rule("core") +def operation_references_valid_channel(ctx: ValidationContext) -> list[ValidationIssue]: + """Validate that operations reference channels that exist.""" + issues = [] + channels = ctx.get_channels() + operations_spec = ctx.get_operations_spec() + + for op_id, op_def in operations_spec.items(): + if not isinstance(op_def, dict): + continue + + channel_ref = op_def.get("channel") + if channel_ref: + # Handle both direct reference and $ref + if isinstance(channel_ref, dict) and "$ref" in channel_ref: + # TODO: Resolve $ref and validate + continue + elif isinstance(channel_ref, str): + # Direct channel reference + if channel_ref not in channels: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Operation references non-existent channel '{channel_ref}'", + path=f"$.operations.{op_id}.channel", + rule="operation-references-valid-channel", + ) + ) + + return issues + + +@rule("core") +def operation_messages_ignored(ctx: ValidationContext) -> list[ValidationIssue]: + """Warn when operation.messages is specified but will be ignored. + + In AsyncAPI 3.0, when an operation references a channel, the channel's messages + are used, and any messages specified directly on the operation are ignored. + """ + issues = [] + operations_spec = ctx.get_operations_spec() + + for op_id, op_def in operations_spec.items(): + if not isinstance(op_def, dict): + continue + + # Check if operation has both channel reference and messages + has_channel = "channel" in op_def and op_def["channel"] + has_messages = "messages" in op_def and op_def["messages"] + + if has_channel and has_messages: + issues.append( + ValidationIssue( + severity=Severity.WARNING, + message=f"Operation '{op_id}' specifies 'messages' but they will be ignored", + path=f"$.operations.{op_id}.messages", + rule="operation-messages-ignored", + suggestion="Remove 'messages' from operation - channel messages are used instead", + ) + ) + + return issues + + +@rule("core") +def valid_operation_action(ctx: ValidationContext) -> list[ValidationIssue]: + """Validate that operation action is 'send' or 'receive'.""" + issues = [] + operations_spec = ctx.get_operations_spec() + + for op_id, op_def in operations_spec.items(): + if not isinstance(op_def, dict): + continue + + action = op_def.get("action") + if action and action not in ["send", "receive"]: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Operation action must be 'send' or 'receive', got '{action}'", + path=f"$.operations.{op_id}.action", + rule="valid-operation-action", + ) + ) + + return issues + + +@rule("core") +def channel_id_no_braces(ctx: ValidationContext) -> list[ValidationIssue]: + """Prohibit curly braces in channel identifiers. + + Channel IDs (keys) should not contain {braces}. Parameters belong in the + channel's address field, not in the channel identifier itself. + """ + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + # Check if channel ID contains braces + if "{" in channel_key or "}" in channel_key: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Channel ID '{channel_key}' must not contain curly braces", + path=f"$.channels.{channel_key}", + rule="channel-id-no-braces", + suggestion="Use a simple identifier for the channel key and put parameters in the 'address' field", + ) + ) + + return issues + + +@rule("core") +def channel_has_address_if_not_reference( + ctx: ValidationContext, +) -> list[ValidationIssue]: + """Validate that channels have an address field (can be null for reusable channels).""" + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + # Skip if this is a reference + if "$ref" in channel_def: + continue + + # Address field should exist (can be null) + if "address" not in channel_def: + issues.append( + ValidationIssue( + severity=Severity.WARNING, + message=f"Channel '{channel_key}' has no 'address' field", + path=f"$.channels.{channel_key}", + rule="channel-has-address", + suggestion="Add 'address' field or set to null for reusable channels", + ) + ) + + return issues + + +@rule("core") +def channel_address_same_as_id(ctx: ValidationContext) -> list[ValidationIssue]: + """Warn when channel address is identical to channel ID. + + When the address field is the same as the channel identifier, it's redundant. + The address field should only be specified when it differs from the channel ID + or when it contains parameters. + """ + issues = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + # Skip if this is a reference + if "$ref" in channel_def: + continue + + # Get the address field + address = channel_def.get("address") + + # Check if address is identical to channel ID + if address and address == channel_key: + issues.append( + ValidationIssue( + severity=Severity.WARNING, + message=f"Channel '{channel_key}' has address identical to its ID", + path=f"$.channels.{channel_key}.address", + rule="channel-address-same-as-id", + suggestion="Remove redundant 'address' field or use null if the address should match the channel ID", + ) + ) + + return issues diff --git a/src/asyncapi_python_codegen/validation/errors.py b/src/asyncapi_python_codegen/validation/errors.py new file mode 100644 index 0000000..750cc50 --- /dev/null +++ b/src/asyncapi_python_codegen/validation/errors.py @@ -0,0 +1,66 @@ +"""Validation error types and severity levels.""" + +from dataclasses import dataclass +from enum import Enum + + +class Severity(Enum): + """Severity level of a validation issue.""" + + ERROR = "error" # Blocks code generation + WARNING = "warning" # Continues with warning + INFO = "info" # Informational only + + +@dataclass +class ValidationIssue: + """A single validation issue found in a document.""" + + severity: Severity + message: str + path: str # JSONPath to the error location (e.g., "$.channels.myChannel") + rule: str # Rule identifier (e.g., "channel-address-matches-parameters") + suggestion: str | None = None # Optional suggestion for fixing the issue + + def __str__(self) -> str: + """Format issue for display.""" + level_emoji = { + Severity.ERROR: "❌", + Severity.WARNING: "⚠️ ", + Severity.INFO: "ℹ️ ", + } + emoji = level_emoji.get(self.severity, "") + suggestion_str = f"\n 💡 {self.suggestion}" if self.suggestion else "" + return f"{emoji} {self.message}\n at {self.path}{suggestion_str}" + + +class ValidationError(ValueError): + """Raised when document validation fails with errors.""" + + def __init__(self, issues: list[ValidationIssue]): + """ + Initialize validation error with issues. + + Args: + issues: List of all validation issues (errors, warnings, info) + """ + self.issues = issues + self.errors = [i for i in issues if i.severity == Severity.ERROR] + self.warnings = [i for i in issues if i.severity == Severity.WARNING] + self.info = [i for i in issues if i.severity == Severity.INFO] + + if self.errors: + error_messages = "\n\n".join(str(e) for e in self.errors) + message = f"Document validation failed with {len(self.errors)} error(s):\n\n{error_messages}" + else: + message = "Document validation completed (no errors)" + + super().__init__(message) + + def has_errors(self) -> bool: + """Check if there are any errors.""" + return len(self.errors) > 0 + + def has_warnings(self) -> bool: + """Check if there are any warnings.""" + return len(self.warnings) > 0 diff --git a/src/asyncapi_python_codegen/validation/protocol/__init__.py b/src/asyncapi_python_codegen/validation/protocol/__init__.py new file mode 100644 index 0000000..147f2b1 --- /dev/null +++ b/src/asyncapi_python_codegen/validation/protocol/__init__.py @@ -0,0 +1,6 @@ +"""Protocol-specific validation rules.""" + +# Import all protocol modules to register their rules +from . import amqp # noqa: F401 + +__all__ = ["amqp"] diff --git a/src/asyncapi_python_codegen/validation/protocol/amqp.py b/src/asyncapi_python_codegen/validation/protocol/amqp.py new file mode 100644 index 0000000..37e93ac --- /dev/null +++ b/src/asyncapi_python_codegen/validation/protocol/amqp.py @@ -0,0 +1,64 @@ +"""AMQP protocol-specific validation rules.""" + +# pyright: reportUnknownVariableType=false, reportUnknownMemberType=false + +from typing import Any + +from ..base import rule +from ..context import ValidationContext +from ..errors import Severity, ValidationIssue + + +@rule("protocol.amqp") +def amqp_parameterized_channels_require_binding_type( + ctx: ValidationContext, +) -> list[ValidationIssue]: + """Parameterized channels MUST specify AMQP binding type (routingKey or queue). + + When compiling for AMQP, any channel with parameters must explicitly specify + whether the parameterized address should be interpreted as: + - "routingKey": Topic exchange routing key with pattern matching + - "queue": Direct queue name for point-to-point messaging + + This disambiguates how AMQP should handle the parameterized address. + """ + issues: list[ValidationIssue] = [] + + for channel_key, channel_def in ctx.get_channels().items(): + if not isinstance(channel_def, dict): + continue + + # Only check channels with parameters + parameters: dict[str, Any] = channel_def.get("parameters", {}) + if not parameters: + continue # No parameters, no ambiguity + + # Check for AMQP binding + bindings: dict[str, Any] = channel_def.get("bindings", {}) + amqp_binding: dict[str, Any] = bindings.get("amqp", {}) + + # Require explicit "is" field for parameterized channels + if "is" not in amqp_binding: + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Channel '{channel_key}' has parameters but AMQP binding lacks 'is' field", + path=f"$.channels.{channel_key}.bindings.amqp", + rule="amqp-parameterized-channels-require-binding-type", + suggestion="Add 'is: routingKey' for topic exchange routing or 'is: queue' for direct queue addressing", + ) + ) + elif amqp_binding["is"] not in ["routingKey", "queue"]: + # Validate the value is one of the allowed types + invalid_value: Any = amqp_binding["is"] + issues.append( + ValidationIssue( + severity=Severity.ERROR, + message=f"Channel '{channel_key}' has invalid AMQP binding type: '{invalid_value}'", + path=f"$.channels.{channel_key}.bindings.amqp.is", + rule="amqp-parameterized-channels-require-binding-type", + suggestion="Use 'is: routingKey' for topic exchanges or 'is: queue' for direct queues", + ) + ) + + return issues diff --git a/tests/codegen/test_validation.py b/tests/codegen/test_validation.py new file mode 100644 index 0000000..034032a --- /dev/null +++ b/tests/codegen/test_validation.py @@ -0,0 +1,480 @@ +"""Tests for the validation system.""" + +from pathlib import Path + +import pytest + +from asyncapi_python_codegen.parser.document_loader import extract_all_operations +from asyncapi_python_codegen.validation import ( + Severity, + ValidationError, + ValidationIssue, + rule, + validate_spec, +) +from asyncapi_python_codegen.validation.context import ValidationContext + + +def test_validation_error_raised_for_missing_asyncapi_field(tmp_path: Path): + """Test that missing asyncapi field raises ValidationError.""" + spec_file = tmp_path / "invalid.yaml" + spec_file.write_text( + """ +operations: + myOp: + action: send +""" + ) + + with pytest.raises(ValueError, match="Missing 'asyncapi' version field"): + extract_all_operations(spec_file) + + +def test_validation_error_for_invalid_channel_parameters(tmp_path: Path): + """Test that parameters not in address raise ValidationError.""" + spec_file = tmp_path / "invalid_params.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + myChannel: + address: static.address + parameters: + userId: + schema: + type: string + messages: + msg: + payload: + type: object +operations: + myOp: + action: send + channel: + $ref: '#/channels/myChannel' +""" + ) + + with pytest.raises(ValidationError) as exc_info: + extract_all_operations(spec_file) + + assert len(exc_info.value.errors) > 0 + assert any( + "not used in address" in error.message for error in exc_info.value.errors + ) + + +def test_validation_passes_for_valid_spec(tmp_path: Path): + """Test that a valid spec passes validation.""" + spec_file = tmp_path / "valid.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + userChannel: + address: user.{userId} + parameters: + userId: + location: $message.payload#/userId + messages: + userMessage: + payload: + type: object + properties: + userId: + type: string + bindings: + amqp: + is: routingKey +operations: + sendUser: + action: send + channel: + $ref: '#/channels/userChannel' + messages: + - $ref: '#/channels/userChannel/messages/userMessage' +""" + ) + + # Should not raise + operations = extract_all_operations(spec_file) + assert "sendUser" in operations + + +def test_validation_can_be_disabled(tmp_path: Path): + """Test that validation can be disabled.""" + spec_file = tmp_path / "invalid_params.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + myChannel: + address: static.address + parameters: + userId: + schema: + type: string + messages: + msg: + payload: + type: object +operations: + myOp: + action: send + channel: + $ref: '#/channels/myChannel' +""" + ) + + # Should not raise when validation is disabled + operations = extract_all_operations(spec_file, validate=False) + assert "myOp" in operations + + +def test_warnings_do_not_fail_validation(tmp_path: Path): + """Test that warnings are collected but don't fail validation.""" + spec_file = tmp_path / "with_warnings.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + myChannel: + address: user + parameters: + userId: + location: $message.payload#/userId + messages: + userMessage: + payload: + type: object + properties: + userId: + type: string + bindings: + amqp: + is: queue +operations: + myOp: + action: send + channel: + $ref: '#/channels/myChannel' +""" + ) + + # Should not raise - location warning doesn't fail + operations = extract_all_operations(spec_file) + assert "myOp" in operations + + +def test_custom_rule_registration(): + """Test that users can register custom validation rules.""" + + @rule("custom") + def custom_test_rule(ctx: ValidationContext) -> list[ValidationIssue]: + if ctx.spec.get("asyncapi") == "3.0.0": + return [ + ValidationIssue( + severity=Severity.INFO, + message="Custom rule triggered", + path="$.asyncapi", + rule="custom-test-rule", + ) + ] + return [] + + # Create a simple spec + spec = { + "asyncapi": "3.0.0", + "operations": {}, + } + + # Validate with custom category + issues = validate_spec( + spec=spec, + operations={}, + spec_path=Path("."), + categories=["custom"], + fail_on_error=False, + ) + + assert len(issues) == 1 + assert issues[0].rule == "custom-test-rule" + + +def test_parameter_with_location_warns_not_implemented(tmp_path: Path): + """Test that using location field generates a warning.""" + spec_file = tmp_path / "location.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + myChannel: + address: user + parameters: + userId: + location: $message.payload#/userId + messages: + userMessage: + payload: + type: object +operations: + myOp: + action: send + channel: + $ref: '#/channels/myChannel' +""" + ) + + # Should succeed but print warning + operations = extract_all_operations(spec_file, fail_on_error=False) + assert "myOp" in operations + + +def test_undefined_placeholders_in_address(tmp_path: Path): + """Test that undefined placeholders in address raise error.""" + spec_file = tmp_path / "undefined_params.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + myChannel: + address: user.{userId}.{role} + parameters: + userId: + schema: + type: string + messages: + msg: + payload: + type: object +operations: + myOp: + action: send + channel: + $ref: '#/channels/myChannel' +""" + ) + + with pytest.raises(ValidationError) as exc_info: + extract_all_operations(spec_file) + + assert any( + "undefined parameters" in error.message for error in exc_info.value.errors + ) + + +def test_operation_references_nonexistent_channel(tmp_path: Path): + """Test that operation referencing non-existent channel raises error.""" + spec_file = tmp_path / "bad_channel_ref.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + realChannel: + address: real + messages: + msg: + payload: + type: object +operations: + myOp: + action: send + channel: + $ref: '#/channels/fakeChannel' +""" + ) + + # Parser will fail when trying to resolve $ref (before validation runs) + with pytest.raises( + RuntimeError, match="JSON pointer segment 'fakeChannel' not found" + ): + extract_all_operations(spec_file) + + +def test_invalid_operation_action(tmp_path: Path): + """Test that invalid operation action raises error.""" + spec_file = tmp_path / "bad_action.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + myChannel: + address: test + messages: + msg: + payload: + type: object +operations: + myOp: + action: publish + channel: + $ref: '#/channels/myChannel' +""" + ) + + with pytest.raises(ValidationError) as exc_info: + extract_all_operations(spec_file) + + assert any( + "must be 'send' or 'receive'" in error.message + for error in exc_info.value.errors + ) + + +def test_amqp_parameterized_channel_without_binding_type_fails(tmp_path: Path): + """Test that parameterized channel without AMQP binding type fails validation.""" + spec_file = tmp_path / "amqp_no_binding_type.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + weatherAlerts: + address: weather.{location}.{severity} + parameters: + location: + location: $message.payload#/location + severity: + location: $message.payload#/severity + messages: + alert: + payload: + type: object + properties: + location: + type: string + severity: + type: string + bindings: + amqp: + # Missing 'is' field! + exchange: + name: weather_alerts + type: topic +operations: + publishAlert: + action: send + channel: + $ref: '#/channels/weatherAlerts' +""" + ) + + with pytest.raises(ValidationError) as exc_info: + extract_all_operations(spec_file) + + assert any("lacks 'is' field" in error.message for error in exc_info.value.errors) + + +def test_amqp_parameterized_channel_with_routing_key_passes(tmp_path: Path): + """Test that parameterized channel with is: routingKey passes validation.""" + spec_file = tmp_path / "amqp_routing_key.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + weatherAlerts: + address: weather.{location}.{severity} + parameters: + location: + location: $message.payload#/location + severity: + location: $message.payload#/severity + messages: + alert: + payload: + type: object + properties: + location: + type: string + severity: + type: string + bindings: + amqp: + is: routingKey + exchange: + name: weather_alerts + type: topic +operations: + publishAlert: + action: send + channel: + $ref: '#/channels/weatherAlerts' +""" + ) + + # Should not raise + operations = extract_all_operations(spec_file) + assert "publishAlert" in operations + + +def test_amqp_parameterized_channel_with_queue_passes(tmp_path: Path): + """Test that parameterized channel with is: queue passes validation.""" + spec_file = tmp_path / "amqp_queue.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + userNotifications: + address: user.{userId}.notifications + parameters: + userId: + location: $message.payload#/userId + messages: + notification: + payload: + type: object + properties: + userId: + type: string + bindings: + amqp: + is: queue +operations: + sendNotification: + action: send + channel: + $ref: '#/channels/userNotifications' +""" + ) + + # Should not raise + operations = extract_all_operations(spec_file) + assert "sendNotification" in operations + + +def test_amqp_parameterized_channel_with_invalid_binding_type_fails(tmp_path: Path): + """Test that parameterized channel with invalid binding type fails validation.""" + spec_file = tmp_path / "amqp_invalid_type.yaml" + spec_file.write_text( + """ +asyncapi: 3.0.0 +channels: + myChannel: + address: my.{param}.channel + parameters: + param: + location: $message.payload#/param + messages: + msg: + payload: + type: object + properties: + param: + type: string + bindings: + amqp: + is: topic # Invalid! Should be 'routingKey' or 'queue' +operations: + myOp: + action: send + channel: + $ref: '#/channels/myChannel' +""" + ) + + with pytest.raises(ValidationError) as exc_info: + extract_all_operations(spec_file) + + assert any( + "invalid" in error.message and "binding type" in error.message + for error in exc_info.value.errors + ) diff --git a/tests/codegen/test_validation_system.py b/tests/codegen/test_validation_system.py new file mode 100644 index 0000000..c1067f3 --- /dev/null +++ b/tests/codegen/test_validation_system.py @@ -0,0 +1,466 @@ +"""Unit tests for the validation system infrastructure (registry, decorators, errors).""" + +from pathlib import Path + +import pytest + +from asyncapi_python_codegen.validation import ( + Severity, + ValidationError, + ValidationIssue, + rule, + validate_spec, +) +from asyncapi_python_codegen.validation.base import RuleRegistry, get_registry +from asyncapi_python_codegen.validation.context import ValidationContext +from asyncapi_python_codegen.validation.errors import ValidationError as VE + + +class TestValidationIssue: + """Test ValidationIssue dataclass.""" + + def test_create_issue(self): + """Test creating a validation issue.""" + issue = ValidationIssue( + severity=Severity.ERROR, + message="Test error", + path="$.test", + rule="test-rule", + ) + + assert issue.severity == Severity.ERROR + assert issue.message == "Test error" + assert issue.path == "$.test" + assert issue.rule == "test-rule" + assert issue.suggestion is None + + def test_issue_with_suggestion(self): + """Test creating issue with suggestion.""" + issue = ValidationIssue( + severity=Severity.WARNING, + message="Test warning", + path="$.test", + rule="test-rule", + suggestion="Try this instead", + ) + + assert issue.suggestion == "Try this instead" + + def test_issue_string_format(self): + """Test string formatting of issues.""" + issue = ValidationIssue( + severity=Severity.ERROR, + message="Something went wrong", + path="$.channels.test", + rule="test-rule", + ) + + issue_str = str(issue) + assert "Something went wrong" in issue_str + assert "$.channels.test" in issue_str + + def test_issue_with_suggestion_format(self): + """Test string formatting includes suggestion.""" + issue = ValidationIssue( + severity=Severity.WARNING, + message="Issue here", + path="$.test", + rule="test-rule", + suggestion="Fix it like this", + ) + + issue_str = str(issue) + assert "Fix it like this" in issue_str + + +class TestValidationError: + """Test ValidationError exception.""" + + def test_create_validation_error(self): + """Test creating validation error.""" + issues = [ + ValidationIssue( + severity=Severity.ERROR, + message="Error 1", + path="$.test1", + rule="rule1", + ), + ValidationIssue( + severity=Severity.WARNING, + message="Warning 1", + path="$.test2", + rule="rule2", + ), + ] + + error = ValidationError(issues) + + assert len(error.issues) == 2 + assert len(error.errors) == 1 + assert len(error.warnings) == 1 + assert error.has_errors() + assert error.has_warnings() + + def test_validation_error_message(self): + """Test error message includes error details.""" + issues = [ + ValidationIssue( + severity=Severity.ERROR, + message="Critical error", + path="$.test", + rule="test-rule", + ) + ] + + error = ValidationError(issues) + error_msg = str(error) + + assert "Document validation failed" in error_msg + assert "Critical error" in error_msg + + def test_no_errors_validation_error(self): + """Test ValidationError with only warnings.""" + issues = [ + ValidationIssue( + severity=Severity.WARNING, + message="Just a warning", + path="$.test", + rule="test-rule", + ) + ] + + error = ValidationError(issues) + + assert not error.has_errors() + assert error.has_warnings() + + +class TestValidationContext: + """Test ValidationContext dataclass.""" + + def test_create_context(self): + """Test creating validation context.""" + spec = {"asyncapi": "3.0.0", "channels": {}} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml"), operations=None) + + assert ctx.spec == spec + assert ctx.spec_path == Path("test.yaml") + assert ctx.operations is None + + def test_get_channels(self): + """Test get_channels helper method.""" + spec = {"channels": {"ch1": {}, "ch2": {}}} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + channels = ctx.get_channels() + + assert len(channels) == 2 + assert "ch1" in channels + assert "ch2" in channels + + def test_get_channels_empty(self): + """Test get_channels when no channels.""" + spec = {"asyncapi": "3.0.0"} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + channels = ctx.get_channels() + + assert channels == {} + + def test_get_operations_spec(self): + """Test get_operations_spec helper.""" + spec = {"operations": {"op1": {}, "op2": {}}} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + ops = ctx.get_operations_spec() + + assert len(ops) == 2 + assert "op1" in ops + + def test_get_components(self): + """Test get_components helper.""" + spec = {"components": {"messages": {}}} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + components = ctx.get_components() + + assert "messages" in components + + +class TestRuleRegistry: + """Test RuleRegistry class.""" + + def test_create_registry(self): + """Test creating a registry.""" + registry = RuleRegistry() + + assert len(registry.get_all_categories()) == 0 + + def test_register_rule(self): + """Test registering a rule.""" + registry = RuleRegistry() + + def test_rule(ctx: ValidationContext) -> list[ValidationIssue]: + return [] + + registry.register("test", test_rule) + + assert "test" in registry.get_all_categories() + assert len(registry.get_rules("test")) == 1 + + def test_register_multiple_rules_same_category(self): + """Test registering multiple rules in same category.""" + registry = RuleRegistry() + + def rule1(ctx): + return [] + + def rule2(ctx): + return [] + + registry.register("test", rule1) + registry.register("test", rule2) + + assert len(registry.get_rules("test")) == 2 + + def test_register_rules_different_categories(self): + """Test registering rules in different categories.""" + registry = RuleRegistry() + + def rule1(ctx): + return [] + + def rule2(ctx): + return [] + + registry.register("cat1", rule1) + registry.register("cat2", rule2) + + assert len(registry.get_all_categories()) == 2 + assert len(registry.get_rules("cat1")) == 1 + assert len(registry.get_rules("cat2")) == 1 + + def test_validate_runs_rules(self): + """Test that validate runs registered rules.""" + registry = RuleRegistry() + called = [] + + def test_rule(ctx: ValidationContext) -> list[ValidationIssue]: + called.append(True) + return [] + + registry.register("test", test_rule) + + spec = {"asyncapi": "3.0.0"} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + registry.validate(ctx, categories=["test"]) + + assert len(called) == 1 + + def test_validate_collects_issues(self): + """Test that validate collects issues from rules.""" + registry = RuleRegistry() + + def test_rule(ctx: ValidationContext) -> list[ValidationIssue]: + return [ + ValidationIssue( + severity=Severity.ERROR, + message="Test error", + path="$.test", + rule="test-rule", + ) + ] + + registry.register("test", test_rule) + + spec = {"asyncapi": "3.0.0"} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + issues = registry.validate(ctx, categories=["test"]) + + assert len(issues) == 1 + assert issues[0].message == "Test error" + + def test_validate_runs_only_specified_categories(self): + """Test that validate only runs specified categories.""" + registry = RuleRegistry() + called = {"cat1": False, "cat2": False} + + def rule1(ctx): + called["cat1"] = True + return [] + + def rule2(ctx): + called["cat2"] = True + return [] + + registry.register("cat1", rule1) + registry.register("cat2", rule2) + + spec = {"asyncapi": "3.0.0"} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + registry.validate(ctx, categories=["cat1"]) + + assert called["cat1"] is True + assert called["cat2"] is False + + def test_validate_handles_rule_exception(self): + """Test that validate handles exceptions in rules.""" + registry = RuleRegistry() + + def broken_rule(ctx): + raise RuntimeError("Rule crashed!") + + registry.register("test", broken_rule) + + spec = {"asyncapi": "3.0.0"} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + issues = registry.validate(ctx, categories=["test"]) + + assert len(issues) == 1 + assert issues[0].severity == Severity.ERROR + assert "Rule crashed!" in issues[0].message + + +class TestRuleDecorator: + """Test @rule decorator.""" + + def test_rule_decorator_registers_function(self): + """Test that @rule decorator registers the function.""" + # Get current rule count + registry = get_registry() + initial_count = len(registry.get_rules("test-category")) + + @rule("test-category") + def my_test_rule(ctx: ValidationContext) -> list[ValidationIssue]: + return [] + + # Should have one more rule + assert len(registry.get_rules("test-category")) == initial_count + 1 + + def test_rule_decorator_preserves_function(self): + """Test that decorator preserves the function.""" + + @rule("test-category") + def my_test_rule(ctx: ValidationContext) -> list[ValidationIssue]: + return [ + ValidationIssue( + severity=Severity.INFO, + message="Test", + path="$.test", + rule="test", + ) + ] + + # Function should still be callable + spec = {"asyncapi": "3.0.0"} + ctx = ValidationContext(spec=spec, spec_path=Path("test.yaml")) + + issues = my_test_rule(ctx) + + assert len(issues) == 1 + + +class TestValidateSpecFunction: + """Test the validate_spec function.""" + + def test_validate_spec_runs_core_rules(self): + """Test that validate_spec runs core rules.""" + spec = {"asyncapi": "2.0.0"} # Wrong version + + with pytest.raises(ValidationError) as exc_info: + validate_spec(spec=spec, operations=None, spec_path=Path("test.yaml")) + + assert exc_info.value.has_errors() + + def test_validate_spec_with_valid_spec(self): + """Test validate_spec with valid spec.""" + spec = { + "asyncapi": "3.0.0", + "operations": { + "test": { + "action": "send", + } + }, + } + + # Should not raise + issues = validate_spec( + spec=spec, + operations={}, + spec_path=Path("test.yaml"), + fail_on_error=False, + ) + + # May have some issues but shouldn't fail + errors = [i for i in issues if i.severity == Severity.ERROR] + assert len(errors) >= 0 # May or may not have errors + + def test_validate_spec_fail_on_error_false(self): + """Test that fail_on_error=False doesn't raise.""" + spec = {"asyncapi": "2.0.0"} # Wrong version + + issues = validate_spec( + spec=spec, + operations=None, + spec_path=Path("test.yaml"), + fail_on_error=False, + ) + + # Should have errors but not raise + errors = [i for i in issues if i.severity == Severity.ERROR] + assert len(errors) > 0 + + def test_validate_spec_with_categories(self): + """Test validate_spec with specific categories.""" + + # Register a test rule in a custom category + @rule("custom-test") + def custom_rule(ctx): + return [ + ValidationIssue( + severity=Severity.INFO, + message="Custom rule ran", + path="$.test", + rule="custom-rule", + ) + ] + + spec = {"asyncapi": "3.0.0", "operations": {}} + + issues = validate_spec( + spec=spec, + operations={}, + spec_path=Path("test.yaml"), + categories=["custom-test"], + fail_on_error=False, + ) + + # Should have run our custom rule + assert any(i.message == "Custom rule ran" for i in issues) + + +class TestSeverityEnum: + """Test Severity enum.""" + + def test_severity_values(self): + """Test severity enum values.""" + assert Severity.ERROR.value == "error" + assert Severity.WARNING.value == "warning" + assert Severity.INFO.value == "info" + + def test_severity_comparison(self): + """Test severity can be compared.""" + error_issue = ValidationIssue( + severity=Severity.ERROR, message="", path="", rule="" + ) + warning_issue = ValidationIssue( + severity=Severity.WARNING, message="", path="", rule="" + ) + + assert error_issue.severity == Severity.ERROR + assert warning_issue.severity != Severity.ERROR + assert warning_issue.severity == Severity.WARNING diff --git a/tests/conftest.py b/tests/conftest.py index 14eb405..a06c27a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,12 +14,19 @@ import asyncio +import re from os import environ from typing import Generator import pytest from asyncapi_python.contrib.wire.in_memory import reset_bus +from asyncapi_python.kernel.document.bindings import AmqpChannelBinding +from asyncapi_python.kernel.document.channel import ( + AddressParameter, + Channel, + ChannelBindings, +) @pytest.fixture(scope="session") @@ -45,3 +52,49 @@ def reset_in_memory_bus() -> Generator[None, None, None]: reset_bus() yield reset_bus() + + +class PytestHelpers: + """Test helper functions available via pytest.helpers""" + + @staticmethod + def create_test_channel( + address: str | None = None, + binding: AmqpChannelBinding | None = None, + ) -> Channel: + """Create a minimal test channel with required fields. + + Automatically extracts parameters from address template (e.g., {location}). + """ + bindings = None + if binding: + bindings = ChannelBindings(amqp=binding) + + # Extract parameters from address template + parameters = {} + if address: + param_names = re.findall(r"\{(\w+)\}", address) + for param_name in param_names: + parameters[param_name] = AddressParameter( + key=param_name, + description=f"Test parameter {param_name}", + location=None, + ) + + return Channel( + key="test_channel", + address=address, + title=None, + summary=None, + description=None, + servers=[], + messages={}, + parameters=parameters, + tags=[], + external_docs=None, + bindings=bindings, + ) + + +# Register helpers +pytest.helpers = PytestHelpers() # type: ignore diff --git a/tests/core/endpoint/test_parameterized_subscriptions.py b/tests/core/endpoint/test_parameterized_subscriptions.py new file mode 100644 index 0000000..038af0a --- /dev/null +++ b/tests/core/endpoint/test_parameterized_subscriptions.py @@ -0,0 +1,634 @@ +"""Tests for parameterized channel subscriptions with wildcards.""" + +import asyncio +from typing import Any + +import pytest + +from asyncapi_python.contrib.codec.json import JsonCodecFactory +from asyncapi_python.contrib.wire.in_memory import InMemoryWire +from asyncapi_python.kernel.application import BaseApplication +from asyncapi_python.kernel.document import Channel, Operation +from asyncapi_python.kernel.document.bindings import ( + AmqpChannelBinding, + AmqpExchange, + AmqpExchangeType, +) +from asyncapi_python.kernel.document.channel import AddressParameter, ChannelBindings +from asyncapi_python.kernel.document.message import Message +from asyncapi_python.kernel.endpoint import Publisher, Subscriber +from pydantic import BaseModel + + +class AlertMessage(BaseModel): + """Test message with location and severity fields.""" + + location: str + severity: str + data: str + + +@pytest.fixture +def parameterized_channel() -> Channel: + """Create a parameterized channel with AMQP routing key binding.""" + return Channel( + key="test_channel", + address="alerts.{location}.{severity}", + title="Test Channel", + summary=None, + description=None, + servers=[], + messages={ + "TestMessage": Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + }, + parameters={ + "location": AddressParameter( + key="location", + description="Location code", + location="$message.payload#/location", + ), + "severity": AddressParameter( + key="severity", + description="Severity level", + location="$message.payload#/severity", + ), + }, + tags=[], + external_docs=None, + bindings=ChannelBindings( + amqp=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="test_exchange", + type=AmqpExchangeType.TOPIC, + ), + ) + ), + ) + + +@pytest.fixture +def queue_channel() -> Channel: + """Create a parameterized channel with AMQP queue binding.""" + return Channel( + key="test_queue", + address="queue.{priority}", + title="Test Queue", + summary=None, + description=None, + servers=[], + messages={ + "TestMessage": Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + }, + parameters={ + "priority": AddressParameter( + key="priority", + description="Priority level", + location="$message.payload#/priority", + ), + }, + tags=[], + external_docs=None, + bindings=ChannelBindings( + amqp=AmqpChannelBinding( + type="queue", + ) + ), + ) + + +async def test_subscriber_accepts_parameters(): + """Subscriber should accept parameters dict in decorator.""" + wire = InMemoryWire() + + # Create a minimal module for codec factory with messages.json structure + import types + + test_module = types.SimpleNamespace() + test_module.messages = types.SimpleNamespace() + test_module.messages.json = types.SimpleNamespace() + test_module.messages.json.TestMessage = AlertMessage + codec_factory = JsonCodecFactory(test_module) + + operation = Operation( + key="test_op", + action="receive", + title=None, + summary=None, + description=None, + channel=pytest.helpers.create_test_channel( # type: ignore + address="alerts.{location}.{severity}", + binding=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="test_exchange", + type=AmqpExchangeType.TOPIC, + ), + ), + ), + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + subscriber = Subscriber( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + # Register handler with wildcard parameters + @subscriber(parameters={"location": "*", "severity": "high"}) + async def handle_alert(msg: AlertMessage) -> None: + pass + + # Should not raise + assert subscriber._subscription_parameters == {"location": "*", "severity": "high"} + + +async def test_subscriber_wildcard_parameters_flow_to_wire(): + """Subscriber parameters should be passed to wire factory.""" + wire = InMemoryWire() + + import types + + test_module = types.SimpleNamespace() + test_module.messages = types.SimpleNamespace() + test_module.messages.json = types.SimpleNamespace() + test_module.messages.json.TestMessage = AlertMessage + codec_factory = JsonCodecFactory(test_module) + + # Mock the wire factory to capture parameters + captured_params: dict[str, str] = {} + + original_create_consumer = wire.create_consumer + + async def mock_create_consumer(**kwargs: Any) -> Any: + nonlocal captured_params + captured_params = kwargs.get("parameters", {}) + return await original_create_consumer(**kwargs) + + wire.create_consumer = mock_create_consumer # type: ignore + + operation = Operation( + key="test_op", + action="receive", + title=None, + summary=None, + description=None, + channel=pytest.helpers.create_test_channel( # type: ignore + address="alerts.{location}.{severity}", + binding=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="test_exchange", + type=AmqpExchangeType.TOPIC, + ), + ), + ), + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + subscriber = Subscriber( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + # Register handler with parameters + @subscriber(parameters={"location": "NYC", "severity": "*"}) + async def handle_alert(msg: AlertMessage) -> None: + pass + + # Start subscriber (should create consumer with parameters) + await subscriber.start() + + # Verify parameters were passed to wire factory + assert captured_params == {"location": "NYC", "severity": "*"} + + await subscriber.stop() + + +async def test_queue_binding_with_wildcards_raises_error(): + """Queue bindings with wildcard parameters should raise ValueError.""" + from asyncapi_python.contrib.wire.amqp import AmqpWire + + import types + + test_module = types.SimpleNamespace() + + # This would fail at runtime when creating consumer + wire = AmqpWire("amqp://guest:guest@localhost") + codec_factory = JsonCodecFactory(test_module) + + operation = Operation( + key="test_op", + action="receive", + title=None, + summary=None, + description=None, + channel=pytest.helpers.create_test_channel( # type: ignore + address="queue.{priority}", + binding=AmqpChannelBinding(type="queue"), + ), + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + subscriber = Subscriber( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + @subscriber(parameters={"priority": "*"}) + async def handle_task(msg: AlertMessage) -> None: + pass + + # Should raise ValueError when starting (wire layer validation) + with pytest.raises(ValueError, match="wildcard patterns"): + await subscriber.start() + + +async def test_default_empty_parameters(): + """Subscriber without parameters should use empty dict.""" + wire = InMemoryWire() + + import types + + test_module = types.SimpleNamespace() + test_module.messages = types.SimpleNamespace() + test_module.messages.json = types.SimpleNamespace() + test_module.messages.json.TestMessage = AlertMessage + codec_factory = JsonCodecFactory(test_module) + + operation = Operation( + key="test_op", + action="receive", + title=None, + summary=None, + description=None, + channel=pytest.helpers.create_test_channel( # type: ignore + address="simple.queue", + binding=None, + ), + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + subscriber = Subscriber( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + @subscriber + async def handle_msg(msg: AlertMessage) -> None: + pass + + # Should use empty parameters + assert subscriber._subscription_parameters == {} + + await subscriber.start() + await subscriber.stop() + + +async def test_subscriber_rejects_missing_parameters(): + """Subscriber should raise ValueError when required parameters are missing.""" + from asyncapi_python.contrib.wire.amqp import AmqpWire + + wire = AmqpWire("amqp://guest:guest@localhost") + + import types + + test_module = types.SimpleNamespace() + test_module.messages = types.SimpleNamespace() + test_module.messages.json = types.SimpleNamespace() + test_module.messages.json.TestMessage = AlertMessage + codec_factory = JsonCodecFactory(test_module) + + # Create channel with 2 parameters + channel = Channel( + key="test_channel", + address="alerts.{location}.{severity}", + title="Test Channel", + summary=None, + description=None, + servers=[], + messages={ + "TestMessage": Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + }, + parameters={ + "location": AddressParameter( + key="location", + description="Location code", + location="$message.payload#/location", + ), + "severity": AddressParameter( + key="severity", + description="Severity level", + location="$message.payload#/severity", + ), + }, + tags=[], + external_docs=None, + bindings=ChannelBindings( + amqp=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="test_exchange", + type=AmqpExchangeType.TOPIC, + ), + ) + ), + ) + + operation = Operation( + key="test_op", + action="receive", + title=None, + summary=None, + description=None, + channel=channel, + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + subscriber = Subscriber( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + # Register with only 1 parameter (missing severity) + @subscriber(parameters={"location": "NYC"}) + async def handle_alert(msg: AlertMessage) -> None: + pass + + # Should raise ValueError when starting + with pytest.raises(ValueError, match="Missing required parameters"): + await subscriber.start() + + +async def test_subscriber_rejects_extra_parameters(): + """Subscriber should raise ValueError when extra parameters are provided.""" + from asyncapi_python.contrib.wire.amqp import AmqpWire + + wire = AmqpWire("amqp://guest:guest@localhost") + + import types + + test_module = types.SimpleNamespace() + test_module.messages = types.SimpleNamespace() + test_module.messages.json = types.SimpleNamespace() + test_module.messages.json.TestMessage = AlertMessage + codec_factory = JsonCodecFactory(test_module) + + # Create channel with 1 parameter + channel = Channel( + key="test_channel", + address="alerts.{location}", + title="Test Channel", + summary=None, + description=None, + servers=[], + messages={ + "TestMessage": Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + }, + parameters={ + "location": AddressParameter( + key="location", + description="Location code", + location="$message.payload#/location", + ), + }, + tags=[], + external_docs=None, + bindings=ChannelBindings( + amqp=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="test_exchange", + type=AmqpExchangeType.TOPIC, + ), + ) + ), + ) + + operation = Operation( + key="test_op", + action="receive", + title=None, + summary=None, + description=None, + channel=channel, + messages=[ + Message( + name="TestMessage", + title="Test Message", + summary=None, + description=None, + content_type="application/json", + headers=None, + tags=[], + externalDocs=None, + bindings=None, + deprecated=None, + correlation_id=None, + traits=[], + payload={"type": "object"}, + key="", + ) + ], + reply=None, + traits=[], + security=[], + tags=[], + external_docs=None, + bindings=None, + ) + + subscriber = Subscriber( + operation=operation, + wire_factory=wire, + codec_factory=codec_factory, + ) + + # Register with 2 parameters (location + extra severity) + @subscriber(parameters={"location": "NYC", "severity": "high"}) + async def handle_alert(msg: AlertMessage) -> None: + pass + + # Should raise ValueError when starting + with pytest.raises(ValueError, match="Unexpected parameters"): + await subscriber.start() diff --git a/tests/core/wire/test_amqp_resolver.py b/tests/core/wire/test_amqp_resolver.py new file mode 100644 index 0000000..3e8d067 --- /dev/null +++ b/tests/core/wire/test_amqp_resolver.py @@ -0,0 +1,352 @@ +"""Tests for AMQP resolver validation logic.""" + +import re + +import pytest + +from asyncapi_python.contrib.wire.amqp.resolver import ( + _validate_no_wildcards_in_queue, + resolve_amqp_config, +) +from asyncapi_python.kernel.document.bindings import ( + AmqpChannelBinding, + AmqpExchange, + AmqpExchangeType, +) +from asyncapi_python.kernel.document.channel import ( + AddressParameter, + Channel, + ChannelBindings, +) +from asyncapi_python.kernel.wire import EndpointParams + + +def create_test_channel( + address: str | None = None, + binding: AmqpChannelBinding | None = None, +) -> Channel: + """Create a minimal test channel with required fields. + + Automatically extracts parameters from address template (e.g., {location}). + """ + bindings = None + if binding: + bindings = ChannelBindings(amqp=binding) + + # Extract parameters from address template + parameters = {} + if address: + param_names = re.findall(r"\{(\w+)\}", address) + for param_name in param_names: + parameters[param_name] = AddressParameter( + key=param_name, + description=f"Test parameter {param_name}", + location=None, + ) + + return Channel( + key="test_channel", + address=address, + title=None, + summary=None, + description=None, + servers=[], + messages={}, + parameters=parameters, + tags=[], + external_docs=None, + bindings=bindings, + ) + + +def test_validate_no_wildcards_accepts_concrete_values(): + """Concrete parameter values should pass validation.""" + params = {"location": "NYC", "severity": "high"} + # Should not raise + _validate_no_wildcards_in_queue(params) + + +def test_validate_no_wildcards_rejects_star_wildcard(): + """Parameter values with * wildcard should be rejected.""" + params = {"location": "*", "severity": "high"} + with pytest.raises(ValueError) as exc_info: + _validate_no_wildcards_in_queue(params) + + assert "wildcard patterns" in str(exc_info.value) + assert "location=*" in str(exc_info.value) + + +def test_validate_no_wildcards_rejects_hash_wildcard(): + """Parameter values with # wildcard should be rejected.""" + params = {"location": "NYC", "severity": "#"} + with pytest.raises(ValueError) as exc_info: + _validate_no_wildcards_in_queue(params) + + assert "wildcard patterns" in str(exc_info.value) + assert "severity=#" in str(exc_info.value) + + +def test_validate_no_wildcards_rejects_multiple_wildcards(): + """Multiple wildcard parameters should all be reported.""" + params = {"location": "*", "severity": "#", "priority": "high"} + with pytest.raises(ValueError) as exc_info: + _validate_no_wildcards_in_queue(params) + + error_msg = str(exc_info.value) + assert "wildcard patterns" in error_msg + assert "location=*" in error_msg + assert "severity=#" in error_msg + + +def test_validate_no_wildcards_accepts_empty_dict(): + """Empty parameter dict should pass validation.""" + params = {} + # Should not raise + _validate_no_wildcards_in_queue(params) + + +def test_resolve_queue_binding_rejects_wildcards(): + """Queue bindings should reject wildcard parameters.""" + # Create channel with queue binding + channel = create_test_channel( + address="weather.{location}.{severity}", + binding=AmqpChannelBinding(type="queue"), + ) + + params: EndpointParams = { + "channel": channel, + "parameters": {"location": "*", "severity": "high"}, + "op_bindings": None, + "is_reply": False, + } + + with pytest.raises(ValueError) as exc_info: + resolve_amqp_config(params, "test_op", "test_app") + + assert "wildcard patterns" in str(exc_info.value) + assert "location=*" in str(exc_info.value) + + +def test_resolve_routing_key_binding_accepts_wildcards(): + """Routing key bindings should accept wildcard parameters.""" + # Create channel with routingKey binding + channel = create_test_channel( + address="weather.{location}.{severity}", + binding=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="weather_exchange", + type=AmqpExchangeType.TOPIC, + ), + ), + ) + + params: EndpointParams = { + "channel": channel, + "parameters": {"location": "*", "severity": "high"}, + "op_bindings": None, + "is_reply": False, + } + + # Should not raise + config = resolve_amqp_config(params, "test_op", "test_app") + + # Verify the routing key was substituted with wildcards + assert config.routing_key == "weather.*.high" + + +def test_resolve_channel_address_without_binding_rejects_wildcards(): + """Channel address without explicit binding (implicit queue) should reject wildcards.""" + # Channel without bindings defaults to queue binding + channel = create_test_channel( + address="task.{priority}", + binding=None, + ) + + params: EndpointParams = { + "channel": channel, + "parameters": {"priority": "*"}, + "op_bindings": None, + "is_reply": False, + } + + with pytest.raises(ValueError) as exc_info: + resolve_amqp_config(params, "test_op", "test_app") + + assert "wildcard patterns" in str(exc_info.value) + assert "priority=*" in str(exc_info.value) + + +def test_resolve_channel_address_without_binding_accepts_concrete_params(): + """Channel address without explicit binding should accept concrete parameters.""" + channel = create_test_channel( + address="task.{priority}", + binding=None, + ) + + params: EndpointParams = { + "channel": channel, + "parameters": {"priority": "high"}, + "op_bindings": None, + "is_reply": False, + } + + # Should not raise + config = resolve_amqp_config(params, "test_op", "test_app") + + # Verify the address was substituted with concrete value + assert config.queue_name == "task.high" + + +def test_resolve_amqp_config_rejects_missing_parameters(): + """Resolver should validate all required parameters are provided.""" + from asyncapi_python.kernel.document.channel import AddressParameter + + # Create channel with 2 required parameters + channel = create_test_channel( + address="weather.{location}.{severity}", + binding=AmqpChannelBinding(type="queue"), + ) + # Add parameter definitions to channel + channel = Channel( + key=channel.key, + address=channel.address, + title=channel.title, + summary=channel.summary, + description=channel.description, + servers=channel.servers, + messages=channel.messages, + parameters={ + "location": AddressParameter( + key="location", + description="Geographic location", + location=None, + ), + "severity": AddressParameter( + key="severity", + description="Alert severity", + location=None, + ), + }, + tags=channel.tags, + external_docs=channel.external_docs, + bindings=channel.bindings, + ) + + params: EndpointParams = { + "channel": channel, + "parameters": {"location": "NYC"}, # Missing severity + "op_bindings": None, + "is_reply": False, + } + + with pytest.raises(ValueError) as exc_info: + resolve_amqp_config(params, "test_op", "test_app") + + error_msg = str(exc_info.value) + assert "Missing required parameters" in error_msg + assert "severity" in error_msg + + +def test_resolve_amqp_config_rejects_extra_parameters(): + """Resolver should reject extra parameters not defined in channel.""" + from asyncapi_python.kernel.document.channel import AddressParameter + + # Create channel with 1 required parameter + channel = create_test_channel( + address="weather.{location}", + binding=AmqpChannelBinding(type="queue"), + ) + # Add parameter definition to channel + channel = Channel( + key=channel.key, + address=channel.address, + title=channel.title, + summary=channel.summary, + description=channel.description, + servers=channel.servers, + messages=channel.messages, + parameters={ + "location": AddressParameter( + key="location", + description="Geographic location", + location=None, + ), + }, + tags=channel.tags, + external_docs=channel.external_docs, + bindings=channel.bindings, + ) + + params: EndpointParams = { + "channel": channel, + "parameters": { + "location": "NYC", + "severity": "high", # Extra parameter + }, + "op_bindings": None, + "is_reply": False, + } + + with pytest.raises(ValueError) as exc_info: + resolve_amqp_config(params, "test_op", "test_app") + + error_msg = str(exc_info.value) + assert "Unexpected parameters" in error_msg + assert "severity" in error_msg + + +def test_resolve_amqp_config_routing_key_rejects_missing_parameters(): + """Routing key bindings should also validate all parameters are provided.""" + from asyncapi_python.kernel.document.channel import AddressParameter + + # Create channel with routingKey binding + channel = create_test_channel( + address="weather.{location}.{severity}", + binding=AmqpChannelBinding( + type="routingKey", + exchange=AmqpExchange( + name="weather_exchange", + type=AmqpExchangeType.TOPIC, + ), + ), + ) + # Add parameter definitions + channel = Channel( + key=channel.key, + address=channel.address, + title=channel.title, + summary=channel.summary, + description=channel.description, + servers=channel.servers, + messages=channel.messages, + parameters={ + "location": AddressParameter( + key="location", + description="Geographic location", + location=None, + ), + "severity": AddressParameter( + key="severity", + description="Alert severity", + location=None, + ), + }, + tags=channel.tags, + external_docs=channel.external_docs, + bindings=channel.bindings, + ) + + params: EndpointParams = { + "channel": channel, + "parameters": {"location": "NYC"}, # Missing severity + "op_bindings": None, + "is_reply": False, + } + + with pytest.raises(ValueError) as exc_info: + resolve_amqp_config(params, "test_op", "test_app") + + error_msg = str(exc_info.value) + assert "Missing required parameters" in error_msg + assert "severity" in error_msg diff --git a/tests/core/wire/test_amqp_utils.py b/tests/core/wire/test_amqp_utils.py new file mode 100644 index 0000000..d12e828 --- /dev/null +++ b/tests/core/wire/test_amqp_utils.py @@ -0,0 +1,233 @@ +"""Unit tests for universal parameter validation and substitution utilities.""" + +import pytest + +from asyncapi_python.kernel.document.channel import AddressParameter, Channel +from asyncapi_python.kernel.wire.utils import ( + substitute_parameters, + validate_parameters_strict, +) + + +def create_test_channel( + address: str, parameter_keys: list[str] | None = None +) -> Channel: + """Create a test channel with specified parameters.""" + parameters = {} + if parameter_keys: + for key in parameter_keys: + parameters[key] = AddressParameter( + key=key, + description=f"Test parameter {key}", + location=None, + ) + + return Channel( + key="test_channel", + address=address, + title="Test Channel", + summary=None, + description=None, + servers=[], + messages={}, + parameters=parameters, + tags=[], + external_docs=None, + bindings=None, + ) + + +class TestValidateParametersStrict: + """Tests for strict parameter validation.""" + + def test_accepts_exact_match(self): + """Should pass when all required parameters are provided, no extras.""" + channel = create_test_channel( + "weather.{location}.{severity}", ["location", "severity"] + ) + + # Should not raise + validate_parameters_strict(channel, {"location": "NYC", "severity": "high"}) + + def test_rejects_missing_parameters(self): + """Should raise ValueError when required parameters are missing.""" + channel = create_test_channel( + "weather.{location}.{severity}", ["location", "severity"] + ) + + with pytest.raises(ValueError) as exc_info: + validate_parameters_strict(channel, {"location": "NYC"}) + + error_msg = str(exc_info.value) + assert "Missing required parameters" in error_msg + assert "severity" in error_msg + assert "weather.{location}.{severity}" in error_msg + + def test_rejects_multiple_missing_parameters(self): + """Should raise ValueError listing all missing parameters.""" + channel = create_test_channel( + "weather.{location}.{severity}.{priority}", + ["location", "severity", "priority"], + ) + + with pytest.raises(ValueError) as exc_info: + validate_parameters_strict(channel, {"location": "NYC"}) + + error_msg = str(exc_info.value) + assert "Missing required parameters" in error_msg + assert "severity" in error_msg + assert "priority" in error_msg + + def test_rejects_extra_parameters(self): + """Should raise ValueError when unexpected parameters are provided.""" + channel = create_test_channel("weather.{location}", ["location"]) + + with pytest.raises(ValueError) as exc_info: + validate_parameters_strict(channel, {"location": "NYC", "severity": "high"}) + + error_msg = str(exc_info.value) + assert "Unexpected parameters" in error_msg + assert "severity" in error_msg + + def test_rejects_multiple_extra_parameters(self): + """Should raise ValueError listing all extra parameters.""" + channel = create_test_channel("weather.{location}", ["location"]) + + with pytest.raises(ValueError) as exc_info: + validate_parameters_strict( + channel, + {"location": "NYC", "severity": "high", "priority": "1"}, + ) + + error_msg = str(exc_info.value) + assert "Unexpected parameters" in error_msg + assert "severity" in error_msg + assert "priority" in error_msg + + def test_accepts_empty_when_none_defined(self): + """Should pass when channel has no parameters and none provided.""" + channel = create_test_channel("simple.queue", []) + + # Should not raise + validate_parameters_strict(channel, {}) + + def test_rejects_params_when_none_expected(self): + """Should raise ValueError when parameters provided but none defined.""" + channel = create_test_channel("simple.queue", []) + + with pytest.raises(ValueError) as exc_info: + validate_parameters_strict(channel, {"location": "NYC"}) + + error_msg = str(exc_info.value) + assert "Unexpected parameters" in error_msg + assert "location" in error_msg + + def test_rejects_all_missing(self): + """Should raise ValueError when all parameters are missing.""" + channel = create_test_channel( + "weather.{location}.{severity}", ["location", "severity"] + ) + + with pytest.raises(ValueError) as exc_info: + validate_parameters_strict(channel, {}) + + error_msg = str(exc_info.value) + assert "Missing required parameters" in error_msg + assert "location" in error_msg + assert "severity" in error_msg + + def test_rejects_mixed_missing_and_extra(self): + """Should raise ValueError for both missing and extra parameters.""" + channel = create_test_channel( + "weather.{location}.{severity}", ["location", "severity"] + ) + + with pytest.raises(ValueError) as exc_info: + # Missing: severity + # Extra: priority + validate_parameters_strict(channel, {"location": "NYC", "priority": "high"}) + + error_msg = str(exc_info.value) + # Should fail on missing first (that's the implementation order) + assert ( + "Missing required parameters" in error_msg + or "Unexpected parameters" in error_msg + ) + + +class TestSubstituteParameters: + """Tests for parameter substitution.""" + + def test_substitutes_single_parameter(self): + """Should substitute a single parameter correctly.""" + result = substitute_parameters("weather.{location}", {"location": "NYC"}) + assert result == "weather.NYC" + + def test_substitutes_multiple_parameters(self): + """Should substitute multiple parameters correctly.""" + result = substitute_parameters( + "weather.{location}.{severity}", + {"location": "NYC", "severity": "high"}, + ) + assert result == "weather.NYC.high" + + def test_preserves_wildcards_in_values(self): + """Should preserve wildcard characters in parameter values.""" + result = substitute_parameters( + "weather.{location}.{severity}", + {"location": "*", "severity": "high"}, + ) + assert result == "weather.*.high" + + def test_handles_no_parameters(self): + """Should return template unchanged when no parameters defined.""" + result = substitute_parameters("simple.queue", {}) + assert result == "simple.queue" + + def test_fails_on_missing_parameter(self): + """Should raise ValueError when template has placeholder without value.""" + with pytest.raises(ValueError) as exc_info: + substitute_parameters("weather.{location}.{severity}", {"location": "NYC"}) + + error_msg = str(exc_info.value) + assert "undefined parameters" in error_msg + assert "severity" in error_msg + + def test_allows_extra_parameters(self): + """Should allow extra parameters in dict (only uses what's in template).""" + result = substitute_parameters( + "weather.{location}", + {"location": "NYC", "severity": "high"}, # Extra: severity + ) + assert result == "weather.NYC" + + def test_handles_complex_patterns(self): + """Should handle complex patterns with multiple dots and parameters.""" + result = substitute_parameters( + "exchange.{env}.{region}.{service}.{version}", + { + "env": "prod", + "region": "us-east-1", + "service": "api", + "version": "v2", + }, + ) + assert result == "exchange.prod.us-east-1.api.v2" + + def test_handles_parameter_at_start(self): + """Should handle parameter at start of template.""" + result = substitute_parameters("{env}.weather.alerts", {"env": "prod"}) + assert result == "prod.weather.alerts" + + def test_handles_parameter_at_end(self): + """Should handle parameter at end of template.""" + result = substitute_parameters("weather.alerts.{env}", {"env": "prod"}) + assert result == "weather.alerts.prod" + + def test_handles_consecutive_parameters(self): + """Should handle consecutive parameters separated by dot.""" + result = substitute_parameters( + "prefix.{location}.{severity}", + {"location": "NYC", "severity": "high"}, + ) + assert result == "prefix.NYC.high" diff --git a/uv.lock b/uv.lock index bc468fa..8f420ba 100644 --- a/uv.lock +++ b/uv.lock @@ -64,7 +64,7 @@ wheels = [ [[package]] name = "asyncapi-python" -version = "0.3.0rc5" +version = "0.3.0rc6" source = { editable = "." } dependencies = [ { name = "pydantic" },