Skip to content

Commit 64e7763

Browse files
Prevent race condition during JWT obtaining (#329)
1 parent 7e44454 commit 64e7763

File tree

7 files changed

+309
-50
lines changed

7 files changed

+309
-50
lines changed

.github/workflows/pull_request.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ jobs:
8181
/bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.integration coverage run -m pytest -m v4 tests/integration -v"
8282
8383
- name: Run asyncio integration tests
84-
id: integration_tests
84+
id: asyncio_integration_tests
8585
continue-on-error: true
8686
run: |
8787
docker run --rm \
@@ -90,7 +90,7 @@ jobs:
9090
-e CONDUCTOR_SERVER_URL=${{ env.CONDUCTOR_SERVER_URL }} \
9191
-v ${{ github.workspace }}/${{ env.COVERAGE_DIR }}:/package/${{ env.COVERAGE_DIR }}:rw \
9292
conductor-sdk-test:latest \
93-
/bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.integration coverage run -m pytest -m v4 tests/integration -v"
93+
/bin/sh -c "cd /package && COVERAGE_FILE=/package/${{ env.COVERAGE_DIR }}/.coverage.asyncio_integration coverage run -m pytest -m v4 tests/integration/async -v"
9494
9595
- name: Generate coverage report
9696
id: coverage_report
@@ -124,4 +124,4 @@ jobs:
124124

125125
- name: Check test results
126126
if: steps.unit_tests.outcome == 'failure' || steps.bc_tests.outcome == 'failure' || steps.serdeser_tests.outcome == 'failure'
127-
run: exit 1
127+
run: exit 1

src/conductor/asyncio_client/adapters/api_client_adapter.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from __future__ import annotations
2+
3+
import asyncio
14
import json
25
import logging
36
import re
7+
import time
48
from typing import Dict, Optional
59

610
from conductor.asyncio_client.adapters.models import GenerateTokenRequest
@@ -15,6 +19,10 @@
1519

1620

1721
class ApiClientAdapter(ApiClient):
22+
def __init__(self, *args, **kwargs):
23+
self._token_lock = asyncio.Lock()
24+
super().__init__(*args, **kwargs)
25+
1826
async def call_api(
1927
self,
2028
method,
@@ -37,7 +45,9 @@ async def call_api(
3745
"""
3846

3947
try:
40-
logger.debug("HTTP request method: %s; url: %s; header_params: %s", method, url, header_params)
48+
logger.debug(
49+
"HTTP request method: %s; url: %s; header_params: %s", method, url, header_params
50+
)
4151
response_data = await self.rest_client.request(
4252
method,
4353
url,
@@ -46,9 +56,29 @@ async def call_api(
4656
post_params=post_params,
4757
_request_timeout=_request_timeout,
4858
)
49-
if response_data.status == 401 and url != self.configuration.host + "/token": # noqa: PLR2004 (Unauthorized status code)
50-
logger.warning("HTTP response from: %s; status code: 401 - obtaining new token", url)
51-
token = await self.refresh_authorization_token()
59+
if (
60+
response_data.status == 401 # noqa: PLR2004 (Unauthorized status code)
61+
and url != self.configuration.host + "/token"
62+
):
63+
logger.warning(
64+
"HTTP response from: %s; status code: 401 - obtaining new token", url
65+
)
66+
async with self._token_lock:
67+
# The lock is intentionally broad (covers the whole block including the token state)
68+
# to avoid race conditions: without it, other coroutines could mis-evaluate
69+
# token state during a context switch and trigger redundant refreshes
70+
token_expired = (
71+
self.configuration.token_update_time > 0
72+
and time.time()
73+
>= self.configuration.token_update_time
74+
+ self.configuration.auth_token_ttl_sec
75+
)
76+
invalid_token = not self.configuration._http_config.api_key.get("api_key")
77+
78+
if invalid_token or token_expired:
79+
token = await self.refresh_authorization_token()
80+
else:
81+
token = self.configuration._http_config.api_key["api_key"]
5282
header_params["X-Authorization"] = token
5383
response_data = await self.rest_client.request(
5484
method,
@@ -59,7 +89,9 @@ async def call_api(
5989
_request_timeout=_request_timeout,
6090
)
6191
except ApiException as e:
62-
logger.error("HTTP request failed url: %s status: %s; reason: %s", url, e.status, e.reason)
92+
logger.error(
93+
"HTTP request failed url: %s status: %s; reason: %s", url, e.status, e.reason
94+
)
6395
raise e
6496

6597
return response_data
@@ -82,12 +114,10 @@ def response_deserialize(
82114
if (
83115
not response_type
84116
and isinstance(response_data.status, int)
85-
and 100 <= response_data.status <= 599
117+
and 100 <= response_data.status <= 599 # noqa: PLR2004
86118
):
87119
# if not found, look for '1XX', '2XX', etc.
88-
response_type = response_types_map.get(
89-
str(response_data.status)[0] + "XX", None
90-
)
120+
response_type = response_types_map.get(str(response_data.status)[0] + "XX", None)
91121

92122
# deserialize response data
93123
response_text = None
@@ -104,12 +134,10 @@ def response_deserialize(
104134
match = re.search(r"charset=([a-zA-Z\-\d]+)[\s;]?", content_type)
105135
encoding = match.group(1) if match else "utf-8"
106136
response_text = response_data.data.decode(encoding)
107-
return_data = self.deserialize(
108-
response_text, response_type, content_type
109-
)
137+
return_data = self.deserialize(response_text, response_type, content_type)
110138
finally:
111-
if not 200 <= response_data.status <= 299:
112-
logger.error(f"Unexpected response status code: {response_data.status}")
139+
if not 200 <= response_data.status <= 299: # noqa: PLR2004
140+
logger.error("Unexpected response status code: %s", response_data.status)
113141
raise ApiException.from_response(
114142
http_resp=response_data,
115143
body=response_text,
@@ -126,8 +154,9 @@ def response_deserialize(
126154
async def refresh_authorization_token(self):
127155
obtain_new_token_response = await self.obtain_new_token()
128156
token = obtain_new_token_response.get("token")
129-
self.configuration.api_key["api_key"] = token
130-
logger.debug(f"New auth token been set")
157+
self.configuration._http_config.api_key["api_key"] = token
158+
self.configuration.token_update_time = time.time()
159+
logger.debug("New auth token been set")
131160
return token
132161

133162
async def obtain_new_token(self):

src/conductor/asyncio_client/configuration/configuration.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
auth_key: Optional[str] = None,
5858
auth_secret: Optional[str] = None,
5959
debug: bool = False,
60+
auth_token_ttl_min: int = 45,
6061
# Worker properties
6162
polling_interval: Optional[int] = None,
6263
domain: Optional[str] = None,
@@ -136,10 +137,6 @@ def __init__(
136137
if api_key is None:
137138
api_key = {}
138139

139-
if self.auth_key and self.auth_secret:
140-
# Use the auth_key as the API key for X-Authorization header
141-
api_key["api_key"] = self.auth_key
142-
143140
self.__ui_host = os.getenv("CONDUCTOR_UI_SERVER_URL")
144141
if self.__ui_host is None:
145142
self.__ui_host = self.server_url.replace("/api", "")
@@ -182,6 +179,10 @@ def __init__(
182179

183180
self.is_logger_config_applied = False
184181

182+
# Orkes Conductor auth token properties
183+
self.token_update_time = 0
184+
self.auth_token_ttl_sec = auth_token_ttl_min * 60
185+
185186
def _get_env_float(self, env_var: str, default: float) -> float:
186187
"""Get float value from environment variable with default fallback."""
187188
try:
@@ -268,9 +269,7 @@ def _convert_property_value(self, property_name: str, value: str) -> Any:
268269
# For other properties, return as string
269270
return value
270271

271-
def set_worker_property(
272-
self, task_type: str, property_name: str, value: Any
273-
) -> None:
272+
def set_worker_property(self, task_type: str, property_name: str, value: Any) -> None:
274273
"""
275274
Set worker property for a specific task type.
276275
@@ -523,7 +522,5 @@ def ui_host(self):
523522
def __getattr__(self, name: str) -> Any:
524523
"""Delegate attribute access to underlying HTTP configuration."""
525524
if "_http_config" not in self.__dict__ or self._http_config is None:
526-
raise AttributeError(
527-
f"'{self.__class__.__name__}' object has no attribute '{name}'"
528-
)
525+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
529526
return getattr(self._http_config, name)

0 commit comments

Comments
 (0)