|
8 | 8 | """ |
9 | 9 |
|
10 | 10 | from __future__ import annotations |
| 11 | +import typing |
11 | 12 |
|
12 | 13 | import marshmallow |
13 | 14 | from marshmallow.utils import is_collection |
@@ -59,9 +60,37 @@ def __init__( |
59 | 60 | self.schema_name_resolver = schema_name_resolver |
60 | 61 | self.spec = spec |
61 | 62 | self.init_attribute_functions() |
| 63 | + self.init_parameter_attribute_functions() |
62 | 64 | # Schema references |
63 | 65 | self.refs: dict = {} |
64 | 66 |
|
| 67 | + def init_parameter_attribute_functions(self) -> None: |
| 68 | + self.parameter_attribute_functions = [ |
| 69 | + self.field2required, |
| 70 | + self.list2param, |
| 71 | + ] |
| 72 | + |
| 73 | + def add_parameter_attribute_function(self, func) -> None: |
| 74 | + """Method to add a field parameter function to the list of field |
| 75 | + parameter functions that will be called on a field to convert it to a |
| 76 | + field parameter. |
| 77 | +
|
| 78 | + :param func func: the field parameter function to add |
| 79 | + The attribute function will be bound to the |
| 80 | + `OpenAPIConverter <apispec.ext.marshmallow.openapi.OpenAPIConverter>` |
| 81 | + instance. |
| 82 | + It will be called for each field in a schema with |
| 83 | + `self <apispec.ext.marshmallow.openapi.OpenAPIConverter>` and a |
| 84 | + `field <marshmallow.fields.Field>` instance |
| 85 | + positional arguments and `ret <dict>` keyword argument. |
| 86 | + May mutate `ret`. |
| 87 | + User added field parameter functions will be called after all built-in |
| 88 | + field parameter functions in the order they were added. |
| 89 | + """ |
| 90 | + bound_func = func.__get__(self) |
| 91 | + setattr(self, func.__name__, bound_func) |
| 92 | + self.parameter_attribute_functions.append(bound_func) |
| 93 | + |
65 | 94 | def resolve_nested_schema(self, schema): |
66 | 95 | """Return the OpenAPI representation of a marshmallow Schema. |
67 | 96 |
|
@@ -150,34 +179,57 @@ def schema2parameters( |
150 | 179 |
|
151 | 180 | def _field2parameter( |
152 | 181 | self, field: marshmallow.fields.Field, *, name: str, location: str |
153 | | - ): |
| 182 | + ) -> dict: |
154 | 183 | """Return an OpenAPI parameter as a `dict`, given a marshmallow |
155 | 184 | :class:`Field <marshmallow.Field>`. |
156 | 185 |
|
157 | 186 | https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#parameterObject |
158 | 187 | """ |
159 | 188 | ret: dict = {"in": location, "name": name} |
160 | 189 |
|
| 190 | + prop = self.field2property(field) |
| 191 | + if self.openapi_version.major < 3: |
| 192 | + ret.update(prop) |
| 193 | + else: |
| 194 | + if "description" in prop: |
| 195 | + ret["description"] = prop.pop("description") |
| 196 | + ret["schema"] = prop |
| 197 | + |
| 198 | + for param_attr_func in self.parameter_attribute_functions: |
| 199 | + ret.update(param_attr_func(field, ret=ret)) |
| 200 | + |
| 201 | + return ret |
| 202 | + |
| 203 | + def field2required( |
| 204 | + self, field: marshmallow.fields.Field, **kwargs: typing.Any |
| 205 | + ) -> dict: |
| 206 | + """Return the dictionary of OpenAPI parameter attributes for a required field. |
| 207 | +
|
| 208 | + :param Field field: A marshmallow field. |
| 209 | + :rtype: dict |
| 210 | + """ |
| 211 | + ret = {} |
161 | 212 | partial = getattr(field.parent, "partial", False) |
162 | 213 | ret["required"] = field.required and ( |
163 | 214 | not partial |
164 | 215 | or (is_collection(partial) and field.name not in partial) # type:ignore |
165 | 216 | ) |
| 217 | + return ret |
166 | 218 |
|
167 | | - prop = self.field2property(field) |
168 | | - multiple = isinstance(field, marshmallow.fields.List) |
| 219 | + def list2param(self, field: marshmallow.fields.Field, **kwargs: typing.Any) -> dict: |
| 220 | + """Return a dictionary of parameter properties from |
| 221 | + :class:`List <marshmallow.fields.List` fields. |
169 | 222 |
|
170 | | - if self.openapi_version.major < 3: |
171 | | - if multiple: |
| 223 | + :param Field field: A marshmallow field. |
| 224 | + :rtype: dict |
| 225 | + """ |
| 226 | + ret: dict = {} |
| 227 | + if isinstance(field, marshmallow.fields.List): |
| 228 | + if self.openapi_version.major < 3: |
172 | 229 | ret["collectionFormat"] = "multi" |
173 | | - ret.update(prop) |
174 | | - else: |
175 | | - if multiple: |
| 230 | + else: |
176 | 231 | ret["explode"] = True |
177 | 232 | ret["style"] = "form" |
178 | | - if prop.get("description", None): |
179 | | - ret["description"] = prop.pop("description") |
180 | | - ret["schema"] = prop |
181 | 233 | return ret |
182 | 234 |
|
183 | 235 | def schema2jsonschema(self, schema): |
|
0 commit comments