Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ services:

1. Run tests
```bash
make pytest
make test
```

1. Run the Actor locally
Expand Down
4 changes: 2 additions & 2 deletions code/src/vcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def _get_item_id(item_id: str) -> tuple[str, list[Document]]:
try:
item_id, documents = future.result()
crawled_db[item_id].extend(documents)
except Exception:
Actor.log.exception("Item_id %s generated an exception", item_id)
except Exception as e:
Actor.log.error("Item_id %s generated an error", item_id, e)

return dict(crawled_db)

Expand Down
68 changes: 57 additions & 11 deletions code/src/vector_stores/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from datetime import datetime, timezone
from functools import partial
from typing import TYPE_CHECKING, Any, Iterator
from typing import TYPE_CHECKING, Any, Iterator, TypeVar

import backoff
import chromadb
from chromadb.errors import ChromaError
from langchain_chroma import Chroma
from langchain_core.documents import Document

from .base import VectorDbBase
from .base import BACKOFF_MAX_TIME_DELETE_SECONDS, BACKOFF_MAX_TIME_SECONDS, VectorDbBase

if TYPE_CHECKING:
from langchain_core.embeddings import Embeddings
Expand All @@ -17,8 +19,10 @@

BATCH_SIZE = 300 # Chroma's default (max) size, number of documents to insert in a single request.

T = TypeVar("T")

def batch(seq: list[Document], size: int) -> Iterator[list[Document]]:

def batch(seq: list[T], size: int) -> Iterator[list[T]]:
if size <= 0:
raise ValueError("size must be > 0")
for i in range(0, len(seq), size):
Expand Down Expand Up @@ -64,6 +68,7 @@ async def is_connected(self) -> bool:
return False
return True

@backoff.on_exception(backoff.expo, ChromaError, max_time=BACKOFF_MAX_TIME_SECONDS)
def get_by_item_id(self, item_id: str) -> list[Document]:
"""Get documents by item_id."""

Expand All @@ -87,25 +92,66 @@ def add_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:
inserted_ids.extend(super().add_documents(docs_batch, **batch_kwargs))
return inserted_ids

@backoff.on_exception(backoff.expo, ChromaError, max_time=BACKOFF_MAX_TIME_SECONDS)
def update_last_seen_at(self, ids: list[str], last_seen_at: int | None = None) -> None:
"""Update last_seen_at field in the database."""
"""Update last_seen_at field in the database.

Large updates are split into batches (self.batch_size) to avoid oversized requests.
"""
if not ids:
return

last_seen_at = last_seen_at or int(datetime.now(timezone.utc).timestamp())
for _id in ids:
self.index.update(ids=_id, metadatas=[{"last_seen_at": last_seen_at}])
batch_size = self.batch_size
for ids_batch in batch(ids, batch_size):
self.index.update(ids=ids_batch, metadatas=[{"last_seen_at": last_seen_at} for _ in ids_batch])

@backoff.on_exception(backoff.expo, ChromaError, max_time=BACKOFF_MAX_TIME_DELETE_SECONDS)
def delete_expired(self, expired_ts: int) -> None:
"""Delete expired objects."""
self.index.delete(where={"last_seen_at": {"$lt": expired_ts}}) # type: ignore[dict-item]
"""Delete expired objects.

Fetch all expired IDs first, then delete them in batches using the configured batch_size.
"""
r = self.index.get(where={"last_seen_at": {"$lt": expired_ts}}, include=[]) # type: ignore
ids: list[str] = r.get("ids") or []
if not ids:
return
for ids_batch in batch(ids, self.batch_size):
self.index.delete(ids=ids_batch)

def delete_by_item_id(self, item_id: str) -> None:
self.index.delete(where={"item_id": {"$eq": item_id}}) # type: ignore[dict-item]
"""Delete documents by item_id.

Fetch all documents with the given item_id and delete in batches.
"""
r = self.index.get(where={"item_id": item_id}, include=[])
ids: list[str] = r.get("ids") or []
if not ids:
return

for ids_batch in batch(ids, self.batch_size):
self.index.delete(ids=ids_batch)

def delete(self, ids: list[str] | None = None, **kwargs: Any) -> None:
"""Delete objects by ids.

Delete the object in batches to avoid exceeding the maximum request size.
"""
if not ids:
return

for ids_batch in batch(ids, self.batch_size):
super().delete(ids=ids_batch, **kwargs)

def delete_all(self) -> None:
"""Delete all objects."""
"""Delete all objects.

Delete the object in batches to avoid exceeding the maximum request size.
"""
r = self.index.get()
if r["ids"]:
self.delete(ids=r["ids"])
for ids_batch in batch(r["ids"], self.batch_size):
self.index.delete(ids=ids_batch)

def search_by_vector(self, vector: list[float], k: int = 1_000_000, filter_: dict | None = None) -> list[Document]:
"""Search by vector."""
Expand Down
1 change: 1 addition & 0 deletions code/tests/test_pinecone_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def delete_ns(namespace: str) -> None:
delete_ns(NAMESPACE1)
delete_ns(NAMESPACE2)


@pytest.mark.integration()
@pytest.mark.skipif("db_pinecone" not in DATABASE_FIXTURES, reason="pinecone database is not enabled")
def test_namespace(db_pinecone_ns: PineconeDatabase) -> None:
Expand Down
Loading