1+ from __future__ import annotations
2+
3+ import asyncio
14import json
25import logging
36import re
7+ import time
48from typing import Dict , Optional
59
610from conductor .asyncio_client .adapters .models import GenerateTokenRequest
1519
1620
1721class ApiClientAdapter (ApiClient ):
22+ def __init__ (self , * args , ** kwargs ):
23+ self ._token_lock = asyncio .Lock ()
24+ super ().__init__ (* args , ** kwargs )
25+
1826 async def call_api (
1927 self ,
2028 method ,
@@ -37,7 +45,9 @@ async def call_api(
3745 """
3846
3947 try :
40- logger .debug ("HTTP request method: %s; url: %s; header_params: %s" , method , url , header_params )
48+ logger .debug (
49+ "HTTP request method: %s; url: %s; header_params: %s" , method , url , header_params
50+ )
4151 response_data = await self .rest_client .request (
4252 method ,
4353 url ,
@@ -46,9 +56,29 @@ async def call_api(
4656 post_params = post_params ,
4757 _request_timeout = _request_timeout ,
4858 )
49- if response_data .status == 401 and url != self .configuration .host + "/token" : # noqa: PLR2004 (Unauthorized status code)
50- logger .warning ("HTTP response from: %s; status code: 401 - obtaining new token" , url )
51- token = await self .refresh_authorization_token ()
59+ if (
60+ response_data .status == 401 # noqa: PLR2004 (Unauthorized status code)
61+ and url != self .configuration .host + "/token"
62+ ):
63+ logger .warning (
64+ "HTTP response from: %s; status code: 401 - obtaining new token" , url
65+ )
66+ async with self ._token_lock :
67+ # The lock is intentionally broad (covers the whole block including the token state)
68+ # to avoid race conditions: without it, other coroutines could mis-evaluate
69+ # token state during a context switch and trigger redundant refreshes
70+ token_expired = (
71+ self .configuration .token_update_time > 0
72+ and time .time ()
73+ >= self .configuration .token_update_time
74+ + self .configuration .auth_token_ttl_sec
75+ )
76+ invalid_token = not self .configuration ._http_config .api_key .get ("api_key" )
77+
78+ if invalid_token or token_expired :
79+ token = await self .refresh_authorization_token ()
80+ else :
81+ token = self .configuration ._http_config .api_key ["api_key" ]
5282 header_params ["X-Authorization" ] = token
5383 response_data = await self .rest_client .request (
5484 method ,
@@ -59,7 +89,9 @@ async def call_api(
5989 _request_timeout = _request_timeout ,
6090 )
6191 except ApiException as e :
62- logger .error ("HTTP request failed url: %s status: %s; reason: %s" , url , e .status , e .reason )
92+ logger .error (
93+ "HTTP request failed url: %s status: %s; reason: %s" , url , e .status , e .reason
94+ )
6395 raise e
6496
6597 return response_data
@@ -82,12 +114,10 @@ def response_deserialize(
82114 if (
83115 not response_type
84116 and isinstance (response_data .status , int )
85- and 100 <= response_data .status <= 599
117+ and 100 <= response_data .status <= 599 # noqa: PLR2004
86118 ):
87119 # if not found, look for '1XX', '2XX', etc.
88- response_type = response_types_map .get (
89- str (response_data .status )[0 ] + "XX" , None
90- )
120+ response_type = response_types_map .get (str (response_data .status )[0 ] + "XX" , None )
91121
92122 # deserialize response data
93123 response_text = None
@@ -104,12 +134,10 @@ def response_deserialize(
104134 match = re .search (r"charset=([a-zA-Z\-\d]+)[\s;]?" , content_type )
105135 encoding = match .group (1 ) if match else "utf-8"
106136 response_text = response_data .data .decode (encoding )
107- return_data = self .deserialize (
108- response_text , response_type , content_type
109- )
137+ return_data = self .deserialize (response_text , response_type , content_type )
110138 finally :
111- if not 200 <= response_data .status <= 299 :
112- logger .error (f "Unexpected response status code: { response_data .status } " )
139+ if not 200 <= response_data .status <= 299 : # noqa: PLR2004
140+ logger .error ("Unexpected response status code: %s" , response_data .status )
113141 raise ApiException .from_response (
114142 http_resp = response_data ,
115143 body = response_text ,
@@ -126,8 +154,9 @@ def response_deserialize(
126154 async def refresh_authorization_token (self ):
127155 obtain_new_token_response = await self .obtain_new_token ()
128156 token = obtain_new_token_response .get ("token" )
129- self .configuration .api_key ["api_key" ] = token
130- logger .debug (f"New auth token been set" )
157+ self .configuration ._http_config .api_key ["api_key" ] = token
158+ self .configuration .token_update_time = time .time ()
159+ logger .debug ("New auth token been set" )
131160 return token
132161
133162 async def obtain_new_token (self ):
0 commit comments