diff --git a/README.md b/README.md index bbd0b89..3eac4e5 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,7 @@ services: 1. Run tests ```bash - make pytest + make test ``` 1. Run the Actor locally diff --git a/code/src/vcs.py b/code/src/vcs.py index fc35421..2ffb0d5 100644 --- a/code/src/vcs.py +++ b/code/src/vcs.py @@ -133,8 +133,8 @@ def _get_item_id(item_id: str) -> tuple[str, list[Document]]: try: item_id, documents = future.result() crawled_db[item_id].extend(documents) - except Exception: - Actor.log.exception("Item_id %s generated an exception", item_id) + except Exception as e: + Actor.log.error("Item_id %s generated an error", item_id, e) return dict(crawled_db) diff --git a/code/src/vector_stores/chroma.py b/code/src/vector_stores/chroma.py index 2d05cd4..9831f98 100644 --- a/code/src/vector_stores/chroma.py +++ b/code/src/vector_stores/chroma.py @@ -2,13 +2,15 @@ from datetime import datetime, timezone from functools import partial -from typing import TYPE_CHECKING, Any, Iterator +from typing import TYPE_CHECKING, Any, Iterator, TypeVar +import backoff import chromadb +from chromadb.errors import ChromaError from langchain_chroma import Chroma from langchain_core.documents import Document -from .base import VectorDbBase +from .base import BACKOFF_MAX_TIME_DELETE_SECONDS, BACKOFF_MAX_TIME_SECONDS, VectorDbBase if TYPE_CHECKING: from langchain_core.embeddings import Embeddings @@ -17,8 +19,10 @@ BATCH_SIZE = 300 # Chroma's default (max) size, number of documents to insert in a single request. +T = TypeVar("T") -def batch(seq: list[Document], size: int) -> Iterator[list[Document]]: + +def batch(seq: list[T], size: int) -> Iterator[list[T]]: if size <= 0: raise ValueError("size must be > 0") for i in range(0, len(seq), size): @@ -64,6 +68,7 @@ async def is_connected(self) -> bool: return False return True + @backoff.on_exception(backoff.expo, ChromaError, max_time=BACKOFF_MAX_TIME_SECONDS) def get_by_item_id(self, item_id: str) -> list[Document]: """Get documents by item_id.""" @@ -87,25 +92,66 @@ def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]: inserted_ids.extend(super().add_documents(docs_batch, **batch_kwargs)) return inserted_ids + @backoff.on_exception(backoff.expo, ChromaError, max_time=BACKOFF_MAX_TIME_SECONDS) def update_last_seen_at(self, ids: list[str], last_seen_at: int | None = None) -> None: - """Update last_seen_at field in the database.""" + """Update last_seen_at field in the database. + + Large updates are split into batches (self.batch_size) to avoid oversized requests. + """ + if not ids: + return last_seen_at = last_seen_at or int(datetime.now(timezone.utc).timestamp()) - for _id in ids: - self.index.update(ids=_id, metadatas=[{"last_seen_at": last_seen_at}]) + batch_size = self.batch_size + for ids_batch in batch(ids, batch_size): + self.index.update(ids=ids_batch, metadatas=[{"last_seen_at": last_seen_at} for _ in ids_batch]) + @backoff.on_exception(backoff.expo, ChromaError, max_time=BACKOFF_MAX_TIME_DELETE_SECONDS) def delete_expired(self, expired_ts: int) -> None: - """Delete expired objects.""" - self.index.delete(where={"last_seen_at": {"$lt": expired_ts}}) # type: ignore[dict-item] + """Delete expired objects. + + Fetch all expired IDs first, then delete them in batches using the configured batch_size. + """ + r = self.index.get(where={"last_seen_at": {"$lt": expired_ts}}, include=[]) # type: ignore + ids: list[str] = r.get("ids") or [] + if not ids: + return + for ids_batch in batch(ids, self.batch_size): + self.index.delete(ids=ids_batch) def delete_by_item_id(self, item_id: str) -> None: - self.index.delete(where={"item_id": {"$eq": item_id}}) # type: ignore[dict-item] + """Delete documents by item_id. + + Fetch all documents with the given item_id and delete in batches. + """ + r = self.index.get(where={"item_id": item_id}, include=[]) + ids: list[str] = r.get("ids") or [] + if not ids: + return + + for ids_batch in batch(ids, self.batch_size): + self.index.delete(ids=ids_batch) + + def delete(self, ids: list[str] | None = None, **kwargs: Any) -> None: + """Delete objects by ids. + + Delete the object in batches to avoid exceeding the maximum request size. + """ + if not ids: + return + + for ids_batch in batch(ids, self.batch_size): + super().delete(ids=ids_batch, **kwargs) def delete_all(self) -> None: - """Delete all objects.""" + """Delete all objects. + + Delete the object in batches to avoid exceeding the maximum request size. + """ r = self.index.get() if r["ids"]: - self.delete(ids=r["ids"]) + for ids_batch in batch(r["ids"], self.batch_size): + self.index.delete(ids=ids_batch) def search_by_vector(self, vector: list[float], k: int = 1_000_000, filter_: dict | None = None) -> list[Document]: """Search by vector.""" diff --git a/code/tests/test_pinecone_namespace.py b/code/tests/test_pinecone_namespace.py index 14fca65..6113301 100644 --- a/code/tests/test_pinecone_namespace.py +++ b/code/tests/test_pinecone_namespace.py @@ -64,6 +64,7 @@ def delete_ns(namespace: str) -> None: delete_ns(NAMESPACE1) delete_ns(NAMESPACE2) + @pytest.mark.integration() @pytest.mark.skipif("db_pinecone" not in DATABASE_FIXTURES, reason="pinecone database is not enabled") def test_namespace(db_pinecone_ns: PineconeDatabase) -> None: