From 5df96281a9d07b576c929e39770e0b3f942e36d4 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Thu, 30 Oct 2025 18:29:34 +0000 Subject: [PATCH 01/35] sdks/python: replace the deprecated testcontainer max tries --- .../ml/rag/enrichment/milvus_search_it_test.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py index 2df9af2f1144..5094d9076e93 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py @@ -53,7 +53,6 @@ MilvusClient, RRFRanker) from pymilvus.milvus_client import IndexParams - from testcontainers.core.config import MAX_TRIES as TC_MAX_TRIES from testcontainers.core.config import testcontainers_config from testcontainers.core.generic import DbContainer from testcontainers.milvus import MilvusContainer @@ -306,13 +305,15 @@ def start_db_container( image="milvusdb/milvus:v2.5.10", max_vec_fields=5, vector_client_max_retries=3, - tc_max_retries=TC_MAX_TRIES) -> Optional[MilvusDBContainerInfo]: + tc_max_retries=None) -> Optional[MilvusDBContainerInfo]: service_container_port = MilvusEnrichmentTestHelper.find_free_port() healthcheck_container_port = MilvusEnrichmentTestHelper.find_free_port() user_yaml_creator = MilvusEnrichmentTestHelper.create_user_yaml with user_yaml_creator(service_container_port, max_vec_fields) as cfg: info = None - testcontainers_config.max_tries = tc_max_retries + original_tc_max_tries = testcontainers_config.max_tries + if not testcontainers_config.max_tries: + testcontainers_config.max_tries = tc_max_retries for i in range(vector_client_max_retries): try: vector_db_container = CustomMilvusContainer( @@ -325,7 +326,7 @@ def start_db_container( host = vector_db_container.get_container_host_ip() port = vector_db_container.get_exposed_port(service_container_port) info = MilvusDBContainerInfo(vector_db_container, host, port) - testcontainers_config.max_tries = TC_MAX_TRIES + testcontainers_config.max_tries = original_tc_max_tries _LOGGER.info( "milvus db container started successfully on %s.", info.uri) break From 91266a7db291dde1d0536f90be6ecea223b6ced4 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Thu, 30 Oct 2025 19:00:03 +0000 Subject: [PATCH 02/35] sdks/python: handle transient testcontainer startup/teardown errors --- .../transforms/elementwise/enrichment_test.py | 30 +++++++++++-------- .../rag/enrichment/milvus_search_it_test.py | 15 ++++------ 2 files changed, 24 insertions(+), 21 deletions(-) diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index c8e988a52c5d..083d246a439a 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -68,6 +68,9 @@ class TestContainerStartupError(Exception): """Raised when any test container fails to start.""" pass +class TestContainerTeardownError(Exception): + """Raised when any test container fails to teardown.""" + pass def validate_enrichment_with_bigtable(): expected = '''[START enrichment_with_bigtable] @@ -186,7 +189,7 @@ def test_enrichment_with_external_pg(self, mock_stdout): output = mock_stdout.getvalue().splitlines() expected = validate_enrichment_with_external_pg() self.assertEqual(output, expected) - except TestContainerStartupError as e: + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -199,7 +202,7 @@ def test_enrichment_with_external_mysql(self, mock_stdout): output = mock_stdout.getvalue().splitlines() expected = validate_enrichment_with_external_mysql() self.assertEqual(output, expected) - except TestContainerStartupError as e: + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -212,7 +215,7 @@ def test_enrichment_with_external_sqlserver(self, mock_stdout): output = mock_stdout.getvalue().splitlines() expected = validate_enrichment_with_external_sqlserver() self.assertEqual(output, expected) - except TestContainerStartupError as e: + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -227,7 +230,7 @@ def test_enrichment_with_milvus(self, mock_stdout): output = parse_chunk_strings(output) expected = parse_chunk_strings(expected) assert_chunks_equivalent(output, expected) - except TestContainerStartupError as e: + except (TestContainerStartupError, TestContainerTeardownError) as e: raise unittest.SkipTest(str(e)) except Exception as e: self.fail(f"Test failed with unexpected error: {e}") @@ -373,19 +376,17 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct): def pre_milvus_enrichment() -> MilvusDBContainerInfo: try: db = MilvusEnrichmentTestHelper.start_db_container() - except Exception as e: - raise TestContainerStartupError( - f"Milvus container failed to start: {str(e)}") - - connection_params = MilvusConnectionParameters( + connection_params = MilvusConnectionParameters( uri=db.uri, user=db.user, password=db.password, db_id=db.id, token=db.token) - - collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( + collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( connection_params) + except Exception as e: + raise TestContainerStartupError( + f"Milvus container failed to start: {str(e)}") # Setup environment variables for db and collection configuration. This will # be used downstream by the milvus enrichment handler. @@ -400,7 +401,12 @@ def pre_milvus_enrichment() -> MilvusDBContainerInfo: @staticmethod def post_milvus_enrichment(db: MilvusDBContainerInfo): - MilvusEnrichmentTestHelper.stop_db_container(db) + try: + MilvusEnrichmentTestHelper.stop_db_container(db) + except Exception: + raise TestContainerTeardownError( + f"Milvus container failed to tear down: {str(e)}") + os.environ.pop('MILVUS_VECTOR_DB_URI', None) os.environ.pop('MILVUS_VECTOR_DB_USER', None) os.environ.pop('MILVUS_VECTOR_DB_PASSWORD', None) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py index 5094d9076e93..4184aca0bfe9 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py @@ -312,7 +312,7 @@ def start_db_container( with user_yaml_creator(service_container_port, max_vec_fields) as cfg: info = None original_tc_max_tries = testcontainers_config.max_tries - if not testcontainers_config.max_tries: + if testcontainers_config.max_tries is not None: testcontainers_config.max_tries = tc_max_retries for i in range(vector_client_max_retries): try: @@ -326,7 +326,6 @@ def start_db_container( host = vector_db_container.get_container_host_ip() port = vector_db_container.get_exposed_port(service_container_port) info = MilvusDBContainerInfo(vector_db_container, host, port) - testcontainers_config.max_tries = original_tc_max_tries _LOGGER.info( "milvus db container started successfully on %s.", info.uri) break @@ -351,6 +350,8 @@ def start_db_container( stdout_logs, stderr_logs) raise e + finally: + testcontainers_config.max_tries = original_tc_max_tries return info @staticmethod @@ -358,13 +359,9 @@ def stop_db_container(db_info: MilvusDBContainerInfo): if db_info is None: _LOGGER.warning("Milvus db info is None. Skipping stop operation.") return - try: - _LOGGER.debug("Stopping milvus db container.") - db_info.container.stop() - _LOGGER.info("milvus db container stopped successfully.") - except Exception as e: - _LOGGER.warning( - "Error encountered while stopping milvus db container: %s", e) + _LOGGER.debug("Stopping milvus db container.") + db_info.container.stop() + _LOGGER.info("milvus db container stopped successfully.") @staticmethod def initialize_db_with_data(connc_params: MilvusConnectionParameters): From fa6d2f06994e6ed14269c39b500b66bf336e3f60 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 31 Oct 2025 14:42:30 +0000 Subject: [PATCH 03/35] sdks/python: bump `testcontainers` py pkg version --- sdks/python/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 9ed2a124e94d..d7afb0a2f112 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -463,7 +463,7 @@ def get_portability_package_data(): 'sqlalchemy>=1.3,<3.0', 'psycopg2-binary>=2.8.5,<2.9.10; python_version <= "3.9"', 'psycopg2-binary>=2.8.5,<3.0; python_version >= "3.10"', - 'testcontainers[mysql,kafka,milvus]>=4.0.0,<5.0.0', + 'testcontainers[mysql,kafka,milvus]>=4.13.2,<5.0.0', 'cryptography>=41.0.2', 'hypothesis>5.0.0,<7.0.0', 'virtualenv-clone>=0.5,<1.0', From 9445aaad9bfb2650081e14348d30df5b1f626a97 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 31 Oct 2025 15:19:07 +0000 Subject: [PATCH 04/35] sdks/python: integrate milvus sink I/O --- .../transforms/elementwise/enrichment_test.py | 18 +- .../ml/rag/enrichment/milvus_search.py | 49 +- .../rag/enrichment/milvus_search_it_test.py | 343 ++-------- .../ml/rag/ingestion/milvus_search.py | 340 ++++++++++ .../ml/rag/ingestion/milvus_search_it_test.py | 616 ++++++++++++++++++ .../ml/rag/ingestion/milvus_search_test.py | 122 ++++ .../ml/rag/ingestion/postgres_common.py | 38 +- sdks/python/apache_beam/ml/rag/test_utils.py | 304 +++++++++ sdks/python/apache_beam/ml/rag/utils.py | 129 ++++ 9 files changed, 1632 insertions(+), 327 deletions(-) create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py create mode 100644 sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py create mode 100644 sdks/python/apache_beam/ml/rag/test_utils.py create mode 100644 sdks/python/apache_beam/ml/rag/utils.py diff --git a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py index 083d246a439a..f303b4a670a2 100644 --- a/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py +++ b/sdks/python/apache_beam/examples/snippets/transforms/elementwise/enrichment_test.py @@ -57,8 +57,8 @@ from apache_beam.ml.rag.enrichment.milvus_search_it_test import ( MilvusEnrichmentTestHelper, MilvusDBContainerInfo, - parse_chunk_strings, assert_chunks_equivalent) + from apache_beam.ml.rag.utils import parse_chunk_strings from apache_beam.io.requestresponse import RequestResponseIO except ImportError as e: raise unittest.SkipTest(f'Examples dependencies are not installed: {str(e)}') @@ -68,10 +68,12 @@ class TestContainerStartupError(Exception): """Raised when any test container fails to start.""" pass + class TestContainerTeardownError(Exception): """Raised when any test container fails to teardown.""" pass + def validate_enrichment_with_bigtable(): expected = '''[START enrichment_with_bigtable] Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'}) @@ -377,13 +379,13 @@ def pre_milvus_enrichment() -> MilvusDBContainerInfo: try: db = MilvusEnrichmentTestHelper.start_db_container() connection_params = MilvusConnectionParameters( - uri=db.uri, - user=db.user, - password=db.password, - db_id=db.id, - token=db.token) + uri=db.uri, + user=db.user, + password=db.password, + db_id=db.id, + token=db.token) collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( - connection_params) + connection_params) except Exception as e: raise TestContainerStartupError( f"Milvus container failed to start: {str(e)}") @@ -405,7 +407,7 @@ def post_milvus_enrichment(db: MilvusDBContainerInfo): MilvusEnrichmentTestHelper.stop_db_container(db) except Exception: raise TestContainerTeardownError( - f"Milvus container failed to tear down: {str(e)}") + f"Milvus container failed to tear down: {str(e)}") os.environ.pop('MILVUS_VECTOR_DB_URI', None) os.environ.pop('MILVUS_VECTOR_DB_USER', None) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py index 431c0db3f416..d488c8d3d80d 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py @@ -25,6 +25,7 @@ from typing import Optional from typing import Tuple from typing import Union +import uuid from google.protobuf.json_format import MessageToDict from pymilvus import AnnSearchRequest @@ -35,6 +36,7 @@ from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.rag.utils import MilvusHelpers, MilvusConnectionParameters from apache_beam.transforms.enrichment import EnrichmentSourceHandler @@ -104,44 +106,6 @@ def __str__(self): return self.dict().__str__() -@dataclass -class MilvusConnectionParameters: - """Parameters for establishing connections to Milvus servers. - - Args: - uri: URI endpoint for connecting to Milvus server in the format - "http(s)://hostname:port". - user: Username for authentication. Required if authentication is enabled and - not using token authentication. - password: Password for authentication. Required if authentication is enabled - and not using token authentication. - db_id: Database ID to connect to. Specifies which Milvus database to use. - Defaults to 'default'. - token: Authentication token as an alternative to username/password. - timeout: Connection timeout in seconds. Uses client default if None. - max_retries: Maximum number of connection retry attempts. Defaults to 3. - retry_delay: Initial delay between retries in seconds. Defaults to 1.0. - retry_backoff_factor: Multiplier for retry delay after each attempt. - Defaults to 2.0 (exponential backoff). - kwargs: Optional keyword arguments for additional connection parameters. - Enables forward compatibility. - """ - uri: str - user: str = field(default_factory=str) - password: str = field(default_factory=str) - db_id: str = "default" - token: str = field(default_factory=str) - timeout: Optional[float] = None - max_retries: int = 3 - retry_delay: float = 1.0 - retry_backoff_factor: float = 2.0 - kwargs: Dict[str, Any] = field(default_factory=dict) - - def __post_init__(self): - if not self.uri: - raise ValueError("URI must be provided for Milvus connection") - - @dataclass class BaseSearchParameters: """Base parameters for both vector and keyword search operations. @@ -361,7 +325,7 @@ def __init__( **kwargs): """ Example Usage: - connection_paramters = MilvusConnectionParameters( + connection_parameters = MilvusConnectionParameters( uri="http://localhost:19530") search_parameters = MilvusSearchParameters( collection_name="my_collection", @@ -369,7 +333,7 @@ def __init__( collection_load_parameters = MilvusCollectionLoadParameters( load_fields=["embedding", "metadata"]), milvus_handler = MilvusSearchEnrichmentHandler( - connection_paramters, + connection_parameters, search_parameters, collection_load_parameters=collection_load_parameters, min_batch_size=10, @@ -534,10 +498,7 @@ def _get_keyword_search_data(self, chunk: Chunk): raise ValueError( f"Chunk {chunk.id} missing both text content and sparse embedding " "required for keyword search") - - sparse_embedding = self.convert_sparse_embedding_to_milvus_format( - chunk.sparse_embedding) - + sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding) return chunk.content.text or sparse_embedding def _get_call_response( diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py index 4184aca0bfe9..094788664bdb 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py @@ -57,6 +57,8 @@ from testcontainers.core.generic import DbContainer from testcontainers.milvus import MilvusContainer from apache_beam.transforms.enrichment import Enrichment + from apache_beam.ml.rag.test_utils import ( + MilvusTestHelpers, VectorDBContainerInfo) from apache_beam.ml.rag.enrichment.milvus_search import ( MilvusSearchEnrichmentHandler, MilvusConnectionParameters, @@ -243,241 +245,67 @@ def __getitem__(self, key): } -@dataclass -class MilvusDBContainerInfo: - container: DbContainer - host: str - port: int - user: Optional[str] = "" - password: Optional[str] = "" - token: Optional[str] = "" - id: Optional[str] = "default" - - @property - def uri(self) -> str: - return f"http://{self.host}:{self.port}" - - -class CustomMilvusContainer(MilvusContainer): - def __init__( - self, - image: str, - service_container_port, - healthcheck_container_port, - **kwargs, - ) -> None: - # Skip the parent class's constructor and go straight to - # GenericContainer. - super(MilvusContainer, self).__init__(image=image, **kwargs) - self.port = service_container_port - self.healthcheck_port = healthcheck_container_port - self.with_exposed_ports(service_container_port, healthcheck_container_port) - - # Get free host ports. - service_host_port = MilvusEnrichmentTestHelper.find_free_port() - healthcheck_host_port = MilvusEnrichmentTestHelper.find_free_port() - - # Bind container and host ports. - self.with_bind_ports(service_container_port, service_host_port) - self.with_bind_ports(healthcheck_container_port, healthcheck_host_port) - self.cmd = "milvus run standalone" - - # Set environment variables needed for Milvus. - envs = { - "ETCD_USE_EMBED": "true", - "ETCD_DATA_DIR": "/var/lib/milvus/etcd", - "COMMON_STORAGETYPE": "local", - "METRICS_PORT": str(healthcheck_container_port) - } - for env, value in envs.items(): - self.with_env(env, value) - - -class MilvusEnrichmentTestHelper: - # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus - # Python SDK client in setup.py is updated to match. Referring to the Milvus - # release notes compatibility matrix at - # https://milvus.io/docs/release_notes.md or PyPI at - # https://pypi.org/project/pymilvus/ for version compatibility. - # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required). - @staticmethod - def start_db_container( - image="milvusdb/milvus:v2.5.10", - max_vec_fields=5, - vector_client_max_retries=3, - tc_max_retries=None) -> Optional[MilvusDBContainerInfo]: - service_container_port = MilvusEnrichmentTestHelper.find_free_port() - healthcheck_container_port = MilvusEnrichmentTestHelper.find_free_port() - user_yaml_creator = MilvusEnrichmentTestHelper.create_user_yaml - with user_yaml_creator(service_container_port, max_vec_fields) as cfg: - info = None - original_tc_max_tries = testcontainers_config.max_tries - if testcontainers_config.max_tries is not None: - testcontainers_config.max_tries = tc_max_retries - for i in range(vector_client_max_retries): - try: - vector_db_container = CustomMilvusContainer( - image=image, - service_container_port=service_container_port, - healthcheck_container_port=healthcheck_container_port) - vector_db_container = vector_db_container.with_volume_mapping( - cfg, "/milvus/configs/user.yaml") - vector_db_container.start() - host = vector_db_container.get_container_host_ip() - port = vector_db_container.get_exposed_port(service_container_port) - info = MilvusDBContainerInfo(vector_db_container, host, port) - _LOGGER.info( - "milvus db container started successfully on %s.", info.uri) - break - except Exception as e: - stdout_logs, stderr_logs = vector_db_container.get_logs() - stdout_logs = stdout_logs.decode("utf-8") - stderr_logs = stderr_logs.decode("utf-8") - _LOGGER.warning( - "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. " - "STDOUT logs:\n%s\nSTDERR logs:\n%s", - i + 1, - vector_client_max_retries, - e, - stdout_logs, - stderr_logs) - if i == vector_client_max_retries - 1: - _LOGGER.error( - "Unable to start milvus db container for I/O tests after %d " - "retries. Tests cannot proceed. STDOUT logs:\n%s\n" - "STDERR logs:\n%s", - vector_client_max_retries, - stdout_logs, - stderr_logs) - raise e - finally: - testcontainers_config.max_tries = original_tc_max_tries - return info - - @staticmethod - def stop_db_container(db_info: MilvusDBContainerInfo): - if db_info is None: - _LOGGER.warning("Milvus db info is None. Skipping stop operation.") - return - _LOGGER.debug("Stopping milvus db container.") - db_info.container.stop() - _LOGGER.info("milvus db container stopped successfully.") - - @staticmethod - def initialize_db_with_data(connc_params: MilvusConnectionParameters): - # Open the connection to the milvus db. - client = MilvusClient(**connc_params.__dict__) - - # Configure schema. - field_schemas: List[FieldSchema] = cast( - List[FieldSchema], MILVUS_IT_CONFIG["fields"]) - schema = CollectionSchema( - fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"]) - - # Create collection with the schema. - collection_name = MILVUS_IT_CONFIG["collection_name"] - index_function: Callable[[], IndexParams] = cast( - Callable[[], IndexParams], MILVUS_IT_CONFIG["index"]) - client.create_collection( - collection_name=collection_name, - schema=schema, - index_params=index_function()) - - # Assert that collection was created. - collection_error = f"Expected collection '{collection_name}' to be created." - assert client.has_collection(collection_name), collection_error - - # Gather all fields we have excluding 'sparse_embedding_bm25' special field. - fields = list(map(lambda field: field.name, field_schemas)) - - # Prep data for indexing. Currently we can't insert sparse vectors for BM25 - # sparse embedding field as it would be automatically generated by Milvus - # through the registered BM25 function. - data_ready_to_index = [] - for doc in MILVUS_IT_CONFIG["corpus"]: - item = {} - for field in fields: - if field.startswith("dense_embedding"): - item[field] = doc["dense_embedding"] - elif field == "sparse_embedding_inner_product": - item[field] = doc["sparse_embedding"] - elif field == "sparse_embedding_bm25": - # It is automatically generated by Milvus from the content field. - continue - else: - item[field] = doc[field] - data_ready_to_index.append(item) - - # Index data. - result = client.insert( - collection_name=collection_name, data=data_ready_to_index) - - # Assert that the intended data has been properly indexed. - insertion_err = f'failed to insert the {result["insert_count"]} data points' - assert result["insert_count"] == len(data_ready_to_index), insertion_err - - # Release the collection from memory. It will be loaded lazily when the - # enrichment handler is invoked. - client.release_collection(collection_name) - - # Close the connection to the Milvus database, as no further preparation - # operations are needed before executing the enrichment handler. - client.close() - - return collection_name - - @staticmethod - def find_free_port(): - """Find a free port on the local machine.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - # Bind to port 0, which asks OS to assign a free port. - s.bind(('', 0)) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # Return the port number assigned by OS. - return s.getsockname()[1] - - @staticmethod - @contextlib.contextmanager - def create_user_yaml(service_port: int, max_vector_field_num=5): - """Creates a temporary user.yaml file for Milvus configuration. - - This user yaml file overrides Milvus default configurations. It sets - the Milvus service port to the specified container service port. The - default for maxVectorFieldNum is 4, but we need 5 - (one unique field for each metric). - - Args: - service_port: Port number for the Milvus service. - max_vector_field_num: Max number of vec fields allowed per collection. - - Yields: - str: Path to the created temporary yaml file. - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', - delete=False) as temp_file: - # Define the content for user.yaml. - user_config = { - 'proxy': { - 'maxVectorFieldNum': max_vector_field_num, 'port': service_port - }, - 'etcd': { - 'use': { - 'embed': True - }, 'data': { - 'dir': '/var/lib/milvus/etcd' - } - } - } - - # Write the content to the file. - yaml.dump(user_config, temp_file, default_flow_style=False) - path = temp_file.name - - try: - yield path - finally: - if os.path.exists(path): - os.remove(path) +def initialize_db_with_data(connc_params: MilvusConnectionParameters): + # Open the connection to the milvus db. + client = MilvusClient(**connc_params.__dict__) + + # Configure schema. + field_schemas: List[FieldSchema] = cast( + List[FieldSchema], MILVUS_IT_CONFIG["fields"]) + schema = CollectionSchema( + fields=field_schemas, functions=MILVUS_IT_CONFIG["functions"]) + + # Create collection with the schema. + collection_name = MILVUS_IT_CONFIG["collection_name"] + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], MILVUS_IT_CONFIG["index"]) + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Assert that collection was created. + collection_error = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), collection_error + + # Gather all fields we have excluding 'sparse_embedding_bm25' special field. + fields = list(map(lambda field: field.name, field_schemas)) + + # Prep data for indexing. Currently we can't insert sparse vectors for BM25 + # sparse embedding field as it would be automatically generated by Milvus + # through the registered BM25 function. + data_ready_to_index = [] + for doc in MILVUS_IT_CONFIG["corpus"]: + item = {} + for field in fields: + if field.startswith("dense_embedding"): + item[field] = doc["dense_embedding"] + elif field == "sparse_embedding_inner_product": + item[field] = doc["sparse_embedding"] + elif field == "sparse_embedding_bm25": + # It is automatically generated by Milvus from the content field. + continue + else: + item[field] = doc[field] + data_ready_to_index.append(item) + + # Index data. + result = client.insert( + collection_name=collection_name, data=data_ready_to_index) + + # Assert that the intended data has been properly indexed. + insertion_err = f'failed to insert the {result["insert_count"]} data points' + assert result["insert_count"] == len(data_ready_to_index), insertion_err + + # Release the collection from memory. It will be loaded lazily when the + # enrichment handler is invoked. + client.release_collection(collection_name) + + # Close the connection to the Milvus database, as no further preparation + # operations are needed before executing the enrichment handler. + client.close() + + return collection_name @pytest.mark.require_docker_in_docker @@ -491,25 +319,23 @@ def create_user_yaml(service_port: int, max_vector_field_num=5): class TestMilvusSearchEnrichment(unittest.TestCase): """Tests for search functionality across all search strategies""" - _db: MilvusDBContainerInfo + _db: VectorDBContainerInfo @classmethod def setUpClass(cls): - cls._db = MilvusEnrichmentTestHelper.start_db_container() + cls._db = MilvusTestHelpers.start_db_container() cls._connection_params = MilvusConnectionParameters( uri=cls._db.uri, user=cls._db.user, password=cls._db.password, - db_id=cls._db.id, - token=cls._db.token, - timeout=60.0) # Increase timeout to 60s for container startup + db_name=cls._db.id, + token=cls._db.token) cls._collection_load_params = MilvusCollectionLoadParameters() - cls._collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data( - cls._connection_params) + cls._collection_name = initialize_db_with_data(cls._connection_params) @classmethod def tearDownClass(cls): - MilvusEnrichmentTestHelper.stop_db_container(cls._db) + MilvusTestHelpers.stop_db_container(cls._db) cls._db = None def test_invalid_query_on_non_existent_collection(self): @@ -1244,37 +1070,6 @@ def test_hybrid_search(self): lambda actual: assert_chunks_equivalent(actual, expected_chunks)) -def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]: - parsed_chunks = [] - - # Define safe globals and disable built-in functions for safety. - safe_globals = { - 'Chunk': Chunk, - 'Content': Content, - 'Embedding': Embedding, - 'defaultdict': defaultdict, - 'list': list, - '__builtins__': {} - } - - for raw_str in chunk_str_list: - try: - # replace "" with actual list reference. - cleaned_str = re.sub( - r"defaultdict\(", "defaultdict(list", raw_str) - - # Evaluate string in restricted environment. - chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used - if isinstance(chunk, Chunk): - parsed_chunks.append(chunk) - else: - raise ValueError("Parsed object is not a Chunk instance") - except Exception as e: - raise ValueError(f"Error parsing string:\n{raw_str}\n{e}") - - return parsed_chunks - - def assert_chunks_equivalent( actual_chunks: List[Chunk], expected_chunks: List[Chunk]): """assert_chunks_equivalent checks for presence rather than exact match""" diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py new file mode 100644 index 000000000000..041349efeb77 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py @@ -0,0 +1,340 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, NamedTuple, Optional + +from pymilvus import MilvusClient + +import logging + +import apache_beam as beam +from apache_beam.ml.rag.ingestion.base import VectorDatabaseWriteConfig +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec +from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpecsBuilder +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.utils import ( + MilvusHelpers, unpack_dataclass_with_kwargs, DEFAULT_WRITE_BATCH_SIZE) +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs +from apache_beam.transforms import DoFn + +from apache_beam.ml.rag.utils import MilvusConnectionParameters + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class MilvusWriteConfig: + """Configuration parameters for writing data to Milvus collections. + + This class defines the parameters needed to write data to a Milvus collection, + including collection targeting, batching behavior, and operation timeouts. + + Args: + collection_name: Name of the target Milvus collection to write data to. + Must be a non-empty string. + partition_name: Name of the specific partition within the collection to + write to. If empty, writes to the default partition. + timeout: Maximum time in seconds to wait for write operations to complete. + If None, uses the client's default timeout. + write_config: Configuration for write operations including batch size and + other write-specific settings. + kwargs: Additional keyword arguments for write operations. Enables forward + compatibility with future Milvus client parameters. + """ + collection_name: str + partition_name: str = "" + timeout: Optional[float] = None + write_config: WriteConfig = field(default_factory=WriteConfig) + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.collection_name: + raise ValueError("Collection name must be provided") + + @property + def write_batch_size(self): + """Returns the batch size for write operations. + + Returns: + The configured batch size, or DEFAULT_WRITE_BATCH_SIZE if not specified. + """ + return self.write_config.write_batch_size or DEFAULT_WRITE_BATCH_SIZE + + +@dataclass +class MilvusVectorWriterConfig(VectorDatabaseWriteConfig): + """Configuration for writing vector data to Milvus collections. + + This class extends VectorDatabaseWriteConfig to provide Milvus-specific + configuration for ingesting vector embeddings and associated metadata. + It defines how Apache Beam chunks are converted to Milvus records and + handles the write operation parameters. + + The configuration includes connection parameters, write settings, and + column specifications that determine how chunk data is mapped to Milvus + fields. + + Args: + connection_params: Configuration for connecting to the Milvus server, + including URI, credentials, and connection options. + write_config: Configuration for write operations including collection name, + partition, batch size, and timeouts. + column_specs: List of column specifications defining how chunk fields are + mapped to Milvus collection fields. Defaults to standard RAG fields + (id, embedding, sparse_embedding, content, metadata). + + Example: + config = MilvusVectorWriterConfig( + connection_params=MilvusConnectionParameters( + uri="http://localhost:19530"), + write_config=MilvusWriteConfig(collection_name="my_collection"), + column_specs=MilvusVectorWriterConfig.default_column_specs()) + """ + connection_params: MilvusConnectionParameters + write_config: MilvusWriteConfig + column_specs: List[ColumnSpec] = field( + default_factory=lambda: MilvusVectorWriterConfig.default_column_specs()) + + def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: + """Creates a function to convert Apache Beam Chunks to Milvus records. + + Returns: + A function that takes a Chunk and returns a dictionary representing + a Milvus record with fields mapped according to column_specs. + """ + """Creates a function to convert Chunks to records.""" + def convert(chunk: Chunk) -> Dict[str, Any]: + result = {} + for col in self.column_specs: + result[col.column_name] = col.value_fn(chunk) + return result + + return convert + + def create_write_transform(self) -> beam.PTransform: + """Creates the Apache Beam transform for writing to Milvus. + + Returns: + A PTransform that can be applied to a PCollection of Chunks to write + them to the configured Milvus collection. + """ + return _WriteToMilvusVectorDatabase(self) + + @staticmethod + def default_column_specs() -> List[ColumnSpec]: + """Returns default column specifications for RAG use cases. + + Creates column mappings for standard RAG fields: id, dense embedding, + sparse embedding, content text, and metadata. These specifications + define how Chunk fields are converted to Milvus-compatible formats. + + Returns: + List of ColumnSpec objects defining the default field mappings. + """ + column_specs = ColumnSpecsBuilder() + return column_specs\ + .with_id_spec()\ + .with_embedding_spec(convert_fn=lambda values: list(values))\ + .with_sparse_embedding_spec(conv_fn=MilvusHelpers.sparse_embedding)\ + .with_content_spec()\ + .with_metadata_spec(convert_fn=lambda values: dict(values))\ + .build() + + +class _WriteToMilvusVectorDatabase(beam.PTransform): + """Apache Beam PTransform for writing vector data to Milvus. + + This transform handles the conversion of Apache Beam Chunks to Milvus records + and coordinates the write operations. It applies the configured converter + function and uses a DoFn for batched writes to optimize performance. + + Args: + config: MilvusVectorWriterConfig containing all necessary parameters for + the write operation. + """ + def __init__(self, config: MilvusVectorWriterConfig): + self.config = config + + def expand(self, pcoll: beam.PCollection[Chunk]): + """Expands the PTransform to convert chunks and write to Milvus. + + Args: + pcoll: PCollection of Chunk objects to write to Milvus. + + Returns: + PCollection of the same Chunk objects after writing to Milvus. + """ + return ( + pcoll + | "Convert to Records" >> beam.Map(self.config.create_converter()) + | beam.ParDo( + _WriteMilvusFn( + self.config.connection_params, self.config.write_config))) + + +class _WriteMilvusFn(DoFn): + """DoFn that handles batched writes to Milvus. + + This DoFn accumulates records in batches and flushes them to Milvus when + the batch size is reached or when the bundle finishes. This approach + optimizes performance by reducing the number of individual write operations. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including batch size + and collection details. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self.batch = [] + + def process(self, element, *args, **kwargs): + """Processes individual records, batching them for efficient writes. + + Args: + element: A dictionary representing a Milvus record to write. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Yields: + The original element after adding it to the batch. + """ + _ = args, kwargs # Unused parameters + self.batch.append(element) + if len(self.batch) >= self._write_config.write_batch_size: + self._flush() + yield element + + def finish_bundle(self): + """Called when a bundle finishes processing. + + Flushes any remaining records in the batch to ensure all data is written. + """ + self._flush() + + def _flush(self): + """Flushes the current batch of records to Milvus. + + Creates a MilvusSink connection and writes all batched records, + then clears the batch for the next set of records. + """ + if len(self.batch) == 0: + return + with _MilvusSink(self._connection_params, self._write_config) as sink: + sink.write(self.batch) + self.batch = [] + + def display_data(self): + """Returns display data for monitoring and debugging. + + Returns: + Dictionary containing database, collection, and batch size information + for display in the Apache Beam monitoring UI. + """ + res = super().display_data() + res["database"] = self._connection_params.db_name + res["collection"] = self._write_config.collection_name + res["batch_size"] = self._write_config.write_batch_size + return res + + +class _MilvusSink: + """Low-level sink for writing data directly to Milvus. + + This class handles the direct interaction with the Milvus client for + upsert operations. It manages the connection lifecycle and provides + context manager support for proper resource cleanup. + + Args: + connection_params: Configuration for connecting to the Milvus server. + write_config: Configuration for write operations including collection + and partition targeting. + """ + def __init__( + self, + connection_params: MilvusConnectionParameters, + write_config: MilvusWriteConfig): + self._connection_params = connection_params + self._write_config = write_config + self._client = None + + def write(self, documents): + """Writes a batch of documents to the Milvus collection. + + Performs an upsert operation to insert new documents or update existing + ones based on primary key. After the upsert, flushes the collection to + ensure data persistence. + + Args: + documents: List of dictionaries representing Milvus records to write. + Each dictionary should contain fields matching the collection schema. + """ + if not self._client: + self._client = MilvusClient( + **unpack_dataclass_with_kwargs(self._connection_params)) + + try: + resp = self._client.upsert( + collection_name=self._write_config.collection_name, + partition_name=self._write_config.partition_name, + data=documents, + timeout=self._write_config.timeout, + **self._write_config.kwargs) + + # Try to flush, but handle connection issues gracefully. + try: + self._client.flush(self._write_config.collection_name) + except Exception as e: + # If flush fails due to connection issues, log but don't fail the write. + _LOGGER.warning( + "Flush operation failed, but upsert was successful: %s", e) + + _LOGGER.debug( + "Upserted into Milvus: upsert_count=%d, cost=%d", + resp.get("upsert_count", 0), + resp.get("cost", 0)) + except Exception as e: + _LOGGER.error("Failed to write to Milvus: %s", e) + raise + + def __enter__(self): + """Enters the context manager and establishes Milvus connection. + + Returns: + Self, enabling use in 'with' statements. + """ + if not self._client: + self._client = MilvusClient( + **unpack_dataclass_with_kwargs(self._connection_params)) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exits the context manager and closes the Milvus connection. + + Args: + exc_type: Exception type if an exception was raised. + exc_val: Exception value if an exception was raised. + exc_tb: Exception traceback if an exception was raised. + """ + _ = exc_type, exc_val, exc_tb # Unused parameters + if self._client: + self._client.close() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py new file mode 100644 index 000000000000..f8f01d9d5964 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py @@ -0,0 +1,616 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import platform +from typing import Callable, cast +import unittest +import uuid + +import pytest +from pymilvus import CollectionSchema, DataType, MilvusClient +from pymilvus import FieldSchema +from pymilvus.milvus_client import IndexParams + +import apache_beam as beam + +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.rag.utils import MilvusConnectionParameters +from apache_beam.ml.rag.test_utils import ( + VectorDBContainerInfo, MilvusTestHelpers) +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.ml.rag.ingestion.jdbc_common import WriteConfig +from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs + +try: + from apache_beam.ml.rag.ingestion.milvus_search import ( + MilvusWriteConfig, MilvusVectorWriterConfig) +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +def _construct_index_params(): + index_params = IndexParams() + + # Dense vector index for dense embeddings. + index_params.add_index( + field_name="embedding", + index_name="embedding_ivf_flat", + index_type="IVF_FLAT", + metric_type="COSINE", + params={"nlist": 1}) + + # Sparse vector index for sparse embeddings. + index_params.add_index( + field_name="sparse_embedding", + index_name="sparse_embedding_inverted_index", + index_type="SPARSE_INVERTED_INDEX", + metric_type="IP", + params={"inverted_index_algo": "TAAT_NAIVE"}) + + return index_params + + +MILVUS_INGESTION_IT_CONFIG = { + "fields": [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=False), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ], + "index": _construct_index_params, + "corpus": [ + Chunk( + id=1, + content=Content(text="Test document one"), + metadata={"source": "test1"}, + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + ), + Chunk( + id=2, + content=Content(text="Test document two"), + metadata={"source": "test2"}, + embedding=Embedding( + dense_embedding=[0.2, 0.3, 0.4], + sparse_embedding=([2, 3], [0.3, 0.1]), + ), + ), + Chunk( + id=3, + content=Content(text="Test document three"), + metadata={"source": "test3"}, + embedding=Embedding( + dense_embedding=[0.3, 0.4, 0.5], + sparse_embedding=([3, 4], [0.4, 0.2]), + ), + ) + ] +} + + +def create_collection_with_partition( + client: MilvusClient, + collection_name: str, + partition_name: str = '', + fields=MILVUS_INGESTION_IT_CONFIG["fields"]): + # Configure schema. + schema = CollectionSchema(fields=fields) + + # Configure index. + index_function: Callable[[], IndexParams] = cast( + Callable[[], IndexParams], MILVUS_INGESTION_IT_CONFIG["index"]) + + # Create collection with schema. + client.create_collection( + collection_name=collection_name, + schema=schema, + index_params=index_function()) + + # Create partition within the collection. + client.create_partition( + collection_name=collection_name, partition_name=partition_name) + + msg = f"Expected collection '{collection_name}' to be created." + assert client.has_collection(collection_name), msg + + msg = f"Expected partition '{partition_name}' to be created." + assert client.has_partition(collection_name, partition_name), msg + + # Release the collection from memory. We don't need that on pure writing. + client.release_collection(collection_name) + + +def drop_collection(client: MilvusClient, collection_name: str): + try: + client.drop_collection(collection_name) + assert not client.has_collection(collection_name) + except Exception: + # Silently ignore connection errors during cleanup. + pass + + +@pytest.mark.uses_testcontainer +@unittest.skipUnless( + platform.system() == "Linux", + "Test runs only on Linux due to lack of support, as yet, for nested " + "virtualization in CI environments on Windows/macOS. Many CI providers run " + "tests in virtualized environments, and nested virtualization " + "(Docker inside a VM) is either unavailable or has several issues on " + "non-Linux platforms.") +class TestMilvusVectorWriterConfig(unittest.TestCase): + """Integration tests for Milvus vector database ingestion functionality""" + + _db: VectorDBContainerInfo + _version = "milvusdb/milvus:v2.5.10" + + @classmethod + def setUpClass(cls): + cls._db = MilvusTestHelpers.start_db_container( + cls._version, vector_client_max_retries=3) + cls._connection_config = MilvusConnectionParameters( + uri=cls._db.uri, + user=cls._db.user, + password=cls._db.password, + db_name=cls._db.id, + token=cls._db.token) + + @classmethod + def tearDownClass(cls): + MilvusTestHelpers.stop_db_container(cls._db) + cls._db = None + + def setUp(self): + self.write_test_pipeline = TestPipeline() + self.write_test_pipeline.not_use_test_runner_api = True + self._collection_name = f"test_collection_{self._testMethodName}" + self._partition_name = f"test_partition_{self._testMethodName}" + config = unpack_dataclass_with_kwargs(self._connection_config) + config["alias"] = f"milvus_conn_{uuid.uuid4().hex[:8]}" + self._test_client = MilvusClient(**config) + create_collection_with_partition( + self._test_client, self._collection_name, self._partition_name) + + def tearDown(self): + drop_collection(self._test_client, self._collection_name) + self._test_client.close() + + def test_invalid_write_on_non_existent_collection(self): + non_existent_collection = "nonexistent_collection" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=non_existent_collection, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + ) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("can't find collection", str(context.exception).lower()) + + def test_invalid_write_on_non_existent_partition(self): + non_existent_partition = "nonexistent_partition" + + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=non_existent_partition, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn("partition not found", str(context.exception).lower()) + + def test_invalid_write_on_missing_primary_key_in_entity(self): + test_chunks = [ + Chunk( + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=1)) + + # Deliberately remove id primary key from the entity. + specs = MilvusVectorWriterConfig.default_column_specs() + for i, spec in enumerate(specs): + if spec.column_name == "id": + del specs[i] + break + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + column_specs=specs) + + # Write pipeline. + with self.assertRaises(Exception) as context: + with TestPipeline() as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Assert on what should happen. + self.assertIn( + "insert missed an field `id` to collection", + str(context.exception).lower()) + + def test_write_on_auto_id_primary_key(self): + auto_id_collection = f"auto_id_collection_{self._testMethodName}" + auto_id_partition = f"auto_id_partition_{self._testMethodName}" + auto_id_fields = [ + FieldSchema( + name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), + FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=1000), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=3), + FieldSchema( + name="sparse_embedding", dtype=DataType.SPARSE_FLOAT_VECTOR) + ] + + # Create collection with an auto id field. + create_collection_with_partition( + client=self._test_client, + collection_name=auto_id_collection, + partition_name=auto_id_partition, + fields=auto_id_fields) + + test_chunks = [ + Chunk( + id=1, + content=Content(text="Test content without ID"), + embedding=Embedding( + dense_embedding=[0.1, 0.2, 0.3], + sparse_embedding=([1, 2], [0.1, 0.2])), + metadata={"source": "test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=auto_id_collection, + partition_name=auto_id_partition, + write_config=WriteConfig(write_batch_size=1)) + + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + self._test_client.load_collection(auto_id_collection) + result = self._test_client.query( + collection_name=auto_id_collection, + partition_names=[auto_id_partition], + limit=3) + + # Test there is only one item in the result and the ID is not equal to one. + self.assertEqual(len(result), len(test_chunks)) + result_item = dict(result[0]) + self.assertNotEqual(result_item["id"], 1) + + def test_write_on_existent_collection_with_default_schema(self): + test_chunks = MILVUS_INGESTION_IT_CONFIG["corpus"] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=3)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify data was written successfully. + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), len(test_chunks)) + + # Verify each chunk was written correctly. + result_by_id = {item["id"]: item for item in result} + for chunk in test_chunks: + self.assertIn(chunk.id, result_by_id) + result_item = result_by_id[chunk.id] + self.assertEqual( + result_item["content"], + chunk.content.text + if hasattr(chunk.content, 'text') else chunk.content) + self.assertEqual(result_item["metadata"], chunk.metadata) + + # Verify embedding is present and has correct length. + expected_embedding = chunk.embedding.dense_embedding + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + def test_write_with_custom_column_specifications(self): + from apache_beam.ml.rag.ingestion.postgres_common import ColumnSpec + from apache_beam.ml.rag.utils import MilvusHelpers + + custom_column_specs = [ + ColumnSpec("id", int, lambda chunk: int(chunk.id) if chunk.id else 0), + ColumnSpec( + "content", + str, lambda chunk: ( + chunk.content.text + if hasattr(chunk.content, 'text') else chunk.content)), + ColumnSpec("metadata", dict, lambda chunk: chunk.metadata or {}), + ColumnSpec( + "embedding", + list, lambda chunk: chunk.embedding.dense_embedding or []), + ColumnSpec( + "sparse_embedding", + dict, lambda chunk: ( + MilvusHelpers.sparse_embedding( + chunk.embedding.sparse_embedding) if chunk.embedding and + chunk.embedding.sparse_embedding else {})) + ] + + test_chunks = [ + Chunk( + id=10, + content=Content(text="Custom column spec test"), + embedding=Embedding( + dense_embedding=[0.8, 0.9, 1.0], + sparse_embedding=([1, 3, 5], [0.8, 0.9, 1.0])), + metadata={"custom": "spec_test"}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=1)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, + write_config=write_config, + column_specs=custom_column_specs) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify data was written successfully. + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + filter="id == 10", + limit=1) + + self.assertEqual(len(result), 1) + result_item = result[0] + + # Verify custom column specs worked correctly. + self.assertEqual(result_item["id"], 10) + self.assertEqual(result_item["content"], "Custom column spec test") + self.assertEqual(result_item["metadata"], {"custom": "spec_test"}) + + # Verify embedding is present and has correct length. + expected_embedding = [0.8, 0.9, 1.0] + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + # Verify sparse embedding was converted correctly - check keys are present. + expected_sparse_keys = {1, 3, 5} + actual_sparse = result_item["sparse_embedding"] + self.assertIsNotNone(actual_sparse) + self.assertEqual(set(actual_sparse.keys()), expected_sparse_keys) + + def test_write_with_batching(self): + test_chunks = [ + Chunk( + id=i, + content=Content(text=f"Batch test document {i}"), + embedding=Embedding( + dense_embedding=[0.1 * i, 0.2 * i, 0.3 * i], + sparse_embedding=([i, i + 1], [0.1 * i, 0.2 * i])), + metadata={"batch_id": i}) for i in range(1, 8) # 7 chunks + ] + + # Set small batch size to force batching (7 chunks with batch size 3). + batch_write_config = WriteConfig(write_batch_size=3) + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=batch_write_config) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + with self.write_test_pipeline as p: + _ = (p | beam.Create(test_chunks) | config.create_write_transform()) + + # Verify all data was written successfully. + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), len(test_chunks)) + + # Verify each batch was written correctly. + result_by_id = {item["id"]: item for item in result} + for chunk in test_chunks: + self.assertIn(chunk.id, result_by_id) + result_item = result_by_id[chunk.id] + + # Verify content and metadata. + self.assertEqual(result_item["content"], chunk.content.text) + self.assertEqual(result_item["metadata"], chunk.metadata) + + # Verify embeddings are present and have correct length. + expected_embedding = chunk.embedding.dense_embedding + actual_embedding = result_item["embedding"] + self.assertIsNotNone(actual_embedding) + self.assertEqual(len(actual_embedding), len(expected_embedding)) + + # Verify sparse embedding keys are present. + expected_sparse_keys = {chunk.id, chunk.id + 1} + actual_sparse = result_item["sparse_embedding"] + self.assertIsNotNone(actual_sparse) + self.assertEqual(set(actual_sparse.keys()), expected_sparse_keys) + + def test_idempotent_write(self): + # Step 1: Insert initial data that doesn't exist. + initial_chunks = [ + Chunk( + id=100, + content=Content(text="Initial document"), + embedding=Embedding( + dense_embedding=[1.0, 2.0, 3.0], + sparse_embedding=([100, 101], [1.0, 2.0])), + metadata={"version": 1}), + Chunk( + id=200, + content=Content(text="Another initial document"), + embedding=Embedding( + dense_embedding=[2.0, 3.0, 4.0], + sparse_embedding=([200, 201], [2.0, 3.0])), + metadata={"version": 1}) + ] + + write_config = MilvusWriteConfig( + collection_name=self._collection_name, + partition_name=self._partition_name, + write_config=WriteConfig(write_batch_size=2)) + config = MilvusVectorWriterConfig( + connection_params=self._connection_config, write_config=write_config) + + # Insert initial data. + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p | "Create initial" >> beam.Create(initial_chunks) + | "Write initial" >> config.create_write_transform()) + + # Verify initial data was inserted (not existed before). + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), 2) + result_by_id = {item["id"]: item for item in result} + + # Verify initial state. + self.assertEqual(result_by_id[100]["content"], "Initial document") + self.assertEqual(result_by_id[100]["metadata"], {"version": 1}) + self.assertEqual(result_by_id[200]["content"], "Another initial document") + self.assertEqual(result_by_id[200]["metadata"], {"version": 1}) + + # Step 2: Update existing data (same IDs, different content). + updated_chunks = [ + Chunk( + id=100, + content=Content(text="Updated document"), + embedding=Embedding( + dense_embedding=[1.1, 2.1, 3.1], + sparse_embedding=([100, 102], [1.1, 2.1])), + metadata={"version": 2}), + Chunk( + id=200, + content=Content(text="Another updated document"), + embedding=Embedding( + dense_embedding=[2.1, 3.1, 4.1], + sparse_embedding=([200, 202], [2.1, 3.1])), + metadata={"version": 2}) + ] + + # Perform first update. + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p | "Create update1" >> beam.Create(updated_chunks) + | "Write update1" >> config.create_write_transform()) + + # Verify update worked. + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + self.assertEqual(len(result), 2) # Still only 2 records. + result_by_id = {item["id"]: item for item in result} + + # Verify updated state. + self.assertEqual(result_by_id[100]["content"], "Updated document") + self.assertEqual(result_by_id[100]["metadata"], {"version": 2}) + self.assertEqual(result_by_id[200]["content"], "Another updated document") + self.assertEqual(result_by_id[200]["metadata"], {"version": 2}) + + # Step 3: Repeat the same update operation 3 more times (idempotence test). + for i in range(3): + with TestPipeline() as p: + p.not_use_test_runner_api = True + _ = ( + p | f"Create repeat{i+2}" >> beam.Create(updated_chunks) + | f"Write repeat{i+2}" >> config.create_write_transform()) + + # Verify state hasn't changed after repeated updates. + self._test_client.load_collection(self._collection_name) + result = self._test_client.query( + collection_name=self._collection_name, + partition_names=[self._partition_name], + limit=10) + + # Still only 2 records. + self.assertEqual(len(result), 2) + result_by_id = {item["id"]: item for item in result} + + # Final state should remain unchanged. + self.assertEqual(result_by_id[100]["content"], "Updated document") + self.assertEqual(result_by_id[100]["metadata"], {"version": 2}) + self.assertEqual(result_by_id[200]["content"], "Another updated document") + self.assertEqual(result_by_id[200]["metadata"], {"version": 2}) + + # Verify embeddings are still correct. + self.assertIsNotNone(result_by_id[100]["embedding"]) + self.assertEqual(len(result_by_id[100]["embedding"]), 3) + self.assertIsNotNone(result_by_id[200]["embedding"]) + self.assertEqual(len(result_by_id[200]["embedding"]), 3) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py new file mode 100644 index 000000000000..37b05c2e2409 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_test.py @@ -0,0 +1,122 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest +from parameterized import parameterized + +try: + from apache_beam.ml.rag.ingestion.milvus_search import ( + MilvusWriteConfig, MilvusVectorWriterConfig) + from apache_beam.ml.rag.utils import MilvusConnectionParameters +except ImportError as e: + raise unittest.SkipTest(f'Milvus dependencies not installed: {str(e)}') + + +class TestMilvusWriteConfig(unittest.TestCase): + """Unit tests for MilvusWriteConfig validation errors.""" + def test_empty_collection_name_raises_error(self): + """Test that empty collection name raises ValueError.""" + with self.assertRaises(ValueError) as context: + MilvusWriteConfig(collection_name="") + + self.assertIn("Collection name must be provided", str(context.exception)) + + def test_none_collection_name_raises_error(self): + """Test that None collection name raises ValueError.""" + with self.assertRaises(ValueError) as context: + MilvusWriteConfig(collection_name=None) # type: ignore[arg-type] + + self.assertIn("Collection name must be provided", str(context.exception)) + + +class TestMilvusVectorWriterConfig(unittest.TestCase): + """Unit tests for MilvusVectorWriterConfig validation and functionality.""" + def test_valid_config_creation(self): + """Test creation of valid MilvusVectorWriterConfig.""" + connection_params = MilvusConnectionParameters(uri="http://localhost:19530") + write_config = MilvusWriteConfig(collection_name="test_collection") + + config = MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + self.assertEqual(config.connection_params, connection_params) + self.assertEqual(config.write_config, write_config) + self.assertIsNotNone(config.column_specs) + + def test_create_converter_returns_callable(self): + """Test that create_converter returns a callable function.""" + connection_params = MilvusConnectionParameters(uri="http://localhost:19530") + write_config = MilvusWriteConfig(collection_name="test_collection") + + config = MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + converter = config.create_converter() + self.assertTrue(callable(converter)) + + def test_create_write_transform_returns_ptransform(self): + """Test that create_write_transform returns a PTransform.""" + connection_params = MilvusConnectionParameters(uri="http://localhost:19530") + write_config = MilvusWriteConfig(collection_name="test_collection") + + config = MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + transform = config.create_write_transform() + self.assertIsNotNone(transform) + + def test_default_column_specs_has_expected_fields(self): + """Test that default column specs include expected fields.""" + column_specs = MilvusVectorWriterConfig.default_column_specs() + + self.assertIsInstance(column_specs, list) + self.assertGreater(len(column_specs), 0) + + column_names = [spec.column_name for spec in column_specs] + expected_fields = [ + "id", "embedding", "sparse_embedding", "content", "metadata" + ] + + for field in expected_fields: + self.assertIn(field, column_names) + + @parameterized.expand([ + # Invalid connection parameters - empty URI. + ( + lambda: ( + MilvusConnectionParameters(uri=""), MilvusWriteConfig( + collection_name="test_collection")), + "URI must be provided"), + # Invalid write config - empty collection name. + ( + lambda: ( + MilvusConnectionParameters(uri="http://localhost:19530"), + MilvusWriteConfig(collection_name="")), + "Collection name must be provided"), + ]) + def test_invalid_configuration_parameters( + self, create_params, expected_error_msg): + """Test validation errors for invalid configuration parameters.""" + with self.assertRaises(ValueError) as context: + connection_params, write_config = create_params() + MilvusVectorWriterConfig( + connection_params=connection_params, write_config=write_config) + + self.assertIn(expected_error_msg, str(context.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py index eca740a4e9c3..d07edff83928 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/postgres_common.py @@ -16,7 +16,7 @@ import json from dataclasses import dataclass -from typing import Any +from typing import Any, Tuple from typing import Callable from typing import Dict from typing import List @@ -311,6 +311,42 @@ def value_fn(chunk: Chunk) -> Any: ColumnSpec.vector(column_name=column_name, value_fn=value_fn)) return self + def with_sparse_embedding_spec( + self, + column_name: str = "sparse_embedding", + conv_fn: Optional[Callable[[Tuple[List[int], List[float]]], Any]] = None + ) -> 'ColumnSpecsBuilder': + """Add sparse embedding :class:`.ColumnSpec` with optional conversion. + + Args: + column_name: Name for the sparse embedding column + (defaults to "sparse_embedding") + conv_fn: Optional function to convert the sparse embedding tuple + If None, converts to PostgreSQL-compatible JSON format + + Returns: + Self for method chaining + + Example: + >>> builder.with_sparse_embedding_spec( + ... column_name="sparse_vector", + ... convert_fn=lambda sparse: dict(zip(sparse[0], sparse[1])) + ... ) + """ + def value_fn(chunk: Chunk) -> Any: + if chunk.embedding is None or chunk.embedding.sparse_embedding is None: + raise ValueError(f'Expected chunk to contain sparse embedding. {chunk}') + sparse_embedding = chunk.embedding.sparse_embedding + if conv_fn: + return conv_fn(sparse_embedding) + # Default: convert to dict format for JSON storage. + indices, values = sparse_embedding + return json.dumps(dict(zip(indices, values))) + + self._specs.append( + ColumnSpec.jsonb(column_name=column_name, value_fn=value_fn)) + return self + def add_metadata_field( self, field: str, diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py b/sdks/python/apache_beam/ml/rag/test_utils.py new file mode 100644 index 000000000000..9a46f46397eb --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/test_utils.py @@ -0,0 +1,304 @@ +import contextlib +from dataclasses import dataclass +import os +import socket +import tempfile +import logging +from typing import Dict, List, Optional +from testcontainers.core.config import testcontainers_config +from testcontainers.core.generic import DbContainer +from testcontainers.milvus import MilvusContainer +import yaml + +from apache_beam.ml.rag.types import Chunk + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class VectorDBContainerInfo: + """Container information for vector database test instances. + + Holds connection details and container reference for testing with + vector databases like Milvus in containerized environments. + """ + container: DbContainer + host: str + port: int + user: str = "" + password: str = "" + token: str = "" + id: str = "default" + + @property + def uri(self) -> str: + return f"http://{self.host}:{self.port}" + + +class TestHelpers: + @staticmethod + def find_free_port(): + """Find a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # Bind to port 0, which asks OS to assign a free port. + s.bind(('', 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # Return the port number assigned by OS. + return s.getsockname()[1] + + +class CustomMilvusContainer(MilvusContainer): + """Custom Milvus container with configurable ports and environment setup. + + Extends MilvusContainer to provide custom port binding and environment + configuration for testing with standalone Milvus instances. + """ + def __init__( + self, + image: str, + service_container_port, + healthcheck_container_port, + **kwargs, + ) -> None: + # Skip the parent class's constructor and go straight to + # GenericContainer. + super(MilvusContainer, self).__init__(image=image, **kwargs) + self.port = service_container_port + self.healthcheck_port = healthcheck_container_port + self.with_exposed_ports(service_container_port, healthcheck_container_port) + + # Get free host ports. + service_host_port = TestHelpers.find_free_port() + healthcheck_host_port = TestHelpers.find_free_port() + + # Bind container and host ports. + self.with_bind_ports(service_container_port, service_host_port) + self.with_bind_ports(healthcheck_container_port, healthcheck_host_port) + self.cmd = "milvus run standalone" + + # Set environment variables needed for Milvus. + envs = { + "ETCD_USE_EMBED": "true", + "ETCD_DATA_DIR": "/var/lib/milvus/etcd", + "COMMON_STORAGETYPE": "local", + "METRICS_PORT": str(healthcheck_container_port) + } + for env, value in envs.items(): + self.with_env(env, value) + + +class MilvusTestHelpers: + """Helper utilities for testing Milvus vector database operations. + + Provides static methods for managing test containers, configuration files, + and chunk comparison utilities for Milvus-based integration tests. + """ + # IMPORTANT: When upgrading the Milvus server version, ensure the pymilvus + # Python SDK client in setup.py is updated to match. Referring to the Milvus + # release notes compatibility matrix at + # https://milvus.io/docs/release_notes.md or PyPI at + # https://pypi.org/project/pymilvus/ for version compatibility. + # Example: Milvus v2.6.0 requires pymilvus==2.6.0 (exact match required). + @staticmethod + def start_db_container( + image="milvusdb/milvus:v2.5.10", + max_vec_fields=5, + vector_client_max_retries=3, + tc_max_retries=None) -> Optional[VectorDBContainerInfo]: + service_container_port = TestHelpers.find_free_port() + healthcheck_container_port = TestHelpers.find_free_port() + user_yaml_creator = MilvusTestHelpers.create_user_yaml + with user_yaml_creator(service_container_port, max_vec_fields) as cfg: + info = None + original_tc_max_tries = testcontainers_config.max_tries + if tc_max_retries is not None: + testcontainers_config.max_tries = tc_max_retries + for i in range(vector_client_max_retries): + try: + vector_db_container = CustomMilvusContainer( + image=image, + service_container_port=service_container_port, + healthcheck_container_port=healthcheck_container_port) + vector_db_container = vector_db_container.with_volume_mapping( + cfg, "/milvus/configs/user.yaml") + vector_db_container.start() + host = vector_db_container.get_container_host_ip() + port = vector_db_container.get_exposed_port(service_container_port) + info = VectorDBContainerInfo(vector_db_container, host, port) + _LOGGER.info( + "milvus db container started successfully on %s.", info.uri) + except Exception as e: + stdout_logs, stderr_logs = vector_db_container.get_logs() + stdout_logs = stdout_logs.decode("utf-8") + stderr_logs = stderr_logs.decode("utf-8") + _LOGGER.warning( + "Retry %d/%d: Failed to start Milvus DB container. Reason: %s. " + "STDOUT logs:\n%s\nSTDERR logs:\n%s", + i + 1, + vector_client_max_retries, + e, + stdout_logs, + stderr_logs) + if i == vector_client_max_retries - 1: + _LOGGER.error( + "Unable to start milvus db container for I/O tests after %d " + "retries. Tests cannot proceed. STDOUT logs:\n%s\n" + "STDERR logs:\n%s", + vector_client_max_retries, + stdout_logs, + stderr_logs) + raise e + finally: + testcontainers_config.max_tries = original_tc_max_tries + return info + + @staticmethod + def stop_db_container(db_info: VectorDBContainerInfo): + if db_info is None: + _LOGGER.warning("Milvus db info is None. Skipping stop operation.") + return + _LOGGER.debug("Stopping milvus db container.") + db_info.container.stop() + _LOGGER.info("milvus db container stopped successfully.") + + @staticmethod + @contextlib.contextmanager + def create_user_yaml(service_port: int, max_vector_field_num=5): + """Creates a temporary user.yaml file for Milvus configuration. + + This user yaml file overrides Milvus default configurations. It sets + the Milvus service port to the specified container service port. The + default for maxVectorFieldNum is 4, but we need 5 + (one unique field for each metric). + + Args: + service_port: Port number for the Milvus service. + max_vector_field_num: Max number of vec fields allowed per collection. + + Yields: + str: Path to the created temporary yaml file. + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', + delete=False) as temp_file: + # Define the content for user.yaml. + user_config = { + 'proxy': { + 'maxVectorFieldNum': max_vector_field_num, 'port': service_port + }, + 'etcd': { + 'use': { + 'embed': True + }, 'data': { + 'dir': '/var/lib/milvus/etcd' + } + } + } + + # Write the content to the file. + yaml.dump(user_config, temp_file, default_flow_style=False) + path = temp_file.name + + try: + yield path + finally: + if os.path.exists(path): + os.remove(path) + + @staticmethod + def assert_chunks_equivalent( + actual_chunks: List[Chunk], expected_chunks: List[Chunk]): + """assert_chunks_equivalent checks for presence rather than exact match""" + # Sort both lists by ID to ensure consistent ordering. + actual_sorted = sorted(actual_chunks, key=lambda c: c.id) + expected_sorted = sorted(expected_chunks, key=lambda c: c.id) + + actual_len = len(actual_sorted) + expected_len = len(expected_sorted) + err_msg = ( + f"Different number of chunks, actual: {actual_len}, " + f"expected: {expected_len}") + assert actual_len == expected_len, err_msg + + for actual, expected in zip(actual_sorted, expected_sorted): + # Assert that IDs match. + assert actual.id == expected.id + + # Assert that dense embeddings match. + err_msg = f"Dense embedding mismatch for chunk {actual.id}" + assert actual.dense_embedding == expected.dense_embedding, err_msg + + # Assert that sparse embeddings match. + err_msg = f"Sparse embedding mismatch for chunk {actual.id}" + assert actual.sparse_embedding == expected.sparse_embedding, err_msg + + # Assert that text content match. + err_msg = f"Text Content mismatch for chunk {actual.id}" + assert actual.content.text == expected.content.text, err_msg + + # For enrichment_data, be more flexible. + # If "expected" has values for enrichment_data but actual doesn't, that's + # acceptable since vector search results can vary based on many factors + # including implementation details, vector database state, and slight + # variations in similarity calculations. + + # First ensure the enrichment data key exists. + err_msg = f"Missing enrichment_data key in chunk {actual.id}" + assert 'enrichment_data' in actual.metadata, err_msg + + # For enrichment_data, ensure consistent ordering of results. + actual_data = actual.metadata['enrichment_data'] + expected_data = expected.metadata['enrichment_data'] + + # If actual has enrichment data, then perform detailed validation. + if actual_data: + # Ensure the id key exist. + err_msg = f"Missing id key in metadata {actual.id}" + assert 'id' in actual_data, err_msg + + # Validate IDs have consistent ordering. + actual_ids = sorted(actual_data['id']) + expected_ids = sorted(expected_data['id']) + err_msg = f"IDs in enrichment_data don't match for chunk {actual.id}" + assert actual_ids == expected_ids, err_msg + + # Ensure the distance key exist. + err_msg = f"Missing distance key in metadata {actual.id}" + assert 'distance' in actual_data, err_msg + + # Validate distances exist and have same length as IDs. + actual_distances = actual_data['distance'] + expected_distances = expected_data['distance'] + err_msg = ( + "Number of distances doesn't match number of IDs for " + f"chunk {actual.id}") + assert len(actual_distances) == len(expected_distances), err_msg + + # Ensure the fields key exist. + err_msg = f"Missing fields key in metadata {actual.id}" + assert 'fields' in actual_data, err_msg + + # Validate fields have consistent content. + # Sort fields by 'id' to ensure consistent ordering. + actual_fields_sorted = sorted( + actual_data['fields'], key=lambda f: f.get('id', 0)) + expected_fields_sorted = sorted( + expected_data['fields'], key=lambda f: f.get('id', 0)) + + # Compare field IDs. + actual_field_ids = [f.get('id') for f in actual_fields_sorted] + expected_field_ids = [f.get('id') for f in expected_fields_sorted] + err_msg = f"Field IDs don't match for chunk {actual.id}" + assert actual_field_ids == expected_field_ids, err_msg + + # Compare field content. + for a_f, e_f in zip(actual_fields_sorted, expected_fields_sorted): + # Ensure the id key exist. + err_msg = f"Missing id key in metadata.fields {actual.id}" + assert 'id' in a_f, err_msg + + err_msg = f"Field ID mismatch chunk {actual.id}" + assert a_f['id'] == e_f['id'], err_msg + + # Validate field metadata. + err_msg = f"Field Metadata doesn't match for chunk {actual.id}" + assert a_f['metadata'] == e_f['metadata'], err_msg diff --git a/sdks/python/apache_beam/ml/rag/utils.py b/sdks/python/apache_beam/ml/rag/utils.py new file mode 100644 index 000000000000..c9c39d074c4a --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/utils.py @@ -0,0 +1,129 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple +import uuid + +from apache_beam.ml.rag.types import Chunk, Content, Embedding + +# Default batch size for writing data to Milvus, matching +# JdbcIO.DEFAULT_BATCH_SIZE. +DEFAULT_WRITE_BATCH_SIZE = 1000 + + +@dataclass +class MilvusConnectionParameters: + """Configurations for establishing connections to Milvus servers. + + Args: + uri: URI endpoint for connecting to Milvus server in the format + "http(s)://hostname:port". + user: Username for authentication. Required if authentication is enabled and + not using token authentication. + password: Password for authentication. Required if authentication is enabled + and not using token authentication. + db_name: Database Name to connect to. Specifies which Milvus database to + use. Defaults to 'default'. + token: Authentication token as an alternative to username/password. + timeout: Connection timeout in seconds. Uses client default if None. + kwargs: Optional keyword arguments for additional connection parameters. + Enables forward compatibility. + """ + uri: str + user: str = field(default_factory=str) + password: str = field(default_factory=str) + db_name: str = "default" + token: str = field(default_factory=str) + timeout: Optional[float] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def __post_init__(self): + if not self.uri: + raise ValueError("URI must be provided for Milvus connection") + + # Generate unique alias if not provided. One-to-one mapping between alias + # and connection - each alias represents exactly one Milvus connection. + if "alias" not in self.kwargs: + alias = f"milvus_conn_{uuid.uuid4().hex[:8]}" + self.kwargs["alias"] = alias + + +class MilvusHelpers: + """Utility class providing helper methods for Milvus vector db operations.""" + @staticmethod + def sparse_embedding( + sparse_vector: Tuple[List[int], + List[float]]) -> Optional[Dict[int, float]]: + if not sparse_vector: + return None + # Converts sparse embedding from (indices, values) tuple format to + # Milvus-compatible values dict format {dimension_index: value, ...}. + indices, values = sparse_vector + return {int(idx): float(val) for idx, val in zip(indices, values)} + + +def parse_chunk_strings(chunk_str_list: List[str]) -> List[Chunk]: + parsed_chunks = [] + + # Define safe globals and disable built-in functions for safety. + safe_globals = { + 'Chunk': Chunk, + 'Content': Content, + 'Embedding': Embedding, + 'defaultdict': defaultdict, + 'list': list, + '__builtins__': {} + } + + for raw_str in chunk_str_list: + try: + # replace "" with actual list reference. + cleaned_str = re.sub( + r"defaultdict\(", "defaultdict(list", raw_str) + + # Evaluate string in restricted environment. + chunk = eval(cleaned_str, safe_globals) # pylint: disable=eval-used + if isinstance(chunk, Chunk): + parsed_chunks.append(chunk) + else: + raise ValueError("Parsed object is not a Chunk instance") + except Exception as e: + raise ValueError(f"Error parsing string:\n{raw_str}\n{e}") + + return parsed_chunks + + +def unpack_dataclass_with_kwargs(dataclass_instance): + """Unpacks dataclass fields into a flat dict, merging kwargs with precedence. + + Args: + dataclass_instance: Dataclass instance to unpack. + + Returns: + dict: Flattened dictionary with kwargs taking precedence over fields. + """ + # Create a copy of the dataclass's __dict__. + params_dict: dict = dataclass_instance.__dict__.copy() + + # Extract the nested kwargs dictionary. + nested_kwargs = params_dict.pop('kwargs', {}) + + # Merge the dictionaries, with nested_kwargs taking precedence + # in case of duplicate keys. + return {**params_dict, **nested_kwargs} From f3a0b880a8769947f21877aa3ad69d5901e4b56b Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 31 Oct 2025 16:41:16 +0000 Subject: [PATCH 05/35] sdks/python: fix linting issues --- .../apache_beam/ml/rag/enrichment/milvus_search.py | 1 - .../ml/rag/enrichment/milvus_search_it_test.py | 11 ----------- .../apache_beam/ml/rag/ingestion/milvus_search.py | 4 +--- .../ml/rag/ingestion/milvus_search_it_test.py | 6 +++++- sdks/python/apache_beam/ml/rag/test_utils.py | 2 +- 5 files changed, 7 insertions(+), 17 deletions(-) diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py index d488c8d3d80d..7a0c38d6d90e 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py @@ -25,7 +25,6 @@ from typing import Optional from typing import Tuple from typing import Union -import uuid from google.protobuf.json_format import MessageToDict from pymilvus import AnnSearchRequest diff --git a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py index 094788664bdb..ed6f52e004fa 100644 --- a/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/enrichment/milvus_search_it_test.py @@ -15,25 +15,17 @@ # limitations under the License. # -import contextlib import logging -import os import platform -import re -import socket -import tempfile import unittest -from collections import defaultdict from dataclasses import dataclass from dataclasses import field from typing import Callable from typing import Dict from typing import List -from typing import Optional from typing import cast import pytest -import yaml import apache_beam as beam from apache_beam.ml.rag.types import Chunk @@ -53,9 +45,6 @@ MilvusClient, RRFRanker) from pymilvus.milvus_client import IndexParams - from testcontainers.core.config import testcontainers_config - from testcontainers.core.generic import DbContainer - from testcontainers.milvus import MilvusContainer from apache_beam.transforms.enrichment import Enrichment from apache_beam.ml.rag.test_utils import ( MilvusTestHelpers, VectorDBContainerInfo) diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py index 041349efeb77..e93dbbef776f 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search.py @@ -15,7 +15,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List, NamedTuple, Optional +from typing import Any, Callable, Dict, List, Optional from pymilvus import MilvusClient @@ -29,7 +29,6 @@ from apache_beam.ml.rag.types import Chunk from apache_beam.ml.rag.utils import ( MilvusHelpers, unpack_dataclass_with_kwargs, DEFAULT_WRITE_BATCH_SIZE) -from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs from apache_beam.transforms import DoFn from apache_beam.ml.rag.utils import MilvusConnectionParameters @@ -117,7 +116,6 @@ def create_converter(self) -> Callable[[Chunk], Dict[str, Any]]: A function that takes a Chunk and returns a dictionary representing a Milvus record with fields mapped according to column_specs. """ - """Creates a function to convert Chunks to records.""" def convert(chunk: Chunk) -> Dict[str, Any]: result = {} for col in self.column_specs: diff --git a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py index f8f01d9d5964..b2871ce431ef 100644 --- a/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py +++ b/sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py @@ -112,7 +112,11 @@ def create_collection_with_partition( client: MilvusClient, collection_name: str, partition_name: str = '', - fields=MILVUS_INGESTION_IT_CONFIG["fields"]): + fields=None): + + if fields is None: + fields = MILVUS_INGESTION_IT_CONFIG["fields"] + # Configure schema. schema = CollectionSchema(fields=fields) diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py b/sdks/python/apache_beam/ml/rag/test_utils.py index 9a46f46397eb..4e87f2e884a1 100644 --- a/sdks/python/apache_beam/ml/rag/test_utils.py +++ b/sdks/python/apache_beam/ml/rag/test_utils.py @@ -4,7 +4,7 @@ import socket import tempfile import logging -from typing import Dict, List, Optional +from typing import List, Optional from testcontainers.core.config import testcontainers_config from testcontainers.core.generic import DbContainer from testcontainers.milvus import MilvusContainer From 4cbe014851de5db9280930d5c6366690df24e0e3 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 31 Oct 2025 16:42:13 +0000 Subject: [PATCH 06/35] sdks/python: add missing apache beam liscense header for `test_utils.py` --- sdks/python/apache_beam/ml/rag/test_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/sdks/python/apache_beam/ml/rag/test_utils.py b/sdks/python/apache_beam/ml/rag/test_utils.py index 4e87f2e884a1..325babdc7037 100644 --- a/sdks/python/apache_beam/ml/rag/test_utils.py +++ b/sdks/python/apache_beam/ml/rag/test_utils.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import contextlib from dataclasses import dataclass import os From 461c8fee9d1d4b63b63558d188f88f3e79856309 Mon Sep 17 00:00:00 2001 From: Mohamed Awnallah Date: Fri, 31 Oct 2025 17:18:21 +0000 Subject: [PATCH 07/35] notebooks/beam-ml: use new refactored code in milvus enrichment handler --- .../beam-ml/milvus_enrichment_transform.ipynb | 338 +++++++++++++----- 1 file changed, 243 insertions(+), 95 deletions(-) diff --git a/examples/notebooks/beam-ml/milvus_enrichment_transform.ipynb b/examples/notebooks/beam-ml/milvus_enrichment_transform.ipynb index 2dbd038f3086..113038e56984 100644 --- a/examples/notebooks/beam-ml/milvus_enrichment_transform.ipynb +++ b/examples/notebooks/beam-ml/milvus_enrichment_transform.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 14, "id": "47053bac", "metadata": {}, "outputs": [], @@ -67,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 15, "id": "e550cd55-e91e-4d43-b1bd-b0e89bb8cbd9", "metadata": {}, "outputs": [], @@ -80,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 16, "id": "31747c45-107a-49be-8885-5a6cc9dc1236", "metadata": {}, "outputs": [ @@ -88,9 +88,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: There was an error checking the latest version of pip.\u001b[0m\u001b[33m\n", - "\u001b[0m" + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.3\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.3\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" ] } ], @@ -103,19 +106,10 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 17, "id": "666e0c2b-0341-4b0e-8d73-561abc39bb10", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dev/beam/sdks/python/.venv/lib/python3.9/site-packages/pydantic/_internal/_generate_schema.py:2249: UnsupportedFieldAttributeWarning: The 'validate_default' attribute with value True was provided to the `Field()` function, which has no effect in the context it was used. 'validate_default' is field-specific metadata, and can only be attached to a model field using `Annotated` metadata or by assignment. This may have happened because an `Annotated` type alias using the `type` statement was used, or if the `Field()` function was attached to a single member of a union type.\n", - " warnings.warn(\n" - ] - } - ], + "outputs": [], "source": [ "# Standard library imports\n", "from collections import defaultdict\n", @@ -149,13 +143,13 @@ "from apache_beam.ml.rag.types import Chunk, Content, Embedding\n", "from apache_beam.ml.rag.chunking.base import ChunkingTransformProvider\n", "from apache_beam.ml.rag.embeddings.huggingface import HuggingfaceTextEmbeddings\n", - "from apache_beam.ml.rag.enrichment.milvus_search_it_test import MilvusEnrichmentTestHelper\n", + "from apache_beam.ml.rag.enrichment.milvus_search_it_test import MilvusTestHelpers\n", + "from apache_beam.ml.rag.utils import MilvusConnectionParameters\n", "from apache_beam.ml.rag.enrichment.milvus_search import (\n", " HybridSearchParameters, \n", " KeywordSearchMetrics, \n", " KeywordSearchParameters,\n", " MilvusCollectionLoadParameters, \n", - " MilvusConnectionParameters, \n", " MilvusSearchEnrichmentHandler,\n", " MilvusSearchParameters, \n", " SearchStrategy, \n", @@ -194,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 18, "id": "38781cf5-e18f-40f5-827e-2d441ae7d2fa", "metadata": {}, "outputs": [], @@ -287,7 +281,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 19, "id": "489e93b6-de41-4ec3-be33-a15c3cba12e8", "metadata": {}, "outputs": [ @@ -364,7 +358,7 @@ "max 312.000000" ] }, - "execution_count": 6, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -379,7 +373,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "id": "eb32aad0-febd-45af-b4bd-e2176b07e2dc", "metadata": {}, "outputs": [ @@ -424,7 +418,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 21, "id": "5ae9bc82-9ad7-46dd-b254-19cbdcdd0e07", "metadata": {}, "outputs": [], @@ -435,30 +429,30 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 22, "id": "aff7b261-3330-4fa9-9a54-3fd87b42521f", "metadata": {}, "outputs": [], "source": [ "if db:\n", " # Stop existing Milvus DB container to prevent duplicates.\n", - " MilvusEnrichmentTestHelper.stop_db_container(db)\n", - "db = MilvusEnrichmentTestHelper.start_db_container(milvus_version)" + " MilvusTestHelpers.stop_db_container(db)\n", + "db = MilvusTestHelpers.start_db_container(milvus_version)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 26, "id": "31496ee0-75a2-48ad-954e-9c4ae5abbf5e", "metadata": {}, "outputs": [], "source": [ - "milvus_connection_parameters = MilvusConnectionParameters(uri=db.uri, user=db.user, password=db.password, db_id=db.id)" + "milvus_connection_parameters = MilvusConnectionParameters(uri=db.uri, user=db.user, password=db.password, db_name=db.id)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 27, "id": "82627714-2425-4058-9b47-d262f015caf7", "metadata": {}, "outputs": [], @@ -468,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 28, "id": "e8a85f51-5d5f-4533-bf0f-ec825e613dc2", "metadata": {}, "outputs": [ @@ -478,7 +472,7 @@ "'2.5.10'" ] }, - "execution_count": 12, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -505,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 29, "id": "e3847821-069c-412f-8c20-2406bcac1e55", "metadata": {}, "outputs": [], @@ -520,7 +514,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 30, "id": "c014af94-1bb7-44e4-842c-1039f4a2a11d", "metadata": {}, "outputs": [], @@ -545,7 +539,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 31, "id": "54fb3428-b007-4804-9d79-b3933d3256c5", "metadata": {}, "outputs": [], @@ -561,7 +555,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 32, "id": "4c2f123a-5949-4974-af48-a5db5b168c11", "metadata": {}, "outputs": [ @@ -571,7 +565,7 @@ "{'auto_id': False, 'description': '', 'fields': [{'name': 'id', 'description': '', 'type': , 'params': {'max_length': 100}, 'is_primary': True, 'auto_id': False}, {'name': 'content', 'description': '', 'type': , 'params': {'max_length': 65279}}, {'name': 'embedding', 'description': '', 'type': , 'params': {'dim': 384}}, {'name': 'sparse_embedding', 'description': '', 'type': , 'is_function_output': True}, {'name': 'metadata', 'description': '', 'type': }, {'name': 'title_and_content', 'description': '', 'type': , 'params': {'max_length': 65535, 'enable_analyzer': True}}], 'enable_dynamic_field': False, 'functions': [{'name': 'content_bm25_emb', 'description': '', 'type': , 'input_field_names': ['title_and_content'], 'output_field_names': ['sparse_embedding'], 'params': {}}]}" ] }, - "execution_count": 16, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" } @@ -591,7 +585,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 33, "id": "671f4352-2086-4428-83be-0de48926682d", "metadata": {}, "outputs": [], @@ -609,7 +603,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 34, "id": "aa8baae5-7c38-4e78-ace4-304c7dc6b127", "metadata": {}, "outputs": [], @@ -632,7 +626,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 35, "id": "d970a35b-f9b2-4f8f-93ef-8de5c83c31b5", "metadata": {}, "outputs": [], @@ -647,7 +641,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 36, "id": "0d45a6ad-2009-4e30-b38d-73266da98a06", "metadata": {}, "outputs": [ @@ -658,7 +652,7 @@ " {'field_name': 'sparse_embedding', 'index_type': 'SPARSE_INVERTED_INDEX', 'index_name': 'sparse_inverted_index', 'inverted_index_algo': 'DAAT_MAXSCORE', 'bm25_k1': 1.2, 'bm25_b': 0.75, 'metric_type': 'BM25'}]" ] }, - "execution_count": 20, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -677,7 +671,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 37, "id": "51dd4423-240c-4271-bb8c-6270f399a25c", "metadata": {}, "outputs": [], @@ -687,7 +681,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 38, "id": "9620b1f2-51fa-491c-ad3f-f0676b9b25f6", "metadata": {}, "outputs": [], @@ -697,7 +691,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 39, "id": "e6cf3a1d-265c-44db-aba8-d491fab290d5", "metadata": {}, "outputs": [], @@ -707,7 +701,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 40, "id": "94497411-43d3-4300-98b3-1cb33759738e", "metadata": {}, "outputs": [ @@ -717,7 +711,7 @@ "True" ] }, - "execution_count": 24, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -736,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 41, "id": "25c5c202-abe0-4d11-82df-e731f0d6201e", "metadata": { "scrolled": true @@ -783,6 +777,160 @@ "WARNING:root:This output type hint will be ignored and not used for type-checking purposes. Typically, output type hints for a PTransform are single (or nested) types wrapped by a PCollection, PDone, or None. Got: Union[Tuple[apache_beam.pvalue.PCollection[~MLTransformOutputT], apache_beam.pvalue.PCollection[apache_beam.pvalue.Row]], apache_beam.pvalue.PCollection[~MLTransformOutputT]] instead.\n" ] }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fb92c794ace141d6a6673d8cb5cffc54", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "modules.json: 0%| | 0.00/349 [00:00\n", - "
\n", + "
\n", "
\n", " Processing... show\n", "
\n", @@ -830,7 +978,7 @@ " }\n", " \n", " \n", - "
\n", + "
\n", "