Skip to content

Commit 29fd373

Browse files
addressing copilot suggestions
1 parent e60dcab commit 29fd373

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

aws_lambda_powertools/event_handler/openapi/exceptions.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,5 +49,3 @@ class SchemaValidationError(ValidationException):
4949

5050
class OpenAPIMergeError(Exception):
5151
"""Exception raised when there's a conflict during OpenAPI merge."""
52-
53-
pass

aws_lambda_powertools/event_handler/openapi/merge.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ def _discover_resolver_files(
110110
found_files: set[Path] = set()
111111

112112
for pat in patterns:
113-
# Add recursive prefix if needed
114-
glob_pattern = f"**/{pat}" if recursive and not pat.startswith("**/") else pat
113+
# Handle recursive flag: add **/ prefix if recursive, strip **/ if not
114+
if recursive and not pat.startswith("**/"):
115+
glob_pattern = f"**/{pat}"
116+
elif not recursive and pat.startswith("**/"):
117+
glob_pattern = pat[3:] # Strip **/ prefix
118+
else:
119+
glob_pattern = pat
115120

116121
for file_path in root.glob(glob_pattern):
117122
if file_path.is_file() and not _is_excluded(file_path, root, exclude):
@@ -289,6 +294,7 @@ def __init__(
289294
self._discovered_files: list[Path] = []
290295
self._resolver_name: str = "app"
291296
self._on_conflict = on_conflict
297+
self._cached_schema: dict[str, Any] | None = None
292298

293299
def discover(
294300
self,
@@ -339,15 +345,23 @@ def discover(
339345
return self._discovered_files
340346

341347
def add_file(self, file_path: str | Path, resolver_name: str | None = None) -> None:
342-
"""Add a specific file to be included in the merge."""
348+
"""Add a specific file to be included in the merge.
349+
350+
Note: Must be called before get_openapi_schema(). Adding files after
351+
schema generation will not affect the cached result.
352+
"""
343353
path = Path(file_path).resolve()
344354
if path not in self._discovered_files:
345355
self._discovered_files.append(path)
346356
if resolver_name:
347357
self._resolver_name = resolver_name
348358

349359
def add_schema(self, schema: dict[str, Any]) -> None:
350-
"""Add a pre-generated OpenAPI schema to be merged."""
360+
"""Add a pre-generated OpenAPI schema to be merged.
361+
362+
Note: Must be called before get_openapi_schema(). Adding schemas after
363+
schema generation will not affect the cached result.
364+
"""
351365
self._schemas.append(_model_to_dict(schema))
352366

353367
def get_openapi_schema(self) -> dict[str, Any]:
@@ -357,6 +371,8 @@ def get_openapi_schema(self) -> dict[str, Any]:
357371
Loads all discovered resolver files, extracts their OpenAPI schemas,
358372
and merges them into a single unified specification.
359373
374+
The schema is cached after the first generation for performance.
375+
360376
Returns
361377
-------
362378
dict[str, Any]
@@ -367,6 +383,9 @@ def get_openapi_schema(self) -> dict[str, Any]:
367383
OpenAPIMergeError
368384
If on_conflict="error" and duplicate path+method combinations are found.
369385
"""
386+
if self._cached_schema is not None:
387+
return self._cached_schema
388+
370389
# Load schemas from discovered files
371390
for file_path in self._discovered_files:
372391
try:
@@ -376,7 +395,8 @@ def get_openapi_schema(self) -> dict[str, Any]:
376395
except (ImportError, AttributeError, FileNotFoundError) as e: # pragma: no cover
377396
logger.warning(f"Failed to load resolver from {file_path}: {e}")
378397

379-
return self._merge_schemas()
398+
self._cached_schema = self._merge_schemas()
399+
return self._cached_schema
380400

381401
def get_openapi_json_schema(self) -> str:
382402
"""
@@ -486,7 +506,12 @@ def _handle_conflict(self, method: str, path: str, target: dict, operation: Any)
486506
target[path][method] = operation
487507

488508
def _merge_components(self, source: dict[str, Any], target: dict[str, dict[str, Any]]) -> None:
489-
"""Merge components from source into target."""
509+
"""Merge components from source into target.
510+
511+
Note: Components with the same name are silently overwritten (last wins).
512+
This is intentional as component conflicts are typically user errors
513+
(e.g., two handlers defining different 'User' schemas).
514+
"""
490515
for component_type, components in source.items():
491516
target.setdefault(component_type, {}).update(components)
492517

tests/functional/event_handler/_pydantic/test_openapi_merge.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,19 @@ def test_openapi_merge_tags_from_schema():
351351
schema = merge.get_openapi_schema()
352352
tag_names = [t["name"] for t in schema.get("tags", [])]
353353
assert "handler-tag" in tag_names
354+
355+
356+
def test_openapi_merge_schema_is_cached():
357+
# GIVEN an OpenAPIMerge with discovered files
358+
merge = OpenAPIMerge(title="Cached API", version="1.0.0")
359+
merge.discover(path=MERGE_HANDLERS_PATH, pattern="**/users_handler.py")
360+
361+
# WHEN calling get_openapi_schema multiple times
362+
schema1 = merge.get_openapi_schema()
363+
schema2 = merge.get_openapi_schema()
364+
365+
# THEN it should return the same cached object
366+
assert schema1 is schema2
367+
368+
# AND paths should not be duplicated
369+
assert len([p for p in schema1["paths"] if p == "/users"]) == 1

0 commit comments

Comments
 (0)