diff --git a/actors/chroma/.actor/input_schema.json b/actors/chroma/.actor/input_schema.json index 3b810de..1bfe275 100644 --- a/actors/chroma/.actor/input_schema.json +++ b/actors/chroma/.actor/input_schema.json @@ -47,6 +47,14 @@ "type": "string", "editor": "textfield" }, + "chromaBatchSize": { + "title": "Chroma batch size", + "description": "Number of documents to insert in a single request to Chroma. Default is 300. Lower if you experience timeouts or want finer control.", + "type": "integer", + "default": 300, + "minimum": 1, + "sectionCaption": "Chroma settings" + }, "embeddingsProvider": { "title": "Embeddings provider (as defined in the langchain API)", "description": "Choose the embeddings provider to use for generating embeddings", diff --git a/actors/chroma/README.md b/actors/chroma/README.md index e668a7c..2284f38 100644 --- a/actors/chroma/README.md +++ b/actors/chroma/README.md @@ -198,6 +198,11 @@ To disable this feature, set `deleteExpiredObjects` to `false`. Otherwise, data crawled by one Actor might expire due to inconsistent crawling schedules. +## Batch size configuration + +You can control the number of documents sent to Chroma in a single request using the `chromaBatchSize` parameter. The default is 300. Lower this value if you experience timeouts or want finer control over insert operations. + + ## 💾 Outputs This integration will save the selected fields from your Actor to Chroma. diff --git a/code/poetry.lock b/code/poetry.lock index 15d24e5..a6372ad 100644 --- a/code/poetry.lock +++ b/code/poetry.lock @@ -2473,14 +2473,14 @@ xai = ["langchain-xai"] [[package]] name = "langchain-apify" -version = "0.1.3" +version = "0.1.4" description = "An integration package connecting Apify and LangChain" optional = false python-versions = "<4.0,>=3.9" groups = ["main"] files = [ - {file = "langchain_apify-0.1.3-py3-none-any.whl", hash = "sha256:b3374f2698a372c1b2c3b29efc009b5555244b3f3bd2244270ef795dad9e4e2c"}, - {file = "langchain_apify-0.1.3.tar.gz", hash = "sha256:5631e6610e940633ff7a2cbadb165a0c2cc3232ae1b10b01f6b48752a1f5840a"}, + {file = "langchain_apify-0.1.4-py3-none-any.whl", hash = "sha256:06a36685d14eabefce2d7cc6bfdd0b76dd537b42b587c1a9fd6b79044a6bd6e1"}, + {file = "langchain_apify-0.1.4.tar.gz", hash = "sha256:dfe5d6ae5731f286e3cb84bfd66003fc195057beb6377364e9b5604086dc4305"}, ] [package.dependencies] @@ -7005,4 +7005,4 @@ cffi = ["cffi (>=1.17) ; python_version >= \"3.13\" and platform_python_implemen [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.12" -content-hash = "243ea5baac376238cc90271e29b8a8bfbb58486a39195ed76438e09b1ee25902" +content-hash = "81593130437ab2caff3ab7e4b27459df89566186a71a1a2d21a19a52e390c55b" diff --git a/code/pyproject.toml b/code/pyproject.toml index b09911f..48e4573 100644 --- a/code/pyproject.toml +++ b/code/pyproject.toml @@ -18,7 +18,7 @@ langchain-openai = "^0.2.0" openai = "^1.17.0" python = ">=3.11,<3.12" python-dotenv = "^1.0.1" -langchain-apify = "^0.1.3" +langchain-apify = "^0.1.4" [tool.poetry.group.dev.dependencies] coverage = "^7.5.4" diff --git a/code/src/models/chroma_input_model.py b/code/src/models/chroma_input_model.py index 4338cb6..4f0a8f8 100644 --- a/code/src/models/chroma_input_model.py +++ b/code/src/models/chroma_input_model.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: input_schema.json -# timestamp: 2025-08-17T21:20:07+00:00 +# timestamp: 2025-08-20T08:08:17+00:00 from __future__ import annotations @@ -36,6 +36,12 @@ class ChromaIntegration(BaseModel): chromaDatabase: Optional[str] = Field( None, description='Chroma database name', title='Chroma database' ) + chromaBatchSize: Optional[int] = Field( + 300, + description='Number of documents to insert in a single request to Chroma. Default is 300. Lower if you experience timeouts or want finer control.', + ge=1, + title='Chroma batch size', + ) embeddingsProvider: Literal['OpenAI', 'Cohere'] = Field( ..., description='Choose the embeddings provider to use for generating embeddings', diff --git a/code/src/vector_stores/chroma.py b/code/src/vector_stores/chroma.py index bbdc77b..2d05cd4 100644 --- a/code/src/vector_stores/chroma.py +++ b/code/src/vector_stores/chroma.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from functools import partial -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Iterator import chromadb from langchain_chroma import Chroma @@ -15,10 +15,18 @@ from ..models import ChromaIntegration +BATCH_SIZE = 300 # Chroma's default (max) size, number of documents to insert in a single request. + + +def batch(seq: list[Document], size: int) -> Iterator[list[Document]]: + if size <= 0: + raise ValueError("size must be > 0") + for i in range(0, len(seq), size): + yield seq[i : i + size] + class ChromaDatabase(Chroma, VectorDbBase): def __init__(self, actor_input: ChromaIntegration, embeddings: Embeddings) -> None: - # Create HttpClient using partial to handle optional parameters client_factory = partial( chromadb.HttpClient, @@ -43,6 +51,7 @@ def __init__(self, actor_input: ChromaIntegration, embeddings: Embeddings) -> No self.client = client self.index = self.client.get_or_create_collection(collection_name) self._dummy_vector: list[float] = [] + self.batch_size = actor_input.chromaBatchSize or BATCH_SIZE @property def dummy_vector(self) -> list[float]: @@ -63,6 +72,21 @@ def get_by_item_id(self, item_id: str) -> list[Document]: return [Document(page_content="", metadata={**m, "chunk_id": _id}) for _id, m in zip(ids, metadata)] return [] + def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]: + """Add documents to the index. + + We need to batch documents to avoid exceeding the maximum request size. + Chroma limits the number of records we can insert in a single request to keep the payload small. + """ + inserted_ids: list[str] = [] + batch_size = kwargs.pop("batch_size", self.batch_size) + + for docs_batch in batch(documents, batch_size): + ids = [str(doc.metadata["chunk_id"]) for doc in docs_batch] + batch_kwargs = {**kwargs, "ids": ids} + inserted_ids.extend(super().add_documents(docs_batch, **batch_kwargs)) + return inserted_ids + def update_last_seen_at(self, ids: list[str], last_seen_at: int | None = None) -> None: """Update last_seen_at field in the database.""" diff --git a/code/tests/test_chroma_batch.py b/code/tests/test_chroma_batch.py new file mode 100644 index 0000000..e529dfe --- /dev/null +++ b/code/tests/test_chroma_batch.py @@ -0,0 +1,72 @@ +# code/tests/test_chroma_batch.py +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from .conftest import DATABASE_FIXTURES + +if TYPE_CHECKING: + from _pytest.fixtures import FixtureRequest + from langchain_core.documents import Document + + from src._types import VectorDb + +from langchain_core.documents import Document + +from src.vector_stores.chroma import BATCH_SIZE, batch + + +def _make_docs(n: int) -> list[Document]: + return [Document(page_content=f"batch {i}", metadata={"chunk_id": f"batch-{i}"}) for i in range(n)] + + +def test_batch_respects_batch_size() -> None: + total = BATCH_SIZE * 2 + 5 # two full batches + remainder + docs = _make_docs(total) + + chunks = list(batch(docs, BATCH_SIZE)) + + assert len(chunks) == 3, "Expected 3 batches" + assert len(chunks[-1]) == 5, "Remainder batch size incorrect" + + +@pytest.mark.parametrize("bad_size", [0, -1]) +def test_batch_invalid_size(bad_size: int) -> None: + with pytest.raises(ValueError, match="size must be > 0"): + list(batch(_make_docs(1), bad_size)) + + +@pytest.mark.integration() +@pytest.mark.parametrize("input_db", DATABASE_FIXTURES) +@pytest.mark.skipif("db_chroma" not in DATABASE_FIXTURES, reason="chroma database is not enabled") +def test_add_documents_batches(input_db: str, request: FixtureRequest) -> None: + # Force small batch size to minimize embeddings/API calls while ensuring multiple batches. + + db: VectorDb = request.getfixturevalue(input_db) + res = db.search_by_vector(db.dummy_vector, k=10) + assert len(res) == 6, "Expected 6 initial objects in the database" + + db.delete_all() # Clear the database before testing + res = db.search_by_vector(db.dummy_vector, k=10) + assert len(res) == 0, "Expected 0 objects in the database after delete_all" + + total_new = 11 # Will require 3 batches (5 + 5 + 1) + docs = _make_docs(total_new) + + ids = [doc.metadata["chunk_id"] for doc in docs] + inserted_ids = db.add_documents(docs, batch_size=5, ids=ids) + + assert len(inserted_ids) == len(ids), "Expected all new documents inserted" + assert inserted_ids == [d.metadata["chunk_id"] for d in docs], "Order of returned IDs must match input order" + + # Verify they are really stored + res = db.search_by_vector(db.dummy_vector, k=20) + print(f"Total objects in the database after batch insert: {len(res)}") + print(res) + assert len(res) == 11, "Expected 11 objects in the database after batch insert" + + for doc in res: + assert isinstance(doc, Document), "Expected each result to be a Document instance" + assert doc.id in ids, f"Missing inserted id: {doc.id}" diff --git a/code/tests/test_pinecone_namespace.py b/code/tests/test_pinecone_namespace.py index 56039b9..14fca65 100644 --- a/code/tests/test_pinecone_namespace.py +++ b/code/tests/test_pinecone_namespace.py @@ -27,6 +27,7 @@ def wait_for_db(sec: int = 3) -> None: # Data freshness - Pinecone is eventually consistent, so there can be a slight delay before new or changed records are visible to queries. time.sleep(sec) + @pytest.mark.integration() @pytest.mark.skipif("db_pinecone" not in DATABASE_FIXTURES, reason="pinecone database is not enabled") @pytest.fixture() diff --git a/code/tests/test_vector_stores.py b/code/tests/test_vector_stores.py index a836b38..6d5574f 100644 --- a/code/tests/test_vector_stores.py +++ b/code/tests/test_vector_stores.py @@ -206,6 +206,7 @@ def test_update_db_with_crawled_data_all(input_db: str, crawl_2: list[Document], assert d.metadata["item_id"] == expected.metadata["item_id"], f"Expected item_id {expected.metadata['item_id']}" assert d.metadata["checksum"] == expected.metadata["checksum"], f"Expected checksum {expected.metadata['checksum']}" + @pytest.mark.integration() @pytest.mark.parametrize("input_db", DATABASE_FIXTURES) def test_get_delete_all(input_db: str, request: FixtureRequest) -> None: @@ -222,6 +223,7 @@ def test_get_delete_all(input_db: str, request: FixtureRequest) -> None: res = db.search_by_vector(db.dummy_vector, k=10) assert not res + @pytest.mark.integration() @pytest.mark.parametrize("input_db", DATABASE_FIXTURES) def test_delete_by_item_id(input_db: str, request: FixtureRequest) -> None: