Skip to content

Commit 83a49a5

Browse files
committed
marshmallow ext OpenAPIConverter: add add_parameter_attribute_function
1 parent 7f71135 commit 83a49a5

File tree

2 files changed

+105
-23
lines changed

2 files changed

+105
-23
lines changed

src/apispec/ext/marshmallow/openapi.py

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99

1010
from __future__ import annotations
11+
import typing
1112

1213
import marshmallow
1314
from marshmallow.utils import is_collection
@@ -59,9 +60,37 @@ def __init__(
5960
self.schema_name_resolver = schema_name_resolver
6061
self.spec = spec
6162
self.init_attribute_functions()
63+
self.init_parameter_attribute_functions()
6264
# Schema references
6365
self.refs: dict = {}
6466

67+
def init_parameter_attribute_functions(self) -> None:
68+
self.parameter_attribute_functions = [
69+
self.field2required,
70+
self.list2param,
71+
]
72+
73+
def add_parameter_attribute_function(self, func) -> None:
74+
"""Method to add a field parameter function to the list of field
75+
parameter functions that will be called on a field to convert it to a
76+
field parameter.
77+
78+
:param func func: the field parameter function to add
79+
The attribute function will be bound to the
80+
`OpenAPIConverter <apispec.ext.marshmallow.openapi.OpenAPIConverter>`
81+
instance.
82+
It will be called for each field in a schema with
83+
`self <apispec.ext.marshmallow.openapi.OpenAPIConverter>` and a
84+
`field <marshmallow.fields.Field>` instance
85+
positional arguments and `ret <dict>` keyword argument.
86+
May mutate `ret`.
87+
User added field parameter functions will be called after all built-in
88+
field parameter functions in the order they were added.
89+
"""
90+
bound_func = func.__get__(self)
91+
setattr(self, func.__name__, bound_func)
92+
self.parameter_attribute_functions.append(bound_func)
93+
6594
def resolve_nested_schema(self, schema):
6695
"""Return the OpenAPI representation of a marshmallow Schema.
6796
@@ -150,30 +179,52 @@ def schema2parameters(
150179

151180
def _field2parameter(
152181
self, field: marshmallow.fields.Field, *, name: str, location: str
153-
):
182+
) -> dict:
154183
"""Return an OpenAPI parameter as a `dict`, given a marshmallow
155184
:class:`Field <marshmallow.Field>`.
156185
157186
https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject
158187
"""
159188
ret: dict = {"in": location, "name": name}
160189

161-
partial = getattr(field.parent, "partial", False)
162-
ret["required"] = field.required and (
163-
not partial
164-
or (is_collection(partial) and field.name not in partial) # type:ignore
165-
)
166-
167190
prop = self.field2property(field)
168191
if self.openapi_version.major < 3:
169192
ret.update(prop)
170193
else:
171-
if prop.get("description", None):
194+
if "description" in prop:
172195
ret["description"] = prop.pop("description")
173196
ret["schema"] = prop
174197

175-
multiple = isinstance(field, marshmallow.fields.List)
176-
if multiple:
198+
for param_attr_func in self.parameter_attribute_functions:
199+
ret.update(param_attr_func(field, ret=ret))
200+
201+
return ret
202+
203+
def field2required(
204+
self, field: marshmallow.fields.Field, **kwargs: typing.Any
205+
) -> dict:
206+
"""Return the dictionary of OpenAPI parameter attributes for a required field.
207+
208+
:param Field field: A marshmallow field.
209+
:rtype: dict
210+
"""
211+
ret = {}
212+
partial = getattr(field.parent, "partial", False)
213+
ret["required"] = field.required and (
214+
not partial
215+
or (is_collection(partial) and field.name not in partial) # type:ignore
216+
)
217+
return ret
218+
219+
def list2param(self, field: marshmallow.fields.Field, **kwargs: typing.Any) -> dict:
220+
"""Return a dictionary of parameter properties from
221+
:class:`List <marshmallow.fields.List` fields.
222+
223+
:param Field field: A marshmallow field.
224+
:rtype: dict
225+
"""
226+
ret: dict = {}
227+
if isinstance(field, marshmallow.fields.List):
177228
if self.openapi_version.major < 3:
178229
ret["collectionFormat"] = "multi"
179230
else:

tests/test_ext_marshmallow_openapi.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -207,20 +207,36 @@ class NotASchema:
207207

208208

209209
class TestMarshmallowSchemaToParameters:
210-
@pytest.mark.parametrize("ListClass", [fields.List, CustomList])
211-
def test_field_multiple(self, ListClass, openapi):
212-
field = ListClass(fields.Str)
213-
res = openapi._field2parameter(field, name="field", location="query")
214-
assert res["in"] == "query"
215-
if openapi.openapi_version.major < 3:
216-
assert res["type"] == "array"
217-
assert res["items"]["type"] == "string"
218-
assert res["collectionFormat"] == "multi"
210+
def test_custom_properties_for_custom_fields(self, spec_fixture):
211+
class DelimitedList(fields.List):
212+
"""Delimited list field"""
213+
214+
def delimited_list2param(self, field, **kwargs):
215+
ret: dict = {}
216+
if isinstance(field, DelimitedList):
217+
if self.openapi_version.major < 3:
218+
ret["collectionFormat"] = "csv"
219+
else:
220+
ret["explode"] = False
221+
ret["style"] = "form"
222+
return ret
223+
224+
spec_fixture.marshmallow_plugin.converter.add_parameter_attribute_function(
225+
delimited_list2param
226+
)
227+
228+
class MySchema(Schema):
229+
delimited_list = DelimitedList(fields.Int)
230+
231+
param = spec_fixture.marshmallow_plugin.converter.schema2parameters(
232+
MySchema(), location="query"
233+
)[0]
234+
235+
if spec_fixture.openapi.openapi_version.major < 3:
236+
assert param["collectionFormat"] == "csv"
219237
else:
220-
assert res["schema"]["type"] == "array"
221-
assert res["schema"]["items"]["type"] == "string"
222-
assert res["style"] == "form"
223-
assert res["explode"] is True
238+
assert param["explode"] is False
239+
assert param["style"] == "form"
224240

225241
def test_field_required(self, openapi):
226242
field = fields.Str(required=True)
@@ -252,6 +268,21 @@ class UserSchema(Schema):
252268
param = next(p for p in res_nodump if p["name"] == "partial_field")
253269
assert param["required"] is False
254270

271+
@pytest.mark.parametrize("ListClass", [fields.List, CustomList])
272+
def test_field_list(self, ListClass, openapi):
273+
field = ListClass(fields.Str)
274+
res = openapi._field2parameter(field, name="field", location="query")
275+
assert res["in"] == "query"
276+
if openapi.openapi_version.major < 3:
277+
assert res["type"] == "array"
278+
assert res["items"]["type"] == "string"
279+
assert res["collectionFormat"] == "multi"
280+
else:
281+
assert res["schema"]["type"] == "array"
282+
assert res["schema"]["items"]["type"] == "string"
283+
assert res["style"] == "form"
284+
assert res["explode"] is True
285+
255286
# json/body is invalid for OpenAPI 3
256287
@pytest.mark.parametrize("openapi", ("2.0",), indirect=True)
257288
def test_schema_body(self, openapi):

0 commit comments

Comments
 (0)