diff --git a/Makefile b/Makefile index 9ee4463d..f2cc2cbb 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ test-all: install .PHONY: lint lint: #! Run type analysis and linting checks lint: install + @mkdir -p .mypy_cache @poetry run mypy ldclient @poetry run isort --check --atomic ldclient contract-tests @poetry run pycodestyle ldclient contract-tests diff --git a/ldclient/impl/datasourcev2/polling.py b/ldclient/impl/datasourcev2/polling.py index a1a67702..e5415039 100644 --- a/ldclient/impl/datasourcev2/polling.py +++ b/ldclient/impl/datasourcev2/polling.py @@ -32,6 +32,8 @@ from ldclient.impl.http import _http_factory from ldclient.impl.repeating_task import RepeatingTask from ldclient.impl.util import ( + _LD_ENVID_HEADER, + _LD_FD_FALLBACK_HEADER, UnsuccessfulResponseException, _Fail, _headers, @@ -117,6 +119,13 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: while self._stop.is_set() is False: result = self._requester.fetch(ss.selector()) if isinstance(result, _Fail): + fallback = None + envid = None + + if result.headers is not None: + fallback = result.headers.get(_LD_FD_FALLBACK_HEADER) == 'true' + envid = result.headers.get(_LD_ENVID_HEADER) + if isinstance(result.exception, UnsuccessfulResponseException): error_info = DataSourceErrorInfo( kind=DataSourceErrorKind.ERROR_RESPONSE, @@ -127,28 +136,28 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: ), ) - fallback = result.exception.headers.get("X-LD-FD-Fallback") == 'true' if fallback: yield Update( state=DataSourceState.OFF, error=error_info, - revert_to_fdv1=True + revert_to_fdv1=True, + environment_id=envid, ) break status_code = result.exception.status if is_http_error_recoverable(status_code): - # TODO(fdv2): Add support for environment ID yield Update( state=DataSourceState.INTERRUPTED, error=error_info, + environment_id=envid, ) continue - # TODO(fdv2): Add support for environment ID yield Update( state=DataSourceState.OFF, error=error_info, + environment_id=envid, ) break @@ -159,19 +168,18 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: message=result.error, ) - # TODO(fdv2): Go has a designation here to handle JSON decoding separately. - # TODO(fdv2): Add support for environment ID yield Update( state=DataSourceState.INTERRUPTED, error=error_info, + environment_id=envid, ) else: (change_set, headers) = result.value yield Update( state=DataSourceState.VALID, change_set=change_set, - environment_id=headers.get("X-LD-EnvID"), - revert_to_fdv1=headers.get('X-LD-FD-Fallback') == 'true' + environment_id=headers.get(_LD_ENVID_HEADER), + revert_to_fdv1=headers.get(_LD_FD_FALLBACK_HEADER) == 'true' ) if self._event.wait(self._poll_interval): @@ -208,7 +216,7 @@ def _poll(self, ss: SelectorStore) -> BasisResult: (change_set, headers) = result.value - env_id = headers.get("X-LD-EnvID") + env_id = headers.get(_LD_ENVID_HEADER) if not isinstance(env_id, str): env_id = None @@ -273,14 +281,14 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: ), retries=1, ) + headers = response.headers if response.status >= 400: return _Fail( - f"HTTP error {response}", UnsuccessfulResponseException(response.status, response.headers) + f"HTTP error {response}", UnsuccessfulResponseException(response.status), + headers=headers, ) - headers = response.headers - if response.status == 304: return _Success(value=(ChangeSetBuilder.no_changes(), headers)) @@ -304,6 +312,7 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: return _Fail( error=changeset_result.error, exception=changeset_result.exception, + headers=headers, # type: ignore ) @@ -436,13 +445,13 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: retries=1, ) + headers = response.headers if response.status >= 400: return _Fail( - f"HTTP error {response}", UnsuccessfulResponseException(response.status, response.headers) + f"HTTP error {response}", UnsuccessfulResponseException(response.status), + headers=headers ) - headers = response.headers - if response.status == 304: return _Success(value=(ChangeSetBuilder.no_changes(), headers)) @@ -466,6 +475,7 @@ def fetch(self, selector: Optional[Selector]) -> PollingResult: return _Fail( error=changeset_result.error, exception=changeset_result.exception, + headers=headers, ) diff --git a/ldclient/impl/datasourcev2/streaming.py b/ldclient/impl/datasourcev2/streaming.py index e8637174..eab7fa8d 100644 --- a/ldclient/impl/datasourcev2/streaming.py +++ b/ldclient/impl/datasourcev2/streaming.py @@ -38,6 +38,8 @@ ) from ldclient.impl.http import HTTPFactory, _http_factory from ldclient.impl.util import ( + _LD_ENVID_HEADER, + _LD_FD_FALLBACK_HEADER, http_error_message, is_http_error_recoverable, log @@ -58,7 +60,6 @@ STREAMING_ENDPOINT = "/sdk/stream" - SseClientBuilder = Callable[[Config, SelectorStore], SSEClient] @@ -146,6 +147,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: self._running = True self._connection_attempt_start_time = time() + envid = None for action in self._sse.all: if isinstance(action, Fault): # If the SSE client detects the stream has closed, then it will @@ -154,7 +156,10 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: if action.error is None: continue - (update, should_continue) = self._handle_error(action.error) + if action.headers is not None: + envid = action.headers.get(_LD_ENVID_HEADER, envid) + + (update, should_continue) = self._handle_error(action.error, envid) if update is not None: yield update @@ -163,12 +168,15 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: continue if isinstance(action, Start) and action.headers is not None: - fallback = action.headers.get('X-LD-FD-Fallback') == 'true' + fallback = action.headers.get(_LD_FD_FALLBACK_HEADER) == 'true' + envid = action.headers.get(_LD_ENVID_HEADER, envid) + if fallback: self._record_stream_init(True) yield Update( state=DataSourceState.OFF, - revert_to_fdv1=True + revert_to_fdv1=True, + environment_id=envid, ) break @@ -176,7 +184,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: continue try: - update = self._process_message(action, change_set_builder) + update = self._process_message(action, change_set_builder, envid) if update is not None: self._record_stream_init(False) self._connection_attempt_start_time = None @@ -187,7 +195,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: ) self._sse.interrupt() - (update, should_continue) = self._handle_error(e) + (update, should_continue) = self._handle_error(e, envid) if update is not None: yield update if not should_continue: @@ -204,7 +212,7 @@ def sync(self, ss: SelectorStore) -> Generator[Update, None, None]: DataSourceErrorKind.UNKNOWN, 0, time(), str(e) ), revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) self._sse.close() @@ -226,7 +234,7 @@ def _record_stream_init(self, failed: bool): # pylint: disable=too-many-return-statements def _process_message( - self, msg: Event, change_set_builder: ChangeSetBuilder + self, msg: Event, change_set_builder: ChangeSetBuilder, envid: Optional[str] ) -> Optional[Update]: """ Processes a single message from the SSE stream and returns an Update @@ -247,7 +255,7 @@ def _process_message( change_set_builder.expect_changes() return Update( state=DataSourceState.VALID, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) return None @@ -293,13 +301,13 @@ def _process_message( return Update( state=DataSourceState.VALID, change_set=change_set, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) log.info("Unexpected event found in stream: %s", msg.event) return None - def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: + def _handle_error(self, error: Exception, envid: Optional[str]) -> Tuple[Optional[Update], bool]: """ This method handles errors that occur during the streaming process. @@ -328,7 +336,7 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: DataSourceErrorKind.INVALID_DATA, 0, time(), str(error) ), revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) return (update, True) @@ -344,11 +352,15 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: str(error), ) - if error.headers is not None and error.headers.get("X-LD-FD-Fallback") == 'true': + if envid is None and error.headers is not None: + envid = error.headers.get(_LD_ENVID_HEADER) + + if error.headers is not None and error.headers.get(_LD_FD_FALLBACK_HEADER) == 'true': update = Update( state=DataSourceState.OFF, error=error_info, - revert_to_fdv1=True + revert_to_fdv1=True, + environment_id=envid, ) return (update, False) @@ -364,7 +376,7 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: ), error=error_info, revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) if not is_recoverable: @@ -386,7 +398,7 @@ def _handle_error(self, error: Exception) -> Tuple[Optional[Update], bool]: DataSourceErrorKind.UNKNOWN, 0, time(), str(error) ), revert_to_fdv1=False, - environment_id=None, # TODO(sdk-1410) + environment_id=envid, ) # no stacktrace here because, for a typical connection error, it'll # just be a lengthy tour of urllib3 internals @@ -411,5 +423,4 @@ def __init__(self, config: Config): def build(self) -> StreamingDataSource: """Builds a StreamingDataSource instance with the configured parameters.""" - # TODO(fdv2): Add in the other controls here. return StreamingDataSource(self._config) diff --git a/ldclient/impl/datasystem/config.py b/ldclient/impl/datasystem/config.py index d3b34a7a..eadc6f0e 100644 --- a/ldclient/impl/datasystem/config.py +++ b/ldclient/impl/datasystem/config.py @@ -210,18 +210,3 @@ def persistent_store(store: FeatureStore) -> ConfigBuilder: although it will keep it up-to-date. """ return default().data_store(store, DataStoreMode.READ_WRITE) - - -# TODO(fdv2): Implement these methods -# -# WithEndpoints configures the data system with custom endpoints for -# LaunchDarkly's streaming and polling synchronizers. This method is not -# necessary for most use-cases, but can be useful for testing or custom -# network configurations. -# -# Any endpoint that is not specified (empty string) will be treated as the -# default LaunchDarkly SaaS endpoint for that service. - -# WithRelayProxyEndpoints configures the data system with a single endpoint -# for LaunchDarkly's streaming and polling synchronizers. The endpoint -# should be Relay Proxy's base URI, for example http://localhost:8123. diff --git a/ldclient/impl/util.py b/ldclient/impl/util.py index 81054f4b..54caf9de 100644 --- a/ldclient/impl/util.py +++ b/ldclient/impl/util.py @@ -4,7 +4,7 @@ import time from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Generic, Optional, TypeVar, Union +from typing import Any, Dict, Generic, Mapping, Optional, TypeVar, Union from urllib.parse import urlparse, urlunparse from ldclient.impl.http import _base_headers @@ -35,6 +35,9 @@ def timedelta_millis(delta: timedelta) -> float: # Compiled regex pattern for valid characters in application values and SDK keys _VALID_CHARACTERS_REGEX = re.compile(r"[^a-zA-Z0-9._-]") +_LD_ENVID_HEADER = 'X-LD-EnvID' +_LD_FD_FALLBACK_HEADER = 'X-LD-FD-Fallback' + def validate_application_info(application: dict, logger: logging.Logger) -> dict: return { @@ -117,23 +120,18 @@ def __str__(self, *args, **kwargs): class UnsuccessfulResponseException(Exception): - def __init__(self, status, headers={}): + def __init__(self, status): super(UnsuccessfulResponseException, self).__init__("HTTP error %d" % status) self._status = status - self._headers = headers @property def status(self): return self._status - @property - def headers(self): - return self._headers - def throw_if_unsuccessful_response(resp): if resp.status >= 400: - raise UnsuccessfulResponseException(resp.status, resp.headers) + raise UnsuccessfulResponseException(resp.status) def is_http_error_recoverable(status): @@ -290,6 +288,7 @@ class _Success(Generic[T]): class _Fail(Generic[E]): error: E exception: Optional[Exception] = None + headers: Optional[Mapping[str, Any]] = None # TODO(breaking): Replace the above Result class with an improved generic diff --git a/ldclient/integrations/test_datav2.py b/ldclient/integrations/test_datav2.py index 744264f2..a2da52db 100644 --- a/ldclient/integrations/test_datav2.py +++ b/ldclient/integrations/test_datav2.py @@ -551,17 +551,21 @@ class TestDataV2: :: from ldclient.impl.datasystem import config as datasystem_config + from ldclient.integrations.test_datav2 import TestDataV2 + td = TestDataV2.data_source() td.update(td.flag('flag-key-1').variation_for_all(True)) # Configure the data system with TestDataV2 as both initializer and synchronizer data_config = datasystem_config.custom() - data_config.initializers([lambda: td.build_initializer()]) - data_config.synchronizers(lambda: td.build_synchronizer()) + data_config.initializers([td.build_initializer]) + data_config.synchronizers(td.build_synchronizer) - # TODO(fdv2): This will be integrated with the main Config in a future version - # For now, TestDataV2 is primarily intended for unit testing scenarios + config = Config( + sdk_key, + datasystem_config=data_config.build(), + ) # flags can be updated at any time: td.update(td.flag('flag-key-1'). diff --git a/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py b/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py index 3410a1e6..7aa3686e 100644 --- a/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py +++ b/ldclient/testing/impl/datasourcev2/test_polling_synchronizer.py @@ -20,7 +20,13 @@ Selector, ServerIntent ) -from ldclient.impl.util import UnsuccessfulResponseException, _Fail, _Success +from ldclient.impl.util import ( + _LD_ENVID_HEADER, + _LD_FD_FALLBACK_HEADER, + UnsuccessfulResponseException, + _Fail, + _Success +) from ldclient.interfaces import DataSourceErrorKind, DataSourceState from ldclient.testing.mock_components import MockSelectorStore @@ -304,3 +310,169 @@ def test_unrecoverable_error_shuts_down(): assert False, "Expected StopIteration" except StopIteration: pass + + +def test_envid_from_success_headers(): + """Test that environment ID is captured from successful polling response headers""" + change_set = ChangeSetBuilder.no_changes() + headers = {_LD_ENVID_HEADER: 'test-env-polling-123'} + polling_result: PollingResult = _Success(value=(change_set, headers)) + + synchronizer = PollingDataSource( + poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) + ) + + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert valid.state == DataSourceState.VALID + assert valid.error is None + assert valid.revert_to_fdv1 is False + assert valid.environment_id == 'test-env-polling-123' + + +def test_envid_from_success_with_changeset(): + """Test that environment ID is captured from polling response with actual changes""" + builder = ChangeSetBuilder() + builder.start(intent=IntentCode.TRANSFER_FULL) + builder.add_put( + version=100, kind=ObjectKind.FLAG, key="flag-key", obj={"key": "flag-key"} + ) + change_set = builder.finish(selector=Selector(state="p:SOMETHING:300", version=300)) + headers = {_LD_ENVID_HEADER: 'test-env-456'} + polling_result: PollingResult = _Success(value=(change_set, headers)) + + synchronizer = PollingDataSource( + poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) + ) + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert valid.state == DataSourceState.VALID + assert valid.environment_id == 'test-env-456' + assert valid.change_set is not None + assert len(valid.change_set.changes) == 1 + + +def test_envid_from_fallback_headers(): + """Test that environment ID is captured when fallback header is present on success""" + change_set = ChangeSetBuilder.no_changes() + headers = { + _LD_ENVID_HEADER: 'test-env-fallback', + _LD_FD_FALLBACK_HEADER: 'true' + } + polling_result: PollingResult = _Success(value=(change_set, headers)) + + synchronizer = PollingDataSource( + poll_interval=0.01, requester=ListBasedRequester(results=iter([polling_result])) + ) + + valid = next(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert valid.state == DataSourceState.VALID + assert valid.revert_to_fdv1 is True + assert valid.environment_id == 'test-env-fallback' + + +def test_envid_from_error_headers_recoverable(): + """Test that environment ID is captured from error response headers for recoverable errors""" + builder = ChangeSetBuilder() + builder.start(intent=IntentCode.TRANSFER_FULL) + builder.add_delete(version=101, kind=ObjectKind.FLAG, key="flag-key") + change_set = builder.finish(selector=Selector(state="p:SOMETHING:300", version=300)) + headers_success = {_LD_ENVID_HEADER: 'test-env-success'} + polling_result: PollingResult = _Success(value=(change_set, headers_success)) + + headers_error = {_LD_ENVID_HEADER: 'test-env-408'} + _failure = _Fail( + error="error for test", + exception=UnsuccessfulResponseException(status=408), + headers=headers_error + ) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure, polling_result])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + interrupted = next(sync) + valid = next(sync) + + assert interrupted.state == DataSourceState.INTERRUPTED + assert interrupted.environment_id == 'test-env-408' + assert interrupted.error is not None + assert interrupted.error.status_code == 408 + + assert valid.state == DataSourceState.VALID + assert valid.environment_id == 'test-env-success' + + +def test_envid_from_error_headers_unrecoverable(): + """Test that environment ID is captured from error response headers for unrecoverable errors""" + headers_error = {_LD_ENVID_HEADER: 'test-env-401'} + _failure = _Fail( + error="error for test", + exception=UnsuccessfulResponseException(status=401), + headers=headers_error + ) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + off = next(sync) + + assert off.state == DataSourceState.OFF + assert off.environment_id == 'test-env-401' + assert off.error is not None + assert off.error.status_code == 401 + + +def test_envid_from_error_with_fallback(): + """Test that environment ID and fallback are captured from error response""" + headers_error = { + _LD_ENVID_HEADER: 'test-env-503', + _LD_FD_FALLBACK_HEADER: 'true' + } + _failure = _Fail( + error="error for test", + exception=UnsuccessfulResponseException(status=503), + headers=headers_error + ) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + off = next(sync) + + assert off.state == DataSourceState.OFF + assert off.revert_to_fdv1 is True + assert off.environment_id == 'test-env-503' + + +def test_envid_from_generic_error_with_headers(): + """Test that environment ID is captured from generic errors with headers""" + builder = ChangeSetBuilder() + builder.start(intent=IntentCode.TRANSFER_FULL) + change_set = builder.finish(selector=Selector(state="p:SOMETHING:300", version=300)) + headers_success = {} + polling_result: PollingResult = _Success(value=(change_set, headers_success)) + + headers_error = {_LD_ENVID_HEADER: 'test-env-generic'} + _failure = _Fail(error="generic error for test", headers=headers_error) + + synchronizer = PollingDataSource( + poll_interval=0.01, + requester=ListBasedRequester(results=iter([_failure, polling_result])), + ) + sync = synchronizer.sync(MockSelectorStore(Selector.no_selector())) + interrupted = next(sync) + valid = next(sync) + + assert interrupted.state == DataSourceState.INTERRUPTED + assert interrupted.environment_id == 'test-env-generic' + assert interrupted.error is not None + assert interrupted.error.kind == DataSourceErrorKind.NETWORK_ERROR + + assert valid.state == DataSourceState.VALID diff --git a/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py b/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py index 90c7037e..c581e785 100644 --- a/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py +++ b/ldclient/testing/impl/datasourcev2/test_streaming_synchronizer.py @@ -6,7 +6,7 @@ from typing import Iterable, List, Optional import pytest -from ld_eventsource.actions import Action +from ld_eventsource.actions import Action, Start from ld_eventsource.http import HTTPStatusError from ld_eventsource.sse_client import Event, Fault @@ -30,6 +30,7 @@ Selector, ServerIntent ) +from ldclient.impl.util import _LD_ENVID_HEADER, _LD_FD_FALLBACK_HEADER from ldclient.interfaces import DataSourceErrorKind, DataSourceState from ldclient.testing.mock_components import MockSelectorStore @@ -416,10 +417,12 @@ def test_invalid_json_decoding(events): # pylint: disable=redefined-outer-name def test_stops_on_unrecoverable_status_code( events, ): # pylint: disable=redefined-outer-name + error = HTTPStatusError(401) + fault = Fault(error=error) builder = list_sse_client( [ # This will generate an error but the stream should continue - Fault(error=HTTPStatusError(401)), + fault, # We send these valid combinations to ensure the stream is NOT # being processed after the 401. events[EventName.SERVER_INTENT], @@ -445,12 +448,18 @@ def test_stops_on_unrecoverable_status_code( def test_continues_on_recoverable_status_code( events, ): # pylint: disable=redefined-outer-name + error1 = HTTPStatusError(400) + fault1 = Fault(error=error1) + + error2 = HTTPStatusError(408) + fault2 = Fault(error=error2) + builder = list_sse_client( [ # This will generate an error but the stream should continue - Fault(error=HTTPStatusError(400)), + fault1, events[EventName.SERVER_INTENT], - Fault(error=HTTPStatusError(408)), + fault2, # We send these valid combinations to ensure the stream will # continue to be processed. events[EventName.SERVER_INTENT], @@ -478,3 +487,207 @@ def test_continues_on_recoverable_status_code( assert updates[2].change_set.selector.version == 300 assert updates[2].change_set.selector.state == "p:SOMETHING:300" assert updates[2].change_set.intent_code == IntentCode.TRANSFER_FULL + + +def test_envid_from_start_action(events): # pylint: disable=redefined-outer-name + """Test that environment ID is captured from Start action headers""" + start_action = Start(headers={_LD_ENVID_HEADER: 'test-env-123'}) + + builder = list_sse_client( + [ + start_action, + events[EventName.SERVER_INTENT], + events[EventName.PAYLOAD_TRANSFERRED], + ] + ) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.VALID + assert updates[0].environment_id == 'test-env-123' + + +def test_envid_not_cleared_from_next_start(events): # pylint: disable=redefined-outer-name + """Test that environment ID is captured from Start action headers""" + start_action_with_headers = Start(headers={_LD_ENVID_HEADER: 'test-env-123'}) + start_action_without_headers = Start() + + builder = list_sse_client( + [ + start_action_with_headers, + events[EventName.SERVER_INTENT], + events[EventName.PAYLOAD_TRANSFERRED], + start_action_without_headers, + events[EventName.SERVER_INTENT], + events[EventName.PAYLOAD_TRANSFERRED], + ] + ) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 2 + assert updates[0].state == DataSourceState.VALID + assert updates[0].environment_id == 'test-env-123' + + assert updates[1].state == DataSourceState.VALID + assert updates[1].environment_id == 'test-env-123' + + +def test_envid_preserved_across_events(events): # pylint: disable=redefined-outer-name + """Test that environment ID is preserved across multiple events after being set on Start""" + start_action = Start(headers={_LD_ENVID_HEADER: 'test-env-456'}) + + builder = list_sse_client( + [ + start_action, + events[EventName.SERVER_INTENT], + events[EventName.PUT_OBJECT], + events[EventName.PAYLOAD_TRANSFERRED], + ] + ) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.VALID + assert updates[0].environment_id == 'test-env-456' + assert updates[0].change_set is not None + assert len(updates[0].change_set.changes) == 1 + + +def test_envid_from_fallback_header(): + """Test that environment ID is captured when fallback header is present""" + start_action = Start(headers={_LD_ENVID_HEADER: 'test-env-fallback', _LD_FD_FALLBACK_HEADER: 'true'}) + + builder = list_sse_client([start_action]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.OFF + assert updates[0].revert_to_fdv1 is True + assert updates[0].environment_id == 'test-env-fallback' + + +def test_envid_from_fault_action(): + """Test that environment ID is captured from Fault action headers""" + error = HTTPStatusError(401, headers={_LD_ENVID_HEADER: 'test-env-fault'}) + fault_action = Fault(error=error) + + builder = list_sse_client([fault_action]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.OFF + assert updates[0].environment_id == 'test-env-fault' + assert updates[0].error is not None + assert updates[0].error.status_code == 401 + + +def test_envid_not_cleared_from_next_error(): + """Test that environment ID is captured from Fault action headers""" + error_with_headers_ = HTTPStatusError(408, headers={_LD_ENVID_HEADER: 'test-env-fault'}) + error_without_headers_ = HTTPStatusError(401) + fault_action_with_headers = Fault(error=error_with_headers_) + fault_action_without_headers = Fault(error=error_without_headers_) + + builder = list_sse_client([fault_action_with_headers, fault_action_without_headers]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 2 + assert updates[0].state == DataSourceState.INTERRUPTED + assert updates[0].environment_id == 'test-env-fault' + assert updates[0].error is not None + assert updates[0].error.status_code == 408 + + assert updates[1].state == DataSourceState.OFF + assert updates[1].environment_id == 'test-env-fault' + assert updates[1].error is not None + assert updates[1].error.status_code == 401 + + +def test_envid_from_fault_with_fallback(): + """Test that environment ID and fallback are captured from Fault action""" + error = HTTPStatusError(503, headers={_LD_ENVID_HEADER: 'test-env-503', _LD_FD_FALLBACK_HEADER: 'true'}) + fault_action = Fault(error=error) + + builder = list_sse_client([fault_action]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.OFF + assert updates[0].revert_to_fdv1 is True + assert updates[0].environment_id == 'test-env-503' + + +def test_envid_from_recoverable_fault(events): # pylint: disable=redefined-outer-name + """Test that environment ID is captured from recoverable Fault and preserved in subsequent events""" + error = HTTPStatusError(400, headers={_LD_ENVID_HEADER: 'test-env-400'}) + fault_action = Fault(error=error) + + builder = list_sse_client( + [ + fault_action, + events[EventName.SERVER_INTENT], + events[EventName.PAYLOAD_TRANSFERRED], + ] + ) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 2 + # First update from the fault + assert updates[0].state == DataSourceState.INTERRUPTED + assert updates[0].environment_id == 'test-env-400' + + # Second update should preserve the envid + assert updates[1].state == DataSourceState.VALID + assert updates[1].environment_id == 'test-env-400' + + +def test_envid_missing_when_no_headers(): + """Test that environment ID is None when no headers are present""" + start_action = Start() + + server_intent = ServerIntent( + payload=Payload( + id="id", + target=300, + code=IntentCode.TRANSFER_NONE, + reason="up-to-date", + ) + ) + intent_event = Event( + event=EventName.SERVER_INTENT, + data=json.dumps(server_intent.to_dict()), + ) + + builder = list_sse_client([start_action, intent_event]) + + synchronizer = StreamingDataSource(Config(sdk_key="key")) + synchronizer._sse_client_builder = builder + updates = list(synchronizer.sync(MockSelectorStore(Selector.no_selector()))) + + assert len(updates) == 1 + assert updates[0].state == DataSourceState.VALID + assert updates[0].environment_id is None diff --git a/pyproject.toml b/pyproject.toml index 118c3336..54a5eaab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ test-filesource = ["pyyaml", "watchdog"] [tool.poetry.group.dev.dependencies] mock = ">=2.0.0" -pytest = ">=2.8" +pytest = "^8.0.0" redis = ">=2.10.5,<5.0.0" boto3 = ">=1.9.71,<2.0.0" coverage = ">=4.4"