diff --git a/.azdo/ci-pr.yaml b/.azdo/ci-pr.yaml index 75426096..f1b1f10a 100644 --- a/.azdo/ci-pr.yaml +++ b/.azdo/ci-pr.yaml @@ -51,8 +51,9 @@ steps: python -m pip install ./dist/microsoft_agents_hosting_aiohttp*.whl python -m pip install ./dist/microsoft_agents_hosting_teams*.whl python -m pip install ./dist/microsoft_agents_storage_blob*.whl + python -m pip install ./dist/microsoft_agents_storage_cosmos*.whl displayName: 'Install wheels' - script: | pytest - displayName: 'Test with pytest' + displayName: 'Test with pytest' \ No newline at end of file diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index c9508469..93898f71 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -57,6 +57,7 @@ jobs: python -m pip install ./dist/microsoft_agents_hosting_aiohttp*.whl python -m pip install ./dist/microsoft_agents_hosting_teams*.whl python -m pip install ./dist/microsoft_agents_storage_blob*.whl + python -m pip install ./dist/microsoft_agents_storage_cosmos*.whl - name: Test with pytest run: | pytest diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/__init__.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/__init__.py index be2e6079..1d54743e 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/__init__.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/__init__.py @@ -1,5 +1,5 @@ from .store_item import StoreItem -from .storage import Storage +from .storage import Storage, AsyncStorageBase from .memory_storage import MemoryStorage -__all__ = ["StoreItem", "Storage", "MemoryStorage"] +__all__ = ["StoreItem", "Storage", "AsyncStorageBase", "MemoryStorage"] diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/error_handling.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/error_handling.py index 396c8d2f..40a62d75 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/error_handling.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/error_handling.py @@ -7,8 +7,13 @@ async def ignore_error(promise: Awaitable, ignore_error_filter: error_filter): """ Ignores errors based on the provided filter function. + promise: the awaitable to execute ignore_error_filter: a function that takes an Exception and returns True if the error should be + ignored, False otherwise. + + Returns the result of the promise if successful, or None if the error is ignored. + Raises the error if it is not ignored. """ try: return await promise @@ -21,6 +26,9 @@ async def ignore_error(promise: Awaitable, ignore_error_filter: error_filter): def is_status_code_error(*ignored_codes: list[int]) -> error_filter: """ Creates an error filter function that ignores errors with specific status codes. + + ignored_codes: a list of status codes to ignore + Returns a function that takes an Exception and returns True if the error's status code is in ignored_codes. """ def func(err: Exception) -> bool: diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage.py index e9b0fcdd..4a71d939 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage.py @@ -40,7 +40,10 @@ async def delete(self, keys: list[str]) -> None: class AsyncStorageBase(Storage): - """Base class for asynchronous storage implementations.""" + """Base class for asynchronous storage implementations with operations + that work on single items. The bulk operations are implemented in terms + of the single-item operations. + """ async def initialize(self) -> None: """Initializes the storage container""" diff --git a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py index 9f805dae..e095cbd6 100644 --- a/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py +++ b/libraries/microsoft-agents-hosting-core/microsoft/agents/hosting/core/storage/storage_test_utils.py @@ -92,22 +92,6 @@ def subsets(lst, n=-1): return subsets -class StorageMock(ABC): - """A mock wrapper around a Storage implementation to be used in tests.""" - - def get_backing_store(self) -> Storage: - raise NotImplementedError("Subclasses must implement this") - - async def read(self, *args, **kwargs): - return await self.get_backing_store().read(*args, **kwargs) - - async def write(self, *args, **kwargs): - return await self.get_backing_store().write(*args, **kwargs) - - async def delete(self, *args, **kwargs): - return await self.get_backing_store().delete(*args, **kwargs) - - # bootstrapping class to compare against # if this class is correct, then the tests are correct class StorageBaseline(Storage): @@ -133,10 +117,14 @@ def delete(self, keys: list[str]) -> None: async def equals(self, other) -> bool: """ - Compare the items for all keys seenby this mock instance. + Compare the items for all keys seen by this mock instance. + + Note: This is an extra safety measure, and I've made the executive decision to not test this method itself - as it is not the main focus of the test suite. + because passing tests with calls to this method + is also dependent on the correctness of other + aspects, based on the other assertions in the tests. """ for key in self._key_history: if key not in self._memory: @@ -155,6 +143,7 @@ async def equals(self, other) -> bool: class StorageTestsCommon(ABC): + """Common fixtures for Storage implementations.""" KEY_LIST = [ "f", @@ -211,8 +200,16 @@ def changes(self, request): class CRUDStorageTests(StorageTestsCommon): + """Tests for Storage implementations that support CRUD operations. + + To use, subclass and implement the `storage` method. + """ - async def storage(self, initial_data=None, existing=False): + async def storage(self, initial_data=None, existing=False) -> Storage: + """Return a Storage instance to be tested. + :param initial_data: The initial data to populate the storage with. + :param existing: If True, the storage instance should connect to an existing store. + """ raise NotImplementedError("Subclasses must implement this") @pytest.mark.asyncio @@ -446,9 +443,64 @@ async def test_flow(self): await storage.read(["key_b"], target_cls=MockStoreItemB) assert await baseline_storage.equals(storage) - if not isinstance(storage.get_backing_store(), MemoryStorage): + if not isinstance(storage, MemoryStorage): # if not memory storage, then items should persist del storage gc.collect() storage_alt = await self.storage(existing=True) assert await baseline_storage.equals(storage_alt) + + +class QuickCRUDStorageTests(CRUDStorageTests): + """Reduced set of permutations for quicker tests. Useful for debugging.""" + + KEY_LIST = ["\\?/#\t\n\r*", "test.txt"] + + READ_KEY_LIST = KEY_LIST + ["nonexistent_key"] + + STATE_LIST = [ + {key: MockStoreItem({"id": key, "value": f"value{key}"}) for key in KEY_LIST} + ] + + @pytest.fixture(params=STATE_LIST) + def initial_state(self, request): + return request.param + + @pytest.fixture(params=KEY_LIST) + def key(self, request): + return request.param + + @pytest.fixture(params=[KEY_LIST]) + def keys(self, request): + return request.param + + @pytest.fixture(params=subsets(KEY_LIST, 2)) + def changes(self, request): + changes_obj = {} + keys = request.param + changes_obj["new_key"] = MockStoreItemB( + {"field": "new_value_for_new_key"}, True + ) + for i, key in enumerate(keys): + if i % 2 == 0: + changes_obj[key] = MockStoreItemB( + {"data": f"value{key}"}, (i // 2) % 2 == 0 + ) + else: + changes_obj[key] = MockStoreItem( + {"id": key, "value": f"new_value_for_{key}"} + ) + changes_obj["new_key_2"] = MockStoreItem({"field": "new_value_for_new_key_2"}) + return changes_obj + + +def debug_print(*args): + """Print debug information clearly separated in the console.""" + print("\n" * 2) + print("--- DEBUG ---") + for arg in args: + print("\n" * 2) + print(arg) + print("\n" * 2) + print("--- ----- ---") + print("\n" * 2) diff --git a/libraries/microsoft-agents-hosting-core/tests/test_memory_storage.py b/libraries/microsoft-agents-hosting-core/tests/test_memory_storage.py index 11a1e6ee..421af72a 100644 --- a/libraries/microsoft-agents-hosting-core/tests/test_memory_storage.py +++ b/libraries/microsoft-agents-hosting-core/tests/test_memory_storage.py @@ -1,25 +1,11 @@ from microsoft.agents.hosting.core.storage.memory_storage import MemoryStorage -from microsoft.agents.hosting.core.storage.storage_test_utils import ( - CRUDStorageTests, - StorageMock, -) +from microsoft.agents.hosting.core.storage.storage_test_utils import CRUDStorageTests -class MemoryStorageMock(StorageMock): - - def __init__(self, initial_data: dict = None): - +class TestMemoryStorage(CRUDStorageTests): + async def storage(self, initial_data=None): data = { key: value.store_item_to_json() for key, value in (initial_data or {}).items() } - self.storage = MemoryStorage(data) - - def get_backing_store(self): - return self.storage - - -class TestMemoryStorage(CRUDStorageTests): - - async def storage(self, initial_state=None): - return MemoryStorageMock(initial_state) + return MemoryStorage(data) diff --git a/libraries/microsoft-agents-storage-blob/tests/test_blob_storage.py b/libraries/microsoft-agents-storage-blob/tests/test_blob_storage.py index 4e48d5e8..4a80b768 100644 --- a/libraries/microsoft-agents-storage-blob/tests/test_blob_storage.py +++ b/libraries/microsoft-agents-storage-blob/tests/test_blob_storage.py @@ -12,7 +12,6 @@ from microsoft.agents.hosting.core.storage.storage_test_utils import ( CRUDStorageTests, - StorageMock, StorageBaseline, MockStoreItem, MockStoreItemB, @@ -69,15 +68,6 @@ async def blob_storage(): await container_client.delete_container() -class BlobStorageMock(StorageMock): - - def __init__(self, blob_storage): - self.storage = blob_storage - - def get_backing_store(self): - return self.storage - - @pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") class TestBlobStorage(CRUDStorageTests): @@ -90,7 +80,7 @@ async def storage(self, initial_data=None, existing=False): value_rep = json.dumps(value.store_item_to_json()) await container_client.upload_blob(name=key, data=value_rep, overwrite=True) - return BlobStorageMock(storage) + return storage @pytest.mark.asyncio async def test_initialize(self, blob_storage): @@ -104,6 +94,26 @@ async def test_initialize(self, blob_storage): "key": MockStoreItem({"id": "item", "value": "data"}) } + @pytest.mark.asyncio + async def test_external_change_is_visible(self): + blob_storage, container_client = await blob_storage_instance() + assert (await blob_storage.read(["key"], target_cls=MockStoreItem)) == {} + assert (await blob_storage.read(["key2"], target_cls=MockStoreItem)) == {} + await container_client.upload_blob( + name="key", data=json.dumps({"id": "item", "value": "data"}), overwrite=True + ) + await container_client.upload_blob( + name="key2", + data=json.dumps({"id": "another_item", "value": "new_val"}), + overwrite=True, + ) + assert (await blob_storage.read(["key"], target_cls=MockStoreItem))[ + "key" + ] == MockStoreItem({"id": "item", "value": "data"}) + assert (await blob_storage.read(["key2"], target_cls=MockStoreItem))[ + "key2" + ] == MockStoreItem({"id": "another_item", "value": "new_val"}) + @pytest.mark.asyncio async def test_blob_storage_flow_existing_container_and_persistence(self): @@ -183,5 +193,4 @@ async def test_blob_storage_flow_existing_container_and_persistence(self): == initial_data["1230"] ) - # teardown await container_client.delete_container() diff --git a/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/__init__.py b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/__init__.py new file mode 100644 index 00000000..00d54854 --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/__init__.py @@ -0,0 +1,7 @@ +from .cosmos_db_storage import CosmosDBStorage +from .cosmos_db_storage_config import CosmosDBStorageConfig + +__all__ = [ + "CosmosDBStorage", + "CosmosDBStorageConfig", +] diff --git a/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/cosmos_db_storage.py b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/cosmos_db_storage.py new file mode 100644 index 00000000..ea589c62 --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/cosmos_db_storage.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from typing import TypeVar, Union +from threading import Lock + +from azure.cosmos import ( + documents, + http_constants, + CosmosDict, +) +from azure.cosmos.aio import ( + ContainerProxy, + CosmosClient, + DatabaseProxy, +) +import azure.cosmos.exceptions as cosmos_exceptions +from azure.cosmos.partition_key import NonePartitionKeyValue + +from microsoft.agents.hosting.core.storage import AsyncStorageBase, StoreItem +from microsoft.agents.hosting.core.storage._type_aliases import JSON +from microsoft.agents.hosting.core.storage.error_handling import ignore_error + +from .cosmos_db_storage_config import CosmosDBStorageConfig +from .key_ops import sanitize_key + +StoreItemT = TypeVar("StoreItemT", bound=StoreItem) + +cosmos_resource_not_found = lambda err: isinstance( + err, cosmos_exceptions.CosmosResourceNotFoundError +) + + +class CosmosDBStorage(AsyncStorageBase): + """A CosmosDB based storage provider using partitioning""" + + def __init__(self, config: CosmosDBStorageConfig): + """Create the storage object. + + :param config: + """ + super().__init__() + + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + self._config: CosmosDBStorageConfig = config + self._client: CosmosClient = self._create_client() + self._database: DatabaseProxy = None + self._container: ContainerProxy = None + self._compatability_mode_partition_key: bool = False + # Lock used for synchronizing container creation + self._lock: Lock = Lock() + + def _create_client(self) -> CosmosClient: + if self._config.url: + if not self._config.credential: + raise ValueError( + "CosmosDBStorage: Credential is required when using a custom service URL." + ) + return CosmosClient( + account_url=self._config.url, credential=self._config.credential + ) + + connection_policy = self._config.cosmos_client_options.get( + "connection_policy", documents.ConnectionPolicy() + ) + + # kwargs 'connection_verify' is to handle CosmosClient overwriting the + # ConnectionPolicy.DisableSSLVerification value. + return CosmosClient( + self._config.cosmos_db_endpoint, + self._config.auth_key, + consistency_level=self._config.cosmos_client_options.get( + "consistency_level", None + ), + **{ + "connection_policy": connection_policy, + "connection_verify": not connection_policy.DisableSSLVerification, + }, + ) + + def _sanitize(self, key: str) -> str: + return sanitize_key( + key, self._config.key_suffix, self._config.compatibility_mode + ) + + async def _read_item( + self, key: str, *, target_cls: StoreItemT = None, **kwargs + ) -> tuple[Union[str, None], Union[StoreItemT, None]]: + + if key == "": + raise ValueError("CosmosDBStorage: Key cannot be empty.") + + escaped_key: str = self._sanitize(key) + read_item_response: CosmosDict = await ignore_error( + self._container.read_item( + escaped_key, self._get_partition_key(escaped_key) + ), + cosmos_resource_not_found, + ) + if read_item_response is None: + return None, None + + doc: JSON = read_item_response.get("document") + return read_item_response["realId"], target_cls.from_json_to_store_item(doc) + + async def _write_item(self, key: str, item: StoreItem) -> None: + if key == "": + raise ValueError("CosmosDBStorage: Key cannot be empty.") + + escaped_key: str = self._sanitize(key) + + doc = { + "id": escaped_key, + "realId": key, # to retrieve the raw key later + "document": item.store_item_to_json(), + } + await self._container.upsert_item(body=doc) + + async def _delete_item(self, key: str) -> None: + if key == "": + raise ValueError("CosmosDBStorage: Key cannot be empty.") + + escaped_key: str = self._sanitize(key) + + await ignore_error( + self._container.delete_item( + escaped_key, self._get_partition_key(escaped_key) + ), + cosmos_resource_not_found, + ) + + async def _create_container(self) -> None: + partition_key = { + "paths": ["/id"], + "kind": documents.PartitionKind.Hash, + } + try: + self._container = await self._database.create_container( + self._config.container_id, + partition_key, + offer_throughput=self._config.container_throughput, + ) + except cosmos_exceptions.CosmosHttpResponseError as err: + if err.status_code == http_constants.StatusCodes.CONFLICT: + self._container = self._database.get_container_client( + self._config.container_id + ) + properties = await self._container.read() + # if "partitionKey" not in properties: + # self._compatability_mode_partition_key = True + # else: + # containers created had no partition key, so the default was "/_partitionKey" + paths = properties["partitionKey"]["paths"] + if "/_partitionKey" in paths: + self._compatability_mode_partition_key = True + elif "/id" not in paths: + raise Exception( + f"Custom Partition Key Paths are not supported. {self._config.container_id} " + "has a custom Partition Key Path of {paths[0]}." + ) + else: + raise err + + async def initialize(self) -> None: + if not self._container: + with self._lock: + # in case another thread attempted to initialize just before acquiring the lock + if self._container: + return + + if not self._database: + self._database = await self._client.create_database_if_not_exists( + self._config.database_id + ) + + await self._create_container() + + def _get_partition_key(self, key: str): + return NonePartitionKeyValue if self._compatability_mode_partition_key else key diff --git a/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/cosmos_db_storage_config.py b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/cosmos_db_storage_config.py new file mode 100644 index 00000000..e70ec138 --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/cosmos_db_storage_config.py @@ -0,0 +1,94 @@ +import json +from typing import Union + +from azure.core.credentials import TokenCredential + +from .key_ops import sanitize_key + + +class CosmosDBStorageConfig: + """The class for partitioned CosmosDB configuration for the Azure Bot Framework.""" + + def __init__( + self, + cosmos_db_endpoint: str = "", + auth_key: str = "", + database_id: str = "", + container_id: str = "", + cosmos_client_options: dict = None, + container_throughput: int = 0, + key_suffix: str = "", + compatibility_mode: bool = False, + url: str = "", + credential: Union[TokenCredential, None] = None, + **kwargs, + ): + """Create the Config object. + + :param cosmos_db_endpoint: The CosmosDB endpoint. + :param auth_key: The authentication key for Cosmos DB. + :param database_id: The database identifier for Cosmos DB instance. + :param container_id: The container identifier. + :param cosmos_client_options: The options for the CosmosClient. Currently only supports connection_policy and + consistency_level + :param container_throughput: The throughput set when creating the Container. Defaults to 400. + :param key_suffix: The suffix to be added to every key. The keySuffix must contain only valid ComosDb + key characters. (e.g. not: '\\', '?', '/', '#', '*') + :param compatibility_mode: True if keys should be truncated in order to support previous CosmosDb + max key length of 255. + :param url: The URL to the CosmosDB resource. + :param credential: The TokenCredential to use for authentication. + :return CosmosDBConfig: + """ + config_file: str = kwargs.get("filename", "") + if config_file: + kwargs = json.load(open(config_file)) + self.cosmos_db_endpoint: str = cosmos_db_endpoint or kwargs.get( + "cosmos_db_endpoint", "" + ) + self.auth_key: str = auth_key or kwargs.get("auth_key", "") + self.database_id: str = database_id or kwargs.get("database_id", "") + self.container_id: str = container_id or kwargs.get("container_id", "") + self.cosmos_client_options: dict = cosmos_client_options or kwargs.get( + "cosmos_client_options", {} + ) + self.container_throughput: int = container_throughput or kwargs.get( + "container_throughput", 400 + ) + self.key_suffix: str = key_suffix or kwargs.get("key_suffix", "") + self.compatibility_mode: bool = compatibility_mode or kwargs.get( + "compatibility_mode", False + ) + self.url = url or kwargs.get("url", "") + self.credential: Union[TokenCredential, None] = credential + + @staticmethod + def validate_cosmos_db_config(config: "CosmosDBStorageConfig") -> None: + """Validate the CosmosDBConfig object. + + This is used prior to the creation of the CosmosDBStorage object.""" + if not config: + raise ValueError("CosmosDBStorage: CosmosDBConfig is required.") + if not config.cosmos_db_endpoint: + raise ValueError("CosmosDBStorage: cosmos_db_endpoint is required.") + if not config.auth_key: + raise ValueError("CosmosDBStorage: auth_key is required.") + if not config.database_id: + raise ValueError("CosmosDBStorage: database_id is required.") + if not config.container_id: + raise ValueError("CosmosDBStorage: container_id is required.") + + CosmosDBStorageConfig._validate_suffix(config) + + @staticmethod + def _validate_suffix(config: "CosmosDBStorageConfig") -> None: + if config.key_suffix: + if config.compatibility_mode: + raise ValueError( + "compatibilityMode cannot be true while using a keySuffix." + ) + suffix_escaped: str = sanitize_key(config.key_suffix) + if suffix_escaped != config.key_suffix: + raise ValueError( + f"Cannot use invalid Row Key characters: {config.key_suffix} in keySuffix." + ) diff --git a/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/key_ops.py b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/key_ops.py new file mode 100644 index 00000000..8e91a598 --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/microsoft/agents/storage/cosmos/key_ops.py @@ -0,0 +1,45 @@ +from hashlib import sha256 + + +def sanitize_key( + key: str, key_suffix: str = "", compatibility_mode: bool = True +) -> str: + """Return the sanitized key. + + Replace characters that are not allowed in keys in Cosmos. + + :param key: The provided key to be escaped. + :param key_suffix: The string to add a the end of all RowKeys. + :param compatibility_mode: True if keys should be truncated in order to support previous CosmosDb + max key length of 255. This behavior can be overridden by setting + cosmosdb_config.compatibility_mode to False. + :return str: + """ + # forbidden characters + bad_chars: list[str] = ["\\", "?", "/", "#", "\t", "\n", "\r", "*"] + + # replace those with with '*' and the + # Unicode code point of the character and return the new string + key = "".join(map(lambda x: "*" + str(ord(x)) if x in bad_chars else x, key)) + return truncate_key(f"{key}{key_suffix}", compatibility_mode) + + +def truncate_key(key: str, compatibility_mode: bool = True) -> str: + """ + Truncate the key to 255 characters if compatibility_mode is True. If the key is longer than 255 characters, + it will be truncated and a SHA-256 hash of the original key will be appended to minimize collisions. + """ + max_key_len: int = 255 + + if not compatibility_mode: + return key + + if len(key) > max_key_len: + # for now (and the foreseeable future), SHA-256 collisions are pretty infentesimally rare: + # https://stackoverflow.com/questions/4014090/is-it-safe-to-ignore-the-possibility-of-sha-collisions-in-practice + aux_hash = sha256(key.encode("utf-8")) + aux_hex = aux_hash.hexdigest() + + key = key[0 : max_key_len - len(aux_hex)] + aux_hex + + return key diff --git a/libraries/microsoft-agents-storage-cosmos/pyproject.toml b/libraries/microsoft-agents-storage-cosmos/pyproject.toml new file mode 100644 index 00000000..611d6e46 --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/pyproject.toml @@ -0,0 +1,23 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "microsoft-agents-storage-cosmos" +version = "0.0.0a1" +description = "A Cosmos DB storage library for Microsoft Agents" +authors = [{name = "Microsoft Corporation"}] +requires-python = ">=3.9" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [ + "microsoft-agents-hosting-core", + "azure-core", + "azure-cosmos", +] + +[project.urls] +"Homepage" = "https://github.com/microsoft/Agents" diff --git a/libraries/microsoft-agents-storage-cosmos/tests/test_cosmos_db_config.py b/libraries/microsoft-agents-storage-cosmos/tests/test_cosmos_db_config.py new file mode 100644 index 00000000..02e47b30 --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/tests/test_cosmos_db_config.py @@ -0,0 +1,245 @@ +import json +import pytest + +from microsoft.agents.storage.cosmos import CosmosDBStorageConfig + +# thank you AI, again + + +@pytest.fixture() +def valid_config(): + """Fixture providing a valid CosmosDBStorageConfig for tests""" + return CosmosDBStorageConfig( + cosmos_db_endpoint="https://localhost:8081", + auth_key=( + "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGG" + "yPMbIZnqyMsEcaGQy67XIw/Jw==" + ), + database_id="test-db", + container_id="bot-storage", + ) + + +@pytest.fixture() +def minimal_config(): + """Fixture providing a minimal CosmosDBStorageConfig for tests""" + return CosmosDBStorageConfig() + + +@pytest.fixture() +def config_with_options(): + """Fixture providing a CosmosDBStorageConfig with all options for tests""" + return CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + database_id="test_db", + container_id="test_container", + cosmos_client_options={"connection_policy": "test"}, + container_throughput=800, + key_suffix="_test", + compatibility_mode=False, + ) + + +class TestCosmosDBStorageConfig: + + def test_constructor_with_parameters(self): + """Test creating config with direct parameters""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + database_id="test_db", + container_id="test_container", + container_throughput=800, + key_suffix="_test", + compatibility_mode=False, + ) + + assert config.cosmos_db_endpoint == "https://test.documents.azure.com:443/" + assert config.auth_key == "test_key" + assert config.database_id == "test_db" + assert config.container_id == "test_container" + assert config.container_throughput == 800 + assert config.key_suffix == "_test" + assert config.compatibility_mode is False + assert config.cosmos_client_options == {} + assert config.credential is None + + def test_constructor_with_defaults(self): + """Test creating config with default values""" + config = CosmosDBStorageConfig() + + assert config.cosmos_db_endpoint == "" + assert config.auth_key == "" + assert config.database_id == "" + assert config.container_id == "" + assert config.container_throughput == 400 # Default value + assert config.key_suffix == "" + assert config.compatibility_mode is False + assert config.cosmos_client_options == {} + assert config.credential is None + + def test_from_file(self, tmp_path): + """Test creating config from JSON file""" + config_file_path = tmp_path / "cosmos_config.json" + + config_data = { + "cosmos_db_endpoint": "https://localhost:8081", + "auth_key": "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==", + "database_id": "test-db", + "container_id": "bot-storage", + "container_throughput": 600, + "key_suffix": "_file", + "compatibility_mode": True, + "cosmos_client_options": {"connection_policy": "test"}, + } + + with open(config_file_path, "w") as f: + json.dump(config_data, f) + + config = CosmosDBStorageConfig(filename=str(config_file_path)) + + assert config.cosmos_db_endpoint == "https://localhost:8081" + assert ( + config.auth_key + == "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==" + ) + assert config.database_id == "test-db" + assert config.container_id == "bot-storage" + assert config.container_throughput == 600 + assert config.key_suffix == "_file" + assert config.compatibility_mode is True + assert config.cosmos_client_options == {"connection_policy": "test"} + + def test_parameter_override_file(self, tmp_path): + """Test that constructor parameters override file values""" + config_file_path = tmp_path / "cosmos_config.json" + + with open(config_file_path, "w") as f: + json.dump( + { + "cosmos_db_endpoint": "https://file-endpoint.com", + "auth_key": "file_key", + "database_id": "file_db", + }, + f, + ) + + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://param-endpoint.com", + auth_key="param_key", + filename=str(config_file_path), + ) + + # Parameters should override file values + assert config.cosmos_db_endpoint == "https://param-endpoint.com" + assert config.auth_key == "param_key" + # File value should be used when parameter not provided + assert config.database_id == "file_db" + + def test_validation_success(self): + """Test successful validation with all required fields""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + database_id="test_db", + container_id="test_container", + ) + + # Should not raise any exception + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_validation_missing_config(self): + """Test validation with None config""" + with pytest.raises(ValueError): + CosmosDBStorageConfig.validate_cosmos_db_config(None) + + def test_validation_missing_endpoint(self): + """Test validation with missing cosmos_db_endpoint""" + config = CosmosDBStorageConfig( + auth_key="test_key", database_id="test_db", container_id="test_container" + ) + with pytest.raises(ValueError): + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_validation_missing_auth_key(self): + """Test validation with missing auth_key""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + database_id="test_db", + container_id="test_container", + ) + with pytest.raises(ValueError): + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_validation_missing_database_id(self): + """Test validation with missing database_id""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + container_id="test_container", + ) + with pytest.raises(ValueError): + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_validation_missing_container_id(self): + """Test validation with missing container_id""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + database_id="test_db", + ) + with pytest.raises(ValueError): + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_validation_suffix_with_compatibility_mode(self): + """Test validation fails when using suffix with compatibility mode""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + database_id="test_db", + container_id="test_container", + key_suffix="_test", + compatibility_mode=True, + ) + with pytest.raises(ValueError): + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_validation_invalid_suffix_characters(self): + """Test validation fails with invalid characters in suffix""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + database_id="test_db", + container_id="test_container", + key_suffix="invalid/suffix\\with?bad#chars", + compatibility_mode=False, + ) + with pytest.raises(ValueError, match="Cannot use invalid Row Key characters"): + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_validation_valid_suffix(self): + """Test validation succeeds with valid suffix""" + config = CosmosDBStorageConfig( + cosmos_db_endpoint="https://test.documents.azure.com:443/", + auth_key="test_key", + database_id="test_db", + container_id="test_container", + key_suffix="valid_suffix_123", + compatibility_mode=False, + ) + # Should not raise any exception + CosmosDBStorageConfig.validate_cosmos_db_config(config) + + def test_cosmos_client_options(self): + """Test cosmos_client_options handling""" + options = {"connection_policy": "test", "consistency_level": "strong"} + config = CosmosDBStorageConfig(cosmos_client_options=options) + assert config.cosmos_client_options == options + + def test_credential_parameter(self): + """Test credential parameter handling""" + # Mock credential (in real usage this would be a TokenCredential instance) + mock_credential = object() # Placeholder for actual TokenCredential + config = CosmosDBStorageConfig(credential=mock_credential) + assert config.credential is mock_credential diff --git a/libraries/microsoft-agents-storage-cosmos/tests/test_cosmos_db_storage.py b/libraries/microsoft-agents-storage-cosmos/tests/test_cosmos_db_storage.py new file mode 100644 index 00000000..cd528db1 --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/tests/test_cosmos_db_storage.py @@ -0,0 +1,299 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import gc + +import pytest +import pytest_asyncio + +from azure.cosmos import documents +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosResourceNotFoundError + +from microsoft.agents.storage.cosmos import CosmosDBStorage, CosmosDBStorageConfig +from microsoft.agents.storage.cosmos.key_ops import sanitize_key + +from microsoft.agents.hosting.core.storage.storage_test_utils import ( + QuickCRUDStorageTests, + MockStoreItem, + MockStoreItemB, + StorageBaseline, +) + +EMULATOR_RUNNING = False + + +def create_config(compat_mode): + return CosmosDBStorageConfig( + cosmos_db_endpoint="https://localhost:8081", + auth_key=( + "C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGG" + "yPMbIZnqyMsEcaGQy67XIw/Jw==" + ), + database_id="test-db", + container_id="bot-storage", + compatibility_mode=compat_mode, + container_throughput=800, + ) + + +@pytest.fixture +def config(): + return create_config(compat_mode=False) + + +async def create_cosmos_env(config, compat_mode=False, existing=False): + """Creates the Cosmos DB environment for testing. + + If existing is False, creates a new database and container, deleting any + existing ones with the same name. If existing is True, creates the database + and container if they do not already exist.""" + + cosmos_client = CosmosClient( + config.cosmos_db_endpoint, + config.auth_key, + ) + + if not existing: + try: + await cosmos_client.delete_database(config.database_id) + except Exception: + pass + database = await cosmos_client.create_database(id=config.database_id) + + try: + await database.delete_container(config.container_id) + except Exception: + pass + + partition_key = { + "paths": ["/_partitionKey"] if compat_mode else ["/id"], + "kind": documents.PartitionKind.Hash, + } + container_client = await database.create_container( + id=config.container_id, + partition_key=partition_key, + offer_throughput=config.container_throughput, + ) + else: + database = await cosmos_client.create_database_if_not_exists( + id=config.database_id + ) + container_client = database.get_container_client(config.container_id) + + return container_client + + +async def cosmos_db_storage_instance(compat_mode=False, existing=False): + config = create_config(compat_mode) + container_client = await create_cosmos_env( + config, compat_mode=compat_mode, existing=existing + ) + storage = CosmosDBStorage(config) + return storage, container_client + + +@pytest_asyncio.fixture() +async def cosmos_db_storage(): + storage, _ = await cosmos_db_storage_instance() + return storage + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_require_compat", [True, False]) +@pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") +async def test_cosmos_db_storage_flow_existing_container_and_persistence( + test_require_compat, +): + + config = create_config(compat_mode=test_require_compat) + container_client = await create_cosmos_env(config) + + initial_data = { + "__some_key": MockStoreItem({"id": "item2", "value": "data2"}), + "?test": MockStoreItem({"id": "?test", "value": "data1"}), + "!another_key": MockStoreItem({"id": "item3", "value": "data3"}), + "1230": MockStoreItemB({"id": "item8", "value": "data"}, False), + "key-with-dash": MockStoreItem({"id": "item4", "value": "data"}), + "key.with.dot": MockStoreItem({"id": "item5", "value": "data"}), + "key/with/slash": MockStoreItem({"id": "item6", "value": "data"}), + "another key": MockStoreItemB({"id": "item7", "value": "data"}, True), + } + + baseline_storage = StorageBaseline(initial_data) + + for key, value in initial_data.items(): + doc = { + "id": sanitize_key( + key, + config.key_suffix, + test_require_compat, + ), + "realId": key, + "document": value.store_item_to_json(), + } + await container_client.upsert_item(body=doc) + + storage = CosmosDBStorage(config) + assert await baseline_storage.equals(storage) + assert ( + await storage.read(["1230", "another key"], target_cls=MockStoreItemB) + ) == baseline_storage.read(["1230", "another key"]) + + changes = { + "?test": MockStoreItem({"id": "?test", "value": "data1_changed"}), + "__some_key": MockStoreItem({"id": "item2", "value": "data2_changed"}), + "new_item": MockStoreItem({"id": "new_item", "value": "new_data"}), + } + + baseline_storage.write(changes) + await storage.write(changes) + + baseline_storage.delete(["!another_key", "?test"]) + await storage.delete(["!another_key", "?test"]) + assert await baseline_storage.equals(storage) + + del storage + gc.collect() + storage = CosmosDBStorage(config) + + escaped_key = storage._sanitize("?test") + with pytest.raises(CosmosResourceNotFoundError): + await container_client.read_item( + escaped_key, storage._get_partition_key(escaped_key) + ) + + escaped_key = storage._sanitize("1230") + item = ( + await container_client.read_item( + escaped_key, storage._get_partition_key(escaped_key) + ) + ).get("document") + assert MockStoreItemB.from_json_to_store_item(item) == initial_data["1230"] + + +@pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") +class TestCosmosDBStorage(QuickCRUDStorageTests): + + def get_compat_mode(self): + return False + + async def storage(self, initial_data=None, existing=False): + storage, _ = await cosmos_db_storage_instance( + compat_mode=self.get_compat_mode(), existing=existing + ) + if initial_data: + await storage.write(initial_data) + return storage + + @pytest.mark.asyncio + async def test_initialize(self, cosmos_db_storage): + await cosmos_db_storage.initialize() + await cosmos_db_storage.initialize() + await cosmos_db_storage.write( + {"some_Key": MockStoreItem({"id": "123", "data": "value"})} + ) + await cosmos_db_storage.initialize() + assert ( + await cosmos_db_storage.read(["some_Key"], target_cls=MockStoreItem) + ) == {"some_Key": MockStoreItem({"id": "123", "data": "value"})} + + @pytest.mark.asyncio + async def test_external_change_is_visible(self): + cosmos_storage, container_client = await cosmos_db_storage_instance() + assert (await cosmos_storage.read(["key"], target_cls=MockStoreItem)) == {} + assert (await cosmos_storage.read(["key2"], target_cls=MockStoreItem)) == {} + await container_client.upsert_item( + { + "id": "key", + "realId": "key", + "document": {"id": "key", "value": "data"}, + "partitionKey": "", + } + ) + await container_client.upsert_item( + { + "id": "key2", + "realId": "key2", + "document": {"id": "key2", "value": "new_val"}, + "partitionKey": "", + } + ) + assert (await cosmos_storage.read(["key"], target_cls=MockStoreItem))[ + "key" + ] == MockStoreItem({"id": "key", "value": "data"}) + assert (await cosmos_storage.read(["key2"], target_cls=MockStoreItem))[ + "key2" + ] == MockStoreItem({"id": "key2", "value": "new_val"}) + + +@pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") +class TestCosmosDBStorageWithCompat(TestCosmosDBStorage): + def get_compat_mode(self): + return True + + +@pytest.mark.skipif(not EMULATOR_RUNNING, reason="Needs the emulator to run.") +class TestCosmosDBStorageInit: + + def test_raises_error_when_no_endpoint_provided(self, config): + config.cosmos_db_endpoint = None + with pytest.raises(ValueError): + CosmosDBStorage(config) + + def test_raises_error_when_no_auth_key_provided(self, config): + config.auth_key = None + with pytest.raises(ValueError): + CosmosDBStorage(config) + + def test_raises_error_when_suffix_provided_but_compat(self, config): + config.auth_key = None + config.compatibility_mode = True + with pytest.raises(ValueError): + CosmosDBStorage(config) + + def test_raises_error_when_no_database_id_provided(self, config): + config.database_id = None + with pytest.raises(ValueError): + CosmosDBStorage(config) + + def test_raises_error_when_no_container_id_provided(self, config): + config.container_id = None + with pytest.raises(ValueError): + CosmosDBStorage(config) + + @pytest.mark.asyncio + @pytest.mark.parametrize("compat_mode", [True, False]) + async def test_raises_error_different_partition_key(self, compat_mode): + config = create_config(compat_mode=compat_mode) + await create_cosmos_env(config, compat_mode=compat_mode) + storage = CosmosDBStorage(config) + + with pytest.raises(Exception): + + cosmos_client = CosmosClient( + config.cosmos_db_endpoint, + config.auth_key, + ) + try: + await cosmos_client.delete_database(config.database_id) + except Exception: + pass + database = await cosmos_client.create_database(id=config.database_id) + + try: + await database.delete_container(config.container_id) + except Exception: + pass + + partition_key = { + "paths": ["/fake_part_key"], + "kind": documents.PartitionKind.Hash, + } + container_client = await database.create_container( + id=config.container_id, + partition_key=partition_key, + offer_throughput=config.container_throughput, + ) + storage = CosmosDBStorage(config) + await storage.initialize() diff --git a/libraries/microsoft-agents-storage-cosmos/tests/test_key_ops.py b/libraries/microsoft-agents-storage-cosmos/tests/test_key_ops.py new file mode 100644 index 00000000..e554370f --- /dev/null +++ b/libraries/microsoft-agents-storage-cosmos/tests/test_key_ops.py @@ -0,0 +1,250 @@ +import hashlib +import pytest +from microsoft.agents.storage.cosmos.key_ops import truncate_key, sanitize_key + +# thank you AI + + +@pytest.mark.parametrize( + "input_key,expected", + [ + ("validKey123", "validKey123"), + ("simple", "simple"), + ("CamelCase", "CamelCase"), + ("under_score", "under_score"), + ("with-dash", "with-dash"), + ("with.dot", "with.dot"), + ], +) +def test_sanitize_key_simple(input_key, expected): + assert sanitize_key(input_key) == expected + + +@pytest.mark.parametrize( + "input_key,expected", + [ + ("key\\value", "key*92value"), + ("key?value", "key*63value"), + ("key/value", "key*47value"), + ("key#value", "key*35value"), + ("key\tvalue", "key*9value"), + ("key\nvalue", "key*10value"), + ("key\rvalue", "key*13value"), + ("key*value", "key*42value"), + ], +) +def test_sanitize_key_forbidden_chars(input_key, expected): + assert sanitize_key(input_key) == expected + + +@pytest.mark.parametrize( + "input_key,expected", + [ + ("key/with\\many?bad#chars", "key*47with*92many*63bad*35chars"), + ("a\\b/c?d#e\tf\ng\rh*i", "a*92b*47c*63d*35e*9f*10g*13h*42i"), + ("key/with\\many?bad#chars", "key*47with*92many*63bad*35chars"), + ], +) +def test_sanitize_key_multiple_forbidden_chars(input_key, expected): + assert sanitize_key(input_key) == expected + + +def test_sanitize_key_with_long_key_with_forbidden_chars(): + long_key = "a?2/!@\t3." * 100 # Create a long key + sanitized = sanitize_key(long_key) + assert len(sanitized) <= 255 # Should be truncated + # Ensure forbidden characters are replaced + assert "?" not in sanitized + assert "/" not in sanitized + assert "\t" not in sanitized + + +def test_sanitize_key_with_long_key_with_forbidden_chars_with_suffix(): + long_key = "a?2/!@\t3." * 100 # Create a long key + sanitized = sanitize_key(long_key, key_suffix="_suff?#*") + assert len(sanitized) <= 255 # Should be truncated + # Ensure forbidden characters are replaced + assert "?" not in sanitized + assert "/" not in sanitized + assert "#" not in sanitized + + +def test_sanitize_key_with_long_key_with_forbidden_chars_with_suffix_compat_mode(): + long_key = "a?2/!@\t3." * 100 # Create a long key + sanitized = sanitize_key(long_key, key_suffix="_suff?#*", compatibility_mode=True) + assert len(sanitized) <= 255 # Should be truncated + # Ensure forbidden characters are replaced + assert "?" not in sanitized + assert "/" not in sanitized + assert "#" not in sanitized + + +@pytest.mark.parametrize( + "input_key,expected", + [ + ("", ""), + (" ", " "), + ], +) +def test_sanitize_key_empty_and_whitespace(input_key, expected): + assert sanitize_key(input_key) == expected + + +@pytest.mark.parametrize( + "input_key,suffix,expected", + [ + ("key", "_suffix", "key_suffix"), + ("test", "123", "test123"), + ("key/value", "_clean", "key*47value_clean"), + ("", "_suffix", "_suffix"), + ], +) +def test_sanitize_key_with_suffix(input_key, suffix, expected): + assert sanitize_key(input_key, key_suffix=suffix) == expected + + +def test_sanitize_key_suffix_with_truncation(): + long_key = "a" * 250 + suffix = "_suffix" + result = sanitize_key(long_key, key_suffix=suffix, compatibility_mode=True) + assert len(result) <= 255 + assert ( + result.endswith(suffix) or len(result) == 255 + ) # Either has suffix or was truncated + + +def test_sanitize_key_truncation_compatibility_mode(): + long_key = "a" * 300 + result = sanitize_key(long_key, compatibility_mode=True) + assert len(result) <= 255 + + # Should contain hash when truncated + very_long_key = "b" * 500 + result2 = sanitize_key(very_long_key, compatibility_mode=True) + assert len(result2) == 255 + + +def test_sanitize_key_no_truncation(): + long_key = "a" * 300 + result = sanitize_key(long_key, compatibility_mode=False) + assert result == long_key # Should be unchanged + assert len(result) == 300 + + +@pytest.mark.parametrize( + "input_key,expected", + [ + ("short", "short"), + ("a" * 254, "a" * 254), + ("a" * 255, "a" * 255), + ], +) +def test_truncate_key_short_strings(input_key, expected): + assert truncate_key(input_key) == expected + + +def test_truncate_key_long_strings(): + long_key = "a" * 300 + result = truncate_key(long_key) + assert len(result) == 255 + + # Result should end with SHA256 hash + expected_hash = hashlib.sha256(long_key.encode("utf-8")).hexdigest() + assert result.endswith(expected_hash) + + # First part should be original key truncated + expected_prefix_len = 255 - len(expected_hash) + assert result.startswith("a" * expected_prefix_len) + + +@pytest.mark.parametrize( + "input_key,compatibility_mode,expected_unchanged", + [ + ("a" * 300, False, True), # Should be unchanged + ("x" * 1000, False, True), # Should be unchanged + ( + "key/with\\special?chars#and\ttabs\nand\rmore*", + False, + True, + ), # Should be unchanged + ], +) +def test_truncate_key_compatibility_mode_disabled( + input_key, compatibility_mode, expected_unchanged +): + result = truncate_key(input_key, compatibility_mode=compatibility_mode) + if expected_unchanged: + assert result == input_key + + +@pytest.mark.parametrize( + "input_key,expected_length", + [ + ("a" * 255, 255), + ("a" * 256, 255), + ], +) +def test_truncate_key_exact_and_over_limit(input_key, expected_length): + result = truncate_key(input_key) + assert len(result) == expected_length + + if len(input_key) == 255: + assert result == input_key + else: + assert result != input_key + + +def test_truncate_key_hash_consistency(): + long_key = "consistent_test_key_" * 20 # > 255 chars + result1 = truncate_key(long_key) + result2 = truncate_key(long_key) + assert result1 == result2 + assert len(result1) == 255 + + +@pytest.mark.parametrize( + "key1,key2", + [ + ("a" * 300, "b" * 300), + ("consistent_test_key_" * 20, "different_test_key_" * 20), + ], +) +def test_truncate_key_different_inputs_different_outputs(key1, key2): + result1 = truncate_key(key1) + result2 = truncate_key(key2) + assert result1 != result2 + assert len(result1) == len(result2) == 255 + + +def test_sanitize_key_integration(): + # Key with forbidden chars that will be long after sanitization + suffix + base_key = "test/key\\with?many#forbidden\tchars\nand\rmore*" * 10 + suffix = "_integration_test" + + result = sanitize_key(base_key, key_suffix=suffix, compatibility_mode=True) + + # Should be sanitized and truncated + assert len(result) <= 255 + assert "*47" in result or "*92" in result # Contains sanitized chars + + # Test without truncation + result_no_trunc = sanitize_key( + base_key, key_suffix=suffix, compatibility_mode=False + ) + assert ( + "*47" in result_no_trunc or "*92" in result_no_trunc + ) # Contains sanitized chars + assert result_no_trunc.endswith(suffix) + + +@pytest.mark.parametrize( + "input_key,expected", + [ + ("key_ñ_测试", "key_ñ_测试"), + ("123456789", "123456789"), + ("MyKey/WithSlash", "MyKey*47WithSlash"), + ], +) +def test_edge_cases(input_key, expected): + result = sanitize_key(input_key) + assert result == expected