|
| 1 | +import functools |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +from requests import Session |
| 5 | + |
| 6 | +from msgraphcore.constants import CONNECTION_TIMEOUT, REQUEST_TIMEOUT |
| 7 | +from msgraphcore.enums import APIVersion, NationalClouds |
| 8 | +from msgraphcore.middleware.abc_token_credential import TokenCredential |
| 9 | +from msgraphcore.middleware.authorization import AuthorizationHandler |
| 10 | +from msgraphcore.middleware.middleware import BaseMiddleware, MiddlewarePipeline |
| 11 | + |
| 12 | + |
| 13 | +class HTTPClientFactory: |
| 14 | + """ |
| 15 | + Constructs HTTP Client(session) instances configured with either custom or default |
| 16 | + pipeline of middleware. |
| 17 | + """ |
| 18 | + def __init__(self, session: Optional[Session], **kwargs): |
| 19 | + """Class constructor that accepts a user provided session object and kwargs |
| 20 | + to configure the request handling behaviour of the client""" |
| 21 | + self.api_version = kwargs.get("api_version", APIVersion.v1) |
| 22 | + self.endpoint = kwargs.get('cloud', NationalClouds.Global) |
| 23 | + self.timeout = kwargs.get('timeout', (CONNECTION_TIMEOUT, REQUEST_TIMEOUT)) |
| 24 | + self.base_url = self._get_base_url() |
| 25 | + if session: |
| 26 | + self.session = session |
| 27 | + else: |
| 28 | + self.session = Session() |
| 29 | + self._set_default_timeout() |
| 30 | + |
| 31 | + # should this be a class method |
| 32 | + def create_with_default_middleware(self, credential: TokenCredential, **kwargs) -> Session: |
| 33 | + """Applies the default middleware chain to the HTTP Client""" |
| 34 | + middleware = [ |
| 35 | + AuthorizationHandler(credential, **kwargs), |
| 36 | + ] |
| 37 | + self._register(middleware) |
| 38 | + return self.session |
| 39 | + |
| 40 | + def create_with_custom_middleware(self, middleware: [BaseMiddleware]) -> Session: |
| 41 | + """Applies a custom middleware chain to the HTTP Client """ |
| 42 | + if not middleware: |
| 43 | + raise ValueError("Please provide a list of custom middleware") |
| 44 | + self._register(middleware) |
| 45 | + return self.session |
| 46 | + |
| 47 | + def _get_base_url(self): |
| 48 | + """Helper method to get the base url""" |
| 49 | + return self.endpoint + '/' + self.api_version |
| 50 | + |
| 51 | + def _set_default_timeout(self): |
| 52 | + """Helper method to set a default timeout for the session |
| 53 | + Reference: https://github.com/psf/requests/issues/2011 |
| 54 | + """ |
| 55 | + self.session.request = functools.partial(self.session.request, timeout=self.timeout) |
| 56 | + |
| 57 | + def _register(self, middleware: [BaseMiddleware]) -> None: |
| 58 | + """ |
| 59 | + Helper method that constructs a middleware_pipeline with the specified middleware |
| 60 | + """ |
| 61 | + if middleware: |
| 62 | + middleware_pipeline = MiddlewarePipeline() |
| 63 | + for ware in middleware: |
| 64 | + middleware_pipeline.add_middleware(ware) |
| 65 | + |
| 66 | + self.session.mount('https://', middleware_pipeline) |
0 commit comments