|
| 1 | +import logging |
| 2 | +import jwt |
| 3 | +from datetime import datetime, timedelta |
| 4 | +from typing import Optional, Dict, Tuple |
| 5 | +from urllib.parse import urlparse |
| 6 | + |
| 7 | +logger = logging.getLogger(__name__) |
| 8 | + |
| 9 | + |
| 10 | +def parse_hostname(hostname: str) -> str: |
| 11 | + """ |
| 12 | + Normalize the hostname to include scheme and trailing slash. |
| 13 | +
|
| 14 | + Args: |
| 15 | + hostname: The hostname to normalize |
| 16 | +
|
| 17 | + Returns: |
| 18 | + Normalized hostname with scheme and trailing slash |
| 19 | + """ |
| 20 | + if not hostname.startswith("http://") and not hostname.startswith("https://"): |
| 21 | + hostname = f"https://{hostname}" |
| 22 | + if not hostname.endswith("/"): |
| 23 | + hostname = f"{hostname}/" |
| 24 | + return hostname |
| 25 | + |
| 26 | + |
| 27 | +def decode_token(access_token: str) -> Optional[Dict]: |
| 28 | + """ |
| 29 | + Decode a JWT token without verification to extract claims. |
| 30 | +
|
| 31 | + Args: |
| 32 | + access_token: The JWT access token to decode |
| 33 | +
|
| 34 | + Returns: |
| 35 | + Decoded token claims or None if decoding fails |
| 36 | + """ |
| 37 | + try: |
| 38 | + return jwt.decode(access_token, options={"verify_signature": False}) |
| 39 | + except Exception as e: |
| 40 | + logger.debug("Failed to decode JWT token: %s", e) |
| 41 | + return None |
| 42 | + |
| 43 | + |
| 44 | +def is_same_host(url1: str, url2: str) -> bool: |
| 45 | + """ |
| 46 | + Check if two URLs have the same host. |
| 47 | +
|
| 48 | + Args: |
| 49 | + url1: First URL |
| 50 | + url2: Second URL |
| 51 | +
|
| 52 | + Returns: |
| 53 | + True if hosts are the same, False otherwise |
| 54 | + """ |
| 55 | + try: |
| 56 | + host1 = urlparse(url1).netloc |
| 57 | + host2 = urlparse(url2).netloc |
| 58 | + # Handle port differences (e.g., example.com vs example.com:443) |
| 59 | + host1_without_port = host1.split(":")[0] |
| 60 | + host2_without_port = host2.split(":")[0] |
| 61 | + return host1_without_port == host2_without_port |
| 62 | + except Exception as e: |
| 63 | + logger.debug("Failed to parse URLs: %s", e) |
| 64 | + return False |
| 65 | + |
| 66 | + |
| 67 | +class Token: |
| 68 | + """ |
| 69 | + Represents an OAuth token with expiration management. |
| 70 | + """ |
| 71 | + |
| 72 | + def __init__(self, access_token: str, token_type: str = "Bearer"): |
| 73 | + """ |
| 74 | + Initialize a token. |
| 75 | +
|
| 76 | + Args: |
| 77 | + access_token: The access token string |
| 78 | + token_type: The token type (default: Bearer) |
| 79 | + """ |
| 80 | + self.access_token = access_token |
| 81 | + self.token_type = token_type |
| 82 | + self.expiry_time = self._calculate_expiry() |
| 83 | + |
| 84 | + def _calculate_expiry(self) -> datetime: |
| 85 | + """ |
| 86 | + Calculate the token expiry time from JWT claims. |
| 87 | +
|
| 88 | + Returns: |
| 89 | + The token expiry datetime |
| 90 | + """ |
| 91 | + decoded = decode_token(self.access_token) |
| 92 | + if decoded and "exp" in decoded: |
| 93 | + # Use JWT exp claim with 1 minute buffer |
| 94 | + return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1) |
| 95 | + # Default to 1 hour if no expiry info |
| 96 | + return datetime.now() + timedelta(hours=1) |
| 97 | + |
| 98 | + def is_expired(self) -> bool: |
| 99 | + """ |
| 100 | + Check if the token is expired. |
| 101 | +
|
| 102 | + Returns: |
| 103 | + True if token is expired, False otherwise |
| 104 | + """ |
| 105 | + return datetime.now() >= self.expiry_time |
| 106 | + |
| 107 | + def to_dict(self) -> Dict[str, str]: |
| 108 | + """ |
| 109 | + Convert token to dictionary format. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + Dictionary with access_token and token_type |
| 113 | + """ |
| 114 | + return { |
| 115 | + "access_token": self.access_token, |
| 116 | + "token_type": self.token_type, |
| 117 | + } |
0 commit comments