22
33from dataclasses import dataclass , field
44from enum import Enum
5- from typing import Dict , List , Optional , Set
6-
7- import stringcase
5+ from typing import Dict , List , Optional , Set , Iterable , Generator
86
97from .properties import Property , property_from_dict , ListProperty , RefProperty , EnumProperty
10- from .responses import Response , response_from_dict
118from .reference import Reference
9+ from .responses import Response , response_from_dict
1210
1311
14- class ParameterLocation (Enum ):
12+ class ParameterLocation (str , Enum ):
1513 """ The places Parameters can be put when calling an Endpoint """
1614
1715 QUERY = "query"
1816 PATH = "path"
1917
2018
21- @dataclass
22- class Parameter :
23- """ A parameter in an Endpoint """
24-
25- location : ParameterLocation
26- property : Property
27-
28- @staticmethod
29- def from_dict (d : Dict , / ) -> Parameter :
30- """ Construct a parameter from it's OpenAPI dict form """
31- return Parameter (
32- location = ParameterLocation (d ["in" ]),
33- property = property_from_dict (name = d ["name" ], required = d ["required" ], data = d ["schema" ]),
34- )
35-
36-
37- def _import_string_from_reference (reference : Reference , prefix : str = "" ) -> str :
19+ def import_string_from_reference (reference : Reference , prefix : str = "" ) -> str :
20+ """ Create a string which is used to import a reference """
3821 return f"from { prefix } .{ reference .module_name } import { reference .class_name } "
3922
4023
@@ -52,11 +35,24 @@ def from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, EndpointCollection]
5235 endpoints_by_tag : Dict [str , EndpointCollection ] = {}
5336 for path , path_data in d .items ():
5437 for method , method_data in path_data .items ():
55- parameters : List [Parameter ] = []
38+ query_parameters : List [Property ] = []
39+ path_parameters : List [Property ] = []
5640 responses : List [Response ] = []
57- for param_dict in method_data .get ("parameters" , []):
58- parameters .append (Parameter .from_dict (param_dict ))
5941 tag = method_data .get ("tags" , ["default" ])[0 ]
42+ collection = endpoints_by_tag .setdefault (tag , EndpointCollection (tag = tag ))
43+ for param_dict in method_data .get ("parameters" , []):
44+ prop = property_from_dict (
45+ name = param_dict ["name" ], required = param_dict ["required" ], data = param_dict ["schema" ]
46+ )
47+ if isinstance (prop , (ListProperty , RefProperty , EnumProperty )) and prop .reference :
48+ collection .relative_imports .add (import_string_from_reference (prop .reference , prefix = "..models" ))
49+ if param_dict ["in" ] == ParameterLocation .QUERY :
50+ query_parameters .append (prop )
51+ elif param_dict ["in" ] == ParameterLocation .PATH :
52+ path_parameters .append (prop )
53+ else :
54+ raise ValueError (f"Don't know where to put this parameter: { param_dict } " )
55+
6056 for code , response_dict in method_data ["responses" ].items ():
6157 response = response_from_dict (status_code = int (code ), data = response_dict )
6258 responses .append (response )
@@ -69,17 +65,17 @@ def from_dict(d: Dict[str, Dict[str, Dict]], /) -> Dict[str, EndpointCollection]
6965 method = method ,
7066 description = method_data .get ("description" ),
7167 name = method_data ["operationId" ],
72- parameters = parameters ,
68+ query_parameters = query_parameters ,
69+ path_parameters = path_parameters ,
7370 responses = responses ,
7471 form_body_reference = form_body_reference ,
7572 requires_security = method_data .get ("security" ),
7673 )
7774
78- collection = endpoints_by_tag .setdefault (tag , EndpointCollection (tag = tag ))
7975 collection .endpoints .append (endpoint )
8076 if form_body_reference :
8177 collection .relative_imports .add (
82- _import_string_from_reference (form_body_reference , prefix = "..models" )
78+ import_string_from_reference (form_body_reference , prefix = "..models" )
8379 )
8480 return endpoints_by_tag
8581
@@ -94,7 +90,8 @@ class Endpoint:
9490 method : str
9591 description : Optional [str ]
9692 name : str
97- parameters : List [Parameter ]
93+ query_parameters : List [Property ]
94+ path_parameters : List [Property ]
9895 responses : List [Response ]
9996 requires_security : bool
10097 form_body_reference : Optional [Reference ]
@@ -118,7 +115,7 @@ class Schema:
118115 These will all be converted to dataclasses in the client
119116 """
120117
121- title : str
118+ reference : Reference
122119 required_properties : List [Property ]
123120 optional_properties : List [Property ]
124121 description : str
@@ -139,10 +136,10 @@ def from_dict(d: Dict, /) -> Schema:
139136 required_properties .append (p )
140137 else :
141138 optional_properties .append (p )
142- if isinstance (p , (ListProperty , RefProperty )) and p .reference :
143- relative_imports .add (_import_string_from_reference (p .reference ))
139+ if isinstance (p , (ListProperty , RefProperty , EnumProperty )) and p .reference :
140+ relative_imports .add (import_string_from_reference (p .reference ))
144141 schema = Schema (
145- title = stringcase . pascalcase (d ["title" ]),
142+ reference = Reference (d ["title" ]),
146143 required_properties = required_properties ,
147144 optional_properties = optional_properties ,
148145 relative_imports = relative_imports ,
@@ -156,7 +153,7 @@ def dict(d: Dict, /) -> Dict[str, Schema]:
156153 result = {}
157154 for data in d .values ():
158155 s = Schema .from_dict (data )
159- result [s .title ] = s
156+ result [s .reference . class_name ] = s
160157 return result
161158
162159
@@ -172,29 +169,44 @@ class OpenAPI:
172169 endpoint_collections_by_tag : Dict [str , EndpointCollection ]
173170 enums : Dict [str , EnumProperty ]
174171
172+ @staticmethod
173+ def check_enums (schemas : Iterable [Schema ], collections : Iterable [EndpointCollection ]) -> Dict [str , EnumProperty ]:
174+ enums : Dict [str , EnumProperty ] = {}
175+
176+ def _iterate_properties () -> Generator [Property ]:
177+ for schema in schemas :
178+ yield from schema .required_properties
179+ yield from schema .optional_properties
180+ for collection in collections :
181+ for endpoint in collection .endpoints :
182+ yield from endpoint .path_parameters
183+ yield from endpoint .query_parameters
184+
185+ for prop in _iterate_properties ():
186+ if not isinstance (prop , EnumProperty ):
187+ continue
188+
189+ if prop .reference .class_name in enums :
190+ # We already have an enum with this name, make sure the values match
191+ assert (
192+ prop .values == enums [prop .reference .class_name ].values
193+ ), f"Encountered conflicting enum named { prop .reference .class_name } "
194+
195+ enums [prop .reference .class_name ] = prop
196+ return enums
197+
175198 @staticmethod
176199 def from_dict (d : Dict , / ) -> OpenAPI :
177200 """ Create an OpenAPI from dict """
178201 schemas = Schema .dict (d ["components" ]["schemas" ])
179- enums : Dict [str , EnumProperty ] = {}
180- for schema in schemas .values ():
181- for prop in schema .required_properties + schema .optional_properties :
182- if not isinstance (prop , EnumProperty ):
183- continue
184- schema .relative_imports .add (f"from .{ prop .name } import { prop .class_name } " )
185- if prop .class_name in enums :
186- # We already have an enum with this name, make sure the values match
187- assert (
188- prop .values == enums [prop .class_name ].values
189- ), f"Encountered conflicting enum named { prop .class_name } "
190-
191- enums [prop .class_name ] = prop
202+ endpoint_collections_by_tag = EndpointCollection .from_dict (d ["paths" ])
203+ enums = OpenAPI .check_enums (schemas .values (), endpoint_collections_by_tag .values ())
192204
193205 return OpenAPI (
194206 title = d ["info" ]["title" ],
195207 description = d ["info" ]["description" ],
196208 version = d ["info" ]["version" ],
197- endpoint_collections_by_tag = EndpointCollection . from_dict ( d [ "paths" ]) ,
209+ endpoint_collections_by_tag = endpoint_collections_by_tag ,
198210 schemas = schemas ,
199211 security_schemes = d ["components" ]["securitySchemes" ],
200212 enums = enums ,
0 commit comments