Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import abc
import base64
import json
import logging
import time
from uuid import uuid4
from typing import TYPE_CHECKING, List, Any, Iterable, Optional, Union, Dict, cast
Expand Down Expand Up @@ -32,6 +33,8 @@

JWT_BEARER_ASSERTION = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"

_LOGGER = logging.getLogger(__name__)


class AadClientBase(abc.ABC):
_POST = ["POST"]
Expand All @@ -45,7 +48,7 @@ def __init__(
cae_cache: Optional[TokenCache] = None,
*,
additionally_allowed_tenants: Optional[List[str]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
self._authority = normalize_authority(authority) if authority else get_default_authority()

Expand Down Expand Up @@ -91,14 +94,25 @@ def get_cached_access_token(self, scopes: Iterable[str], **kwargs: Any) -> Optio
)

cache = self._get_cache(**kwargs)
now = int(time.time())
for token in cache.search(
TokenCache.CredentialType.ACCESS_TOKEN,
target=list(scopes),
query={"client_id": self._client_id, "realm": tenant},
):
expires_on = int(token["expires_on"])
if expires_on > int(time.time()):
if expires_on > now:
refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None
expires_in = expires_on - now
refresh_on_msg = f", refresh in {refresh_on - now}s" if refresh_on else ""
_LOGGER.debug(
"Access token found in cache for scopes %s (tenant: %s, expires in %ss%s, cache ID: %s)",
list(scopes),
tenant,
expires_in,
refresh_on_msg,
id(cache),
)
return AccessTokenInfo(
token["secret"], expires_on, token_type=token.get("token_type", "Bearer"), refresh_on=refresh_on
)
Expand Down Expand Up @@ -301,7 +315,7 @@ def _get_on_behalf_of_request(
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
user_assertion: str,
**kwargs: Any
**kwargs: Any,
) -> HttpRequest:
data = {
"assertion": user_assertion,
Expand Down Expand Up @@ -356,7 +370,7 @@ def _get_refresh_token_on_behalf_of_request(
scopes: Iterable[str],
client_credential: Union[str, AadClientCertificate, Dict[str, Any]],
refresh_token: str,
**kwargs: Any
**kwargs: Any,
) -> HttpRequest:
data = {
"grant_type": "refresh_token",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import logging
import time
from typing import Any, Callable, Dict, Optional

Expand All @@ -17,14 +18,16 @@
from .._internal import _scopes_to_resource
from .._internal.pipeline import build_pipeline

_LOGGER = logging.getLogger(__name__)


class ManagedIdentityClientBase(abc.ABC):
def __init__(
self,
request_factory: Callable[[str, dict], HttpRequest],
client_id: Optional[str] = None,
identity_config: Optional[Dict] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
self._custom_cache = False
self._cache = kwargs.pop("_cache", None)
Expand Down Expand Up @@ -101,6 +104,15 @@ def get_cached_token(self, *scopes: str, **kwargs: Any) -> Optional[AccessTokenI
expires_on = int(token["expires_on"])
refresh_on = int(token["refresh_on"]) if "refresh_on" in token else None
if expires_on > now and (not refresh_on or refresh_on > now):
expires_in = expires_on - int(now)
refresh_on_msg = f", refresh in {refresh_on - int(now)}s" if refresh_on else ""
_LOGGER.debug(
"Access token found in cache for resource %s (expires in %ss%s, cache ID: %s)",
resource,
expires_in,
refresh_on_msg,
id(self._cache),
)
return AccessTokenInfo(
token["secret"], expires_on, token_type=token.get("token_type", "Bearer"), refresh_on=refresh_on
)
Expand Down