Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
5df9628
sdks/python: replace the deprecated testcontainer max tries
mohamedawnallah Oct 30, 2025
91266a7
sdks/python: handle transient testcontainer startup/teardown errors
mohamedawnallah Oct 30, 2025
fa6d2f0
sdks/python: bump `testcontainers` py pkg version
mohamedawnallah Oct 31, 2025
9445aaa
sdks/python: integrate milvus sink I/O
mohamedawnallah Oct 31, 2025
f3a0b88
sdks/python: fix linting issues
mohamedawnallah Oct 31, 2025
4cbe014
sdks/python: add missing apache beam liscense header for `test_utils.py`
mohamedawnallah Oct 31, 2025
461c8fe
notebooks/beam-ml: use new refactored code in milvus enrichment handler
mohamedawnallah Oct 31, 2025
9d41879
CHANGES.md: update release notes
mohamedawnallah Oct 31, 2025
c64e9c9
sdks/python: mark milvus itests with `require_docker_in_docker` marker
mohamedawnallah Oct 31, 2025
825bf30
sdks/python: override milvus db version with the default
mohamedawnallah Oct 31, 2025
e6569ba
sdsk/python: add missing import in rag utils
mohamedawnallah Oct 31, 2025
281ea3f
sdks/python: fix linting issue
mohamedawnallah Oct 31, 2025
5a350b5
rag/ingestion/milvus_search_itest.py: ensure flushing in-memory data …
mohamedawnallah Nov 1, 2025
ee64600
sdks/python: fix linting issues
mohamedawnallah Nov 1, 2025
dab040a
sdks/python: fix formatting issues
mohamedawnallah Nov 1, 2025
deef266
sdks/python: fix arising linting issue
mohamedawnallah Nov 2, 2025
b4e31e8
rag: reuse `retry_with_backoff` for one-time setup operations
mohamedawnallah Nov 2, 2025
795ed60
sdks/python: fix linting issues
mohamedawnallah Nov 2, 2025
9d35585
sdks/python: fix py docs CI issue
mohamedawnallah Nov 2, 2025
119108f
sdks/python: fix linting issues
mohamedawnallah Nov 2, 2025
cfc44f6
sdks/python: fix linting issues
mohamedawnallah Nov 2, 2025
599c7f4
sdks/python: isolate milvus sink integration to be in follow-up PR
mohamedawnallah Nov 5, 2025
2ba2b33
CHANGES.md: remove milvus from release notes in the refactoring PR
mohamedawnallah Nov 5, 2025
894ab28
sdks/python: remove `with_sparse_embedding_spec` column specs builder
mohamedawnallah Nov 5, 2025
21ce084
sdks/python: fix linting issues
mohamedawnallah Nov 5, 2025
ffa5d2b
Revert "notebooks/beam-ml: use new refactored code in milvus enrichme…
mohamedawnallah Nov 6, 2025
732ae31
Merge remote-tracking branch 'upstream/master' into sinkWithMilvusIO
mohamedawnallah Nov 6, 2025
0c00044
sdks/python: fix linting issues
mohamedawnallah Nov 6, 2025
c9f4b6c
sdks/python: fix linting issues
mohamedawnallah Nov 6, 2025
17ba353
sdks/python: fix linting issues
mohamedawnallah Nov 6, 2025
90c43ea
Merge remote-tracking branch 'upstream/master' into sinkWithMilvusIO
mohamedawnallah Nov 6, 2025
560f926
sdks/python: fix linting issues
mohamedawnallah Nov 6, 2025
f72c2e1
CI: fix import errors in CI
mohamedawnallah Nov 6, 2025
0ef2c7d
sdks/python: fix linting issues
mohamedawnallah Nov 6, 2025
690fd72
sdks/python: fix linting issues
mohamedawnallah Nov 6, 2025
f01a7e5
sdks/python: fix linting issues
mohamedawnallah Nov 10, 2025
3145561
sdks/python: fix linting issues
mohamedawnallah Nov 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,11 @@
ConnectionConfig,
CloudSQLConnectionConfig,
ExternalSQLDBConnectionConfig)
from apache_beam.ml.rag.enrichment.milvus_search import (
MilvusConnectionParameters)
from apache_beam.ml.rag.enrichment.milvus_search_it_test import (
MilvusEnrichmentTestHelper,
MilvusDBContainerInfo,
parse_chunk_strings,
assert_chunks_equivalent)
from apache_beam.ml.rag.enrichment.milvus_search import MilvusConnectionParameters
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
from apache_beam.ml.rag.test_utils import VectorDBContainerInfo
from apache_beam.ml.rag.test_utils import MilvusTestHelpers
from apache_beam.ml.rag.utils import parse_chunk_strings
from apache_beam.io.requestresponse import RequestResponseIO
except ImportError as e:
raise unittest.SkipTest(f'Examples dependencies are not installed: {str(e)}')
Expand All @@ -69,6 +67,11 @@ class TestContainerStartupError(Exception):
pass


class TestContainerTeardownError(Exception):
"""Raised when any test container fails to teardown."""
pass


def validate_enrichment_with_bigtable():
expected = '''[START enrichment_with_bigtable]
Row(sale_id=1, customer_id=1, product_id=1, quantity=1, product={'product_id': '1', 'product_name': 'pixel 5', 'product_stock': '2'})
Expand Down Expand Up @@ -186,7 +189,7 @@ def test_enrichment_with_external_pg(self, mock_stdout):
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_external_pg()
self.assertEqual(output, expected)
except TestContainerStartupError as e:
except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
Expand All @@ -199,7 +202,7 @@ def test_enrichment_with_external_mysql(self, mock_stdout):
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_external_mysql()
self.assertEqual(output, expected)
except TestContainerStartupError as e:
except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
Expand All @@ -212,7 +215,7 @@ def test_enrichment_with_external_sqlserver(self, mock_stdout):
output = mock_stdout.getvalue().splitlines()
expected = validate_enrichment_with_external_sqlserver()
self.assertEqual(output, expected)
except TestContainerStartupError as e:
except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
Expand All @@ -226,8 +229,8 @@ def test_enrichment_with_milvus(self, mock_stdout):
self.maxDiff = None
output = parse_chunk_strings(output)
expected = parse_chunk_strings(expected)
assert_chunks_equivalent(output, expected)
except TestContainerStartupError as e:
MilvusTestHelpers.assert_chunks_equivalent(output, expected)
except (TestContainerStartupError, TestContainerTeardownError) as e:
raise unittest.SkipTest(str(e))
except Exception as e:
self.fail(f"Test failed with unexpected error: {e}")
Expand Down Expand Up @@ -257,7 +260,7 @@ def sql_test_context(is_cloudsql: bool, db_adapter: DatabaseTypeAdapter):
@staticmethod
@contextmanager
def milvus_test_context():
db: Optional[MilvusDBContainerInfo] = None
db: Optional[VectorDBContainerInfo] = None
try:
db = EnrichmentTestHelpers.pre_milvus_enrichment()
yield
Expand Down Expand Up @@ -370,23 +373,21 @@ def post_sql_enrichment_test(res: CloudSQLEnrichmentTestDataConstruct):
os.environ.pop('GOOGLE_CLOUD_SQL_DB_TABLE_ID', None)

@staticmethod
def pre_milvus_enrichment() -> MilvusDBContainerInfo:
def pre_milvus_enrichment() -> VectorDBContainerInfo:
try:
db = MilvusEnrichmentTestHelper.start_db_container()
db = MilvusTestHelpers.start_db_container()
connection_params = MilvusConnectionParameters(
uri=db.uri,
user=db.user,
password=db.password,
db_id=db.id,
token=db.token)
collection_name = MilvusTestHelpers.initialize_db_with_data(
connection_params)
except Exception as e:
raise TestContainerStartupError(
f"Milvus container failed to start: {str(e)}")

connection_params = MilvusConnectionParameters(
uri=db.uri,
user=db.user,
password=db.password,
db_id=db.id,
token=db.token)

collection_name = MilvusEnrichmentTestHelper.initialize_db_with_data(
connection_params)

# Setup environment variables for db and collection configuration. This will
# be used downstream by the milvus enrichment handler.
os.environ['MILVUS_VECTOR_DB_URI'] = db.uri
Expand All @@ -399,8 +400,13 @@ def pre_milvus_enrichment() -> MilvusDBContainerInfo:
return db

@staticmethod
def post_milvus_enrichment(db: MilvusDBContainerInfo):
MilvusEnrichmentTestHelper.stop_db_container(db)
def post_milvus_enrichment(db: VectorDBContainerInfo):
try:
MilvusTestHelpers.stop_db_container(db)
except Exception as e:
raise TestContainerTeardownError(
f"Milvus container failed to tear down: {str(e)}")

os.environ.pop('MILVUS_VECTOR_DB_URI', None)
os.environ.pop('MILVUS_VECTOR_DB_USER', None)
os.environ.pop('MILVUS_VECTOR_DB_PASSWORD', None)
Expand Down
133 changes: 38 additions & 95 deletions sdks/python/apache_beam/ml/rag/enrichment/milvus_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@
from pymilvus import Hits
from pymilvus import MilvusClient
from pymilvus import SearchResult
from pymilvus.exceptions import MilvusException

from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.rag.utils import MilvusConnectionParameters
from apache_beam.ml.rag.utils import MilvusHelpers
from apache_beam.ml.rag.utils import retry_with_backoff
from apache_beam.ml.rag.utils import unpack_dataclass_with_kwargs
from apache_beam.transforms.enrichment import EnrichmentSourceHandler


Expand Down Expand Up @@ -104,44 +109,6 @@ def __str__(self):
return self.dict().__str__()


@dataclass
class MilvusConnectionParameters:
"""Parameters for establishing connections to Milvus servers.

Args:
uri: URI endpoint for connecting to Milvus server in the format
"http(s)://hostname:port".
user: Username for authentication. Required if authentication is enabled and
not using token authentication.
password: Password for authentication. Required if authentication is enabled
and not using token authentication.
db_id: Database ID to connect to. Specifies which Milvus database to use.
Defaults to 'default'.
token: Authentication token as an alternative to username/password.
timeout: Connection timeout in seconds. Uses client default if None.
max_retries: Maximum number of connection retry attempts. Defaults to 3.
retry_delay: Initial delay between retries in seconds. Defaults to 1.0.
retry_backoff_factor: Multiplier for retry delay after each attempt.
Defaults to 2.0 (exponential backoff).
kwargs: Optional keyword arguments for additional connection parameters.
Enables forward compatibility.
"""
uri: str
user: str = field(default_factory=str)
password: str = field(default_factory=str)
db_id: str = "default"
token: str = field(default_factory=str)
timeout: Optional[float] = None
max_retries: int = 3
retry_delay: float = 1.0
retry_backoff_factor: float = 2.0
kwargs: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
if not self.uri:
raise ValueError("URI must be provided for Milvus connection")


@dataclass
class BaseSearchParameters:
"""Base parameters for both vector and keyword search operations.
Expand Down Expand Up @@ -361,15 +328,15 @@ def __init__(
**kwargs):
"""
Example Usage:
connection_paramters = MilvusConnectionParameters(
connection_parameters = MilvusConnectionParameters(
uri="http://localhost:19530")
search_parameters = MilvusSearchParameters(
collection_name="my_collection",
search_strategy=VectorSearchParameters(anns_field="embedding"))
collection_load_parameters = MilvusCollectionLoadParameters(
load_fields=["embedding", "metadata"]),
milvus_handler = MilvusSearchEnrichmentHandler(
connection_paramters,
connection_parameters,
search_parameters,
collection_load_parameters=collection_load_parameters,
min_batch_size=10,
Expand Down Expand Up @@ -407,52 +374,43 @@ def __init__(
'min_batch_size': min_batch_size, 'max_batch_size': max_batch_size
}
self.kwargs = kwargs
self._client = None
self.join_fn = join_fn
self.use_custom_types = True

def __enter__(self):
import logging
import time

from pymilvus.exceptions import MilvusException

connection_params = unpack_dataclass_with_kwargs(
self._connection_parameters)
collection_load_params = unpack_dataclass_with_kwargs(
self._collection_load_parameters)

# Extract retry parameters from connection_params
max_retries = connection_params.pop('max_retries', 3)
retry_delay = connection_params.pop('retry_delay', 1.0)
retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)

# Retry logic for MilvusClient connection
last_exception = None
for attempt in range(max_retries + 1):
try:
self._client = MilvusClient(**connection_params)
self._client.load_collection(
"""Enters the context manager and establishes Milvus connection.

Returns:
Self, enabling use in 'with' statements.
"""
if not self._client:
connection_params = unpack_dataclass_with_kwargs(
self._connection_parameters)
collection_load_params = unpack_dataclass_with_kwargs(
self._collection_load_parameters)

# Extract retry parameters from connection_params.
max_retries = connection_params.pop('max_retries', 3)
retry_delay = connection_params.pop('retry_delay', 1.0)
retry_backoff_factor = connection_params.pop('retry_backoff_factor', 2.0)

def connect_and_load():
client = MilvusClient(**connection_params)
client.load_collection(
collection_name=self.collection_name,
partition_names=self.partition_names,
**collection_load_params)
logging.info(
"Successfully connected to Milvus on attempt %d", attempt + 1)
return
except MilvusException as e:
last_exception = e
if attempt < max_retries:
delay = retry_delay * (retry_backoff_factor**attempt)
logging.warning(
"Milvus connection attempt %d failed: %s. "
"Retrying in %.2f seconds...",
attempt + 1,
e,
delay)
time.sleep(delay)
else:
logging.error(
"Failed to connect to Milvus after %d attempts", max_retries + 1)
raise last_exception
return client

self._client = retry_with_backoff(
connect_and_load,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff_factor=retry_backoff_factor,
operation_name="Milvus connection and collection load",
exception_types=(MilvusException, ))
return self

def __call__(self, request: Union[Chunk, List[Chunk]], *args,
**kwargs) -> List[Tuple[Chunk, Dict[str, Any]]]:
Expand Down Expand Up @@ -535,10 +493,7 @@ def _get_keyword_search_data(self, chunk: Chunk):
raise ValueError(
f"Chunk {chunk.id} missing both text content and sparse embedding "
"required for keyword search")

sparse_embedding = self.convert_sparse_embedding_to_milvus_format(
chunk.sparse_embedding)

sparse_embedding = MilvusHelpers.sparse_embedding(chunk.sparse_embedding)
return chunk.content.text or sparse_embedding

def _get_call_response(
Expand Down Expand Up @@ -628,15 +583,3 @@ def batch_elements_kwargs(self) -> Dict[str, int]:
def join_fn(left: Embedding, right: Dict[str, Any]) -> Embedding:
left.metadata['enrichment_data'] = right
return left


def unpack_dataclass_with_kwargs(dataclass_instance):
# Create a copy of the dataclass's __dict__.
params_dict: dict = dataclass_instance.__dict__.copy()

# Extract the nested kwargs dictionary.
nested_kwargs = params_dict.pop('kwargs', {})

# Merge the dictionaries, with nested_kwargs taking precedence
# in case of duplicate keys.
return {**params_dict, **nested_kwargs}
Loading
Loading