Skip to content

Commit e5e8406

Browse files
authored
Merge pull request #811 from paradoxxxzero/dev
Allow openapi_version as str in marshmallow OpenAPIConverter
2 parents 986b464 + b3ecd0a commit e5e8406

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/apispec/ext/marshmallow/openapi.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
class OpenAPIConverter(FieldConverterMixin):
4141
"""Adds methods for generating OpenAPI specification from marshmallow schemas and fields.
4242
43-
:param Version openapi_version: The OpenAPI version to use.
43+
:param Version|str openapi_version: The OpenAPI version to use.
4444
Should be in the form '2.x' or '3.x.x' to comply with the OpenAPI standard.
4545
:param callable schema_name_resolver: Callable to generate the schema definition name.
4646
Receives the `Schema` class and returns the name to be used in refs within
@@ -52,11 +52,15 @@ class OpenAPIConverter(FieldConverterMixin):
5252

5353
def __init__(
5454
self,
55-
openapi_version: Version,
55+
openapi_version: Version | str,
5656
schema_name_resolver,
5757
spec: APISpec,
5858
) -> None:
59-
self.openapi_version = openapi_version
59+
self.openapi_version = (
60+
Version(openapi_version)
61+
if isinstance(openapi_version, str)
62+
else openapi_version
63+
)
6064
self.schema_name_resolver = schema_name_resolver
6165
self.spec = spec
6266
self.init_attribute_functions()

tests/test_ext_marshmallow_openapi.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from marshmallow import EXCLUDE, fields, INCLUDE, RAISE, Schema, validate
66

7-
from apispec.ext.marshmallow import MarshmallowPlugin
7+
from apispec.ext.marshmallow import MarshmallowPlugin, OpenAPIConverter
88
from apispec import exceptions, APISpec
9+
from packaging.version import Version
910

1011
from .schemas import CustomList, CustomStringField
1112
from .utils import get_schemas, build_ref, validate_spec
@@ -608,6 +609,15 @@ def test_openapi_tools_validate_v3():
608609
pytest.fail(str(error))
609610

610611

612+
def test_openapi_converter_openapi_version_types():
613+
converter_with_version = OpenAPIConverter(Version("3.1"), None, None)
614+
converter_with_str_version = OpenAPIConverter("3.1", None, None)
615+
assert (
616+
converter_with_version.openapi_version
617+
== converter_with_str_version.openapi_version
618+
)
619+
620+
611621
class TestFieldValidation:
612622
class ValidationSchema(Schema):
613623
id = fields.Int(dump_only=True)

0 commit comments

Comments
 (0)