1212from urllib .parse import urlencode , urljoin
1313
1414import aiofiles
15+ import aiofiles .os
1516import aiohttp
1617import requests
1718import urllib3
@@ -199,7 +200,7 @@ def __init__(
199200 self ._base_url : str = base_url
200201 self ._doc_path : str = doc_path
201202 self ._dry_run : bool = dry_run
202- self ._headers = CIMultiDict (headers or {} )
203+ self ._headers = CIMultiDict (headers )
203204 self ._verify_ssl = verify_ssl
204205
205206 self ._auth_provider = auth_provider
@@ -268,6 +269,9 @@ def ssl_context(self) -> t.Union[ssl.SSLContext, bool]:
268269 return _ssl_context
269270
270271 def load_api (self , refresh_cache : bool = False ) -> None :
272+ asyncio .run (self ._load_api (refresh_cache = refresh_cache ))
273+
274+ async def _load_api (self , refresh_cache : bool = False ) -> None :
271275 # TODO: Find a way to invalidate caches on upstream change
272276 xdg_cache_home : str = os .environ .get ("XDG_CACHE_HOME" ) or "~/.cache"
273277 apidoc_cache : str = os .path .join (
@@ -279,17 +283,17 @@ def load_api(self, refresh_cache: bool = False) -> None:
279283 try :
280284 if refresh_cache :
281285 raise IOError ()
282- with open (apidoc_cache , "rb" ) as f :
283- data : bytes = f .read ()
286+ async with aiofiles . open (apidoc_cache , mode = "rb" ) as f :
287+ data : bytes = await f .read ()
284288 self ._parse_api (data )
285289 except Exception :
286290 # Try again with a freshly downloaded version
287- data = asyncio . run ( self ._download_api () )
291+ data = await self ._download_api ()
288292 self ._parse_api (data )
289293 # Write to cache as it seems to be valid
290- os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
291- with open (apidoc_cache , "bw" ) as f :
292- f .write (data )
294+ await aiofiles . os .makedirs (os .path .dirname (apidoc_cache ), exist_ok = True )
295+ async with aiofiles . open (apidoc_cache , mode = "bw" ) as f :
296+ await f .write (data )
293297
294298 def _parse_api (self , data : bytes ) -> None :
295299 self .api_spec : t .Dict [str , t .Any ] = json .loads (data )
@@ -306,11 +310,10 @@ def _parse_api(self, data: bytes) -> None:
306310
307311 async def _download_api (self ) -> bytes :
308312 try :
309- connector = aiohttp .TCPConnector (ssl = self .ssl_context )
310- async with aiohttp .ClientSession (
311- connector = connector , headers = self ._headers , ssl = self .ssl_context
312- ) as session :
313- async with session .get (urljoin (self ._base_url , self ._doc_path )) as response :
313+ async with aiohttp .ClientSession (headers = self ._headers ) as session :
314+ async with session .get (
315+ urljoin (self ._base_url , self ._doc_path ), ssl = self .ssl_context
316+ ) as response :
314317 response .raise_for_status ()
315318 data = await response .read ()
316319 if "Correlation-Id" in response .headers :
@@ -508,7 +511,20 @@ def _render_request_body(
508511
509512 return content_type , data , files
510513
511- def _send_request (
514+ async def _http_middleware (
515+ self , request : aiohttp .ClientRequest , handler : aiohttp .ClientHandlerType
516+ ) -> aiohttp .ClientResponse :
517+ for key , value in request .headers .items ():
518+ self ._debug_callback (2 , f" { key } : { value } " )
519+ if request .body is not None :
520+ self ._debug_callback (3 , f"{ request .body !r} " )
521+ if self ._dry_run and request .method .upper () not in SAFE_METHODS :
522+ raise UnsafeCallError (_ ("Call aborted due to safe mode" ))
523+
524+ response = await handler (request )
525+ return response
526+
527+ async def _send_request (
512528 self ,
513529 path_spec : t .Dict [str , t .Any ],
514530 method : str ,
@@ -519,6 +535,7 @@ def _send_request(
519535 validate_body : bool = True ,
520536 ) -> _Response :
521537 method_spec = path_spec [method ]
538+ _headers = CIMultiDict (headers )
522539 content_type , data , files = self ._render_request_body (method_spec , body , validate_body )
523540 security : t .Optional [t .List [t .Dict [str , t .List [str ]]]] = method_spec .get (
524541 "security" , self .api_spec .get ("security" )
@@ -536,14 +553,14 @@ def _send_request(
536553 # For we encode the json on our side.
537554 # Somehow this does not work properly for multipart...
538555 if content_type is not None and content_type .startswith ("application/json" ):
539- headers [ "content-type " ] = content_type
556+ _headers [ "Content-Type " ] = content_type
540557 request = self ._session .prepare_request (
541558 requests .Request (
542559 method ,
543560 url ,
544561 auth = auth ,
545562 params = params ,
546- headers = headers ,
563+ headers = _headers ,
547564 data = data ,
548565 files = files ,
549566 )
@@ -552,35 +569,40 @@ def _send_request(
552569 assert request .headers ["content-type" ].startswith (
553570 content_type
554571 ), f"{ request .headers ['content-type' ]} != { content_type } "
555- for key , value in request .headers .items ():
556- self ._debug_callback (2 , f" { key } : { value } " )
557- if request .body is not None :
558- self ._debug_callback (3 , f"{ request .body !r} " )
559- if self ._dry_run and method .upper () not in SAFE_METHODS :
560- raise UnsafeCallError (_ ("Call aborted due to safe mode" ))
561572 try :
562- response = self ._session .send (request )
573+ async with aiohttp .ClientSession (
574+ headers = self ._headers , middlewares = (self ._http_middleware ,)
575+ ) as session :
576+ async with session .request (
577+ method ,
578+ url ,
579+ params = params ,
580+ headers = _headers ,
581+ data = data ,
582+ # files=files,
583+ ssl = self .ssl_context ,
584+ ) as response :
585+ response .raise_for_status ()
586+ response_body = await response .read ()
563587 except requests .TooManyRedirects as e :
564588 assert e .response is not None
565589 raise OpenAPIError (
566- _ ("Received redirect to '{url}'. Please check your CLI configuration." ).format (
590+ _ ("Received redirect to '{url}'. Please check your configuration." ).format (
567591 url = e .response .headers ["location" ]
568592 )
569593 )
570594 except requests .RequestException as e :
571595 raise OpenAPIError (str (e ))
572- self ._debug_callback (
573- 1 , _ ("Response: {status_code}" ).format (status_code = response .status_code )
574- )
596+ self ._debug_callback (1 , _ ("Response: {status_code}" ).format (status_code = response .status ))
575597 for key , value in response .headers .items ():
576598 self ._debug_callback (2 , f" { key } : { value } " )
577599 if response .text :
578600 self ._debug_callback (3 , f"{ response .text } " )
579601 if "Correlation-Id" in response .headers :
580602 self ._set_correlation_id (response .headers ["Correlation-Id" ])
581- if response .status_code == 401 :
603+ if response .status == 401 :
582604 raise PulpAuthenticationFailed (method_spec ["operationId" ])
583- if response .status_code == 403 :
605+ if response .status == 403 :
584606 raise PulpNotAutorized (method_spec ["operationId" ])
585607 try :
586608 response .raise_for_status ()
@@ -589,9 +611,7 @@ def _send_request(
589611 raise PulpHTTPError (str (e .response .text ), e .response .status_code )
590612 else :
591613 raise PulpException (str (e ))
592- return _Response (
593- status_code = response .status_code , headers = response .headers , body = response .content
594- )
614+ return _Response (status_code = response .status , headers = response .headers , body = response_body )
595615
596616 def _parse_response (self , method_spec : t .Dict [str , t .Any ], response : _Response ) -> t .Any :
597617 if response .status_code == 204 :
@@ -676,14 +696,16 @@ def call(
676696 2 , "\n " .join ([f" { key } =={ value } " for key , value in query_params .items ()])
677697 )
678698
679- response = self ._send_request (
680- path_spec ,
681- method ,
682- url ,
683- query_params ,
684- headers ,
685- body ,
686- validate_body = validate_body ,
699+ response = asyncio .run (
700+ self ._send_request (
701+ path_spec ,
702+ method ,
703+ url ,
704+ query_params ,
705+ headers ,
706+ body ,
707+ validate_body = validate_body ,
708+ )
687709 )
688710
689711 return self ._parse_response (method_spec , response )
0 commit comments