Skip to content

Commit c21022d

Browse files
committed
WIP
1 parent 7a43217 commit c21022d

File tree

1 file changed

+61
-39
lines changed

1 file changed

+61
-39
lines changed

pulp-glue/pulp_glue/common/openapi.py

Lines changed: 61 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from urllib.parse import urlencode, urljoin
1313

1414
import aiofiles
15+
import aiofiles.os
1516
import aiohttp
1617
import requests
1718
import urllib3
@@ -199,7 +200,7 @@ def __init__(
199200
self._base_url: str = base_url
200201
self._doc_path: str = doc_path
201202
self._dry_run: bool = dry_run
202-
self._headers = CIMultiDict(headers or {})
203+
self._headers = CIMultiDict(headers)
203204
self._verify_ssl = verify_ssl
204205

205206
self._auth_provider = auth_provider
@@ -268,6 +269,9 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
268269
return _ssl_context
269270

270271
def load_api(self, refresh_cache: bool = False) -> None:
272+
asyncio.run(self._load_api(refresh_cache=refresh_cache))
273+
274+
async def _load_api(self, refresh_cache: bool = False) -> None:
271275
# TODO: Find a way to invalidate caches on upstream change
272276
xdg_cache_home: str = os.environ.get("XDG_CACHE_HOME") or "~/.cache"
273277
apidoc_cache: str = os.path.join(
@@ -279,17 +283,17 @@ def load_api(self, refresh_cache: bool = False) -> None:
279283
try:
280284
if refresh_cache:
281285
raise IOError()
282-
with open(apidoc_cache, "rb") as f:
283-
data: bytes = f.read()
286+
async with aiofiles.open(apidoc_cache, mode="rb") as f:
287+
data: bytes = await f.read()
284288
self._parse_api(data)
285289
except Exception:
286290
# Try again with a freshly downloaded version
287-
data = asyncio.run(self._download_api())
291+
data = await self._download_api()
288292
self._parse_api(data)
289293
# Write to cache as it seems to be valid
290-
os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
291-
with open(apidoc_cache, "bw") as f:
292-
f.write(data)
294+
await aiofiles.os.makedirs(os.path.dirname(apidoc_cache), exist_ok=True)
295+
async with aiofiles.open(apidoc_cache, mode="bw") as f:
296+
await f.write(data)
293297

294298
def _parse_api(self, data: bytes) -> None:
295299
self.api_spec: t.Dict[str, t.Any] = json.loads(data)
@@ -306,11 +310,10 @@ def _parse_api(self, data: bytes) -> None:
306310

307311
async def _download_api(self) -> bytes:
308312
try:
309-
connector = aiohttp.TCPConnector(ssl=self.ssl_context)
310-
async with aiohttp.ClientSession(
311-
connector=connector, headers=self._headers, ssl=self.ssl_context
312-
) as session:
313-
async with session.get(urljoin(self._base_url, self._doc_path)) as response:
313+
async with aiohttp.ClientSession(headers=self._headers) as session:
314+
async with session.get(
315+
urljoin(self._base_url, self._doc_path), ssl=self.ssl_context
316+
) as response:
314317
response.raise_for_status()
315318
data = await response.read()
316319
if "Correlation-Id" in response.headers:
@@ -508,7 +511,20 @@ def _render_request_body(
508511

509512
return content_type, data, files
510513

511-
def _send_request(
514+
async def _http_middleware(
515+
self, request: aiohttp.ClientRequest, handler: aiohttp.ClientHandlerType
516+
) -> aiohttp.ClientResponse:
517+
for key, value in request.headers.items():
518+
self._debug_callback(2, f" {key}: {value}")
519+
if request.body is not None:
520+
self._debug_callback(3, f"{request.body!r}")
521+
if self._dry_run and request.method.upper() not in SAFE_METHODS:
522+
raise UnsafeCallError(_("Call aborted due to safe mode"))
523+
524+
response = await handler(request)
525+
return response
526+
527+
async def _send_request(
512528
self,
513529
path_spec: t.Dict[str, t.Any],
514530
method: str,
@@ -519,6 +535,7 @@ def _send_request(
519535
validate_body: bool = True,
520536
) -> _Response:
521537
method_spec = path_spec[method]
538+
_headers = CIMultiDict(headers)
522539
content_type, data, files = self._render_request_body(method_spec, body, validate_body)
523540
security: t.Optional[t.List[t.Dict[str, t.List[str]]]] = method_spec.get(
524541
"security", self.api_spec.get("security")
@@ -536,14 +553,14 @@ def _send_request(
536553
# For we encode the json on our side.
537554
# Somehow this does not work properly for multipart...
538555
if content_type is not None and content_type.startswith("application/json"):
539-
headers["content-type"] = content_type
556+
_headers["Content-Type"] = content_type
540557
request = self._session.prepare_request(
541558
requests.Request(
542559
method,
543560
url,
544561
auth=auth,
545562
params=params,
546-
headers=headers,
563+
headers=_headers,
547564
data=data,
548565
files=files,
549566
)
@@ -552,35 +569,40 @@ def _send_request(
552569
assert request.headers["content-type"].startswith(
553570
content_type
554571
), f"{request.headers['content-type']} != {content_type}"
555-
for key, value in request.headers.items():
556-
self._debug_callback(2, f" {key}: {value}")
557-
if request.body is not None:
558-
self._debug_callback(3, f"{request.body!r}")
559-
if self._dry_run and method.upper() not in SAFE_METHODS:
560-
raise UnsafeCallError(_("Call aborted due to safe mode"))
561572
try:
562-
response = self._session.send(request)
573+
async with aiohttp.ClientSession(
574+
headers=self._headers, middlewares=(self._http_middleware,)
575+
) as session:
576+
async with session.request(
577+
method,
578+
url,
579+
params=params,
580+
headers=_headers,
581+
data=data,
582+
# files=files,
583+
ssl=self.ssl_context,
584+
) as response:
585+
response.raise_for_status()
586+
response_body = await response.read()
563587
except requests.TooManyRedirects as e:
564588
assert e.response is not None
565589
raise OpenAPIError(
566-
_("Received redirect to '{url}'. Please check your CLI configuration.").format(
590+
_("Received redirect to '{url}'. Please check your configuration.").format(
567591
url=e.response.headers["location"]
568592
)
569593
)
570594
except requests.RequestException as e:
571595
raise OpenAPIError(str(e))
572-
self._debug_callback(
573-
1, _("Response: {status_code}").format(status_code=response.status_code)
574-
)
596+
self._debug_callback(1, _("Response: {status_code}").format(status_code=response.status))
575597
for key, value in response.headers.items():
576598
self._debug_callback(2, f" {key}: {value}")
577599
if response.text:
578600
self._debug_callback(3, f"{response.text}")
579601
if "Correlation-Id" in response.headers:
580602
self._set_correlation_id(response.headers["Correlation-Id"])
581-
if response.status_code == 401:
603+
if response.status == 401:
582604
raise PulpAuthenticationFailed(method_spec["operationId"])
583-
if response.status_code == 403:
605+
if response.status == 403:
584606
raise PulpNotAutorized(method_spec["operationId"])
585607
try:
586608
response.raise_for_status()
@@ -589,9 +611,7 @@ def _send_request(
589611
raise PulpHTTPError(str(e.response.text), e.response.status_code)
590612
else:
591613
raise PulpException(str(e))
592-
return _Response(
593-
status_code=response.status_code, headers=response.headers, body=response.content
594-
)
614+
return _Response(status_code=response.status, headers=response.headers, body=response_body)
595615

596616
def _parse_response(self, method_spec: t.Dict[str, t.Any], response: _Response) -> t.Any:
597617
if response.status_code == 204:
@@ -676,14 +696,16 @@ def call(
676696
2, "\n".join([f" {key}=={value}" for key, value in query_params.items()])
677697
)
678698

679-
response = self._send_request(
680-
path_spec,
681-
method,
682-
url,
683-
query_params,
684-
headers,
685-
body,
686-
validate_body=validate_body,
699+
response = asyncio.run(
700+
self._send_request(
701+
path_spec,
702+
method,
703+
url,
704+
query_params,
705+
headers,
706+
body,
707+
validate_body=validate_body,
708+
)
687709
)
688710

689711
return self._parse_response(method_spec, response)

0 commit comments

Comments
 (0)