|
| 1 | +""" |
| 2 | +Qdrant vector search implementation for amnicontext. |
| 3 | +""" |
| 4 | + |
| 5 | +import logging |
| 6 | +import time |
| 7 | +import traceback |
| 8 | +import uuid |
| 9 | +from typing import Optional, Dict, Any |
| 10 | + |
| 11 | +from ..embeddings import ( |
| 12 | + EmbeddingsResults, |
| 13 | + EmbeddingsResult, |
| 14 | + EmbeddingsMetadata, |
| 15 | + SearchResult, |
| 16 | + SearchResults, |
| 17 | +) |
| 18 | +from .base import VectorDB |
| 19 | + |
| 20 | + |
| 21 | +class QdrantVectorDB(VectorDB): |
| 22 | + """Qdrant implementation of the VectorDB interface.""" |
| 23 | + |
| 24 | + CONTENT_KEY = "content" |
| 25 | + METADATA_KEY = "metadata" |
| 26 | + DEFAULT_DISTANCE = "cosine" |
| 27 | + |
| 28 | + def _parse_distance(self, distance_config): |
| 29 | + from qdrant_client.models import Distance |
| 30 | + |
| 31 | + distance_map = { |
| 32 | + "cosine": Distance.COSINE, |
| 33 | + "euclid": Distance.EUCLID, |
| 34 | + "euclidean": Distance.EUCLID, |
| 35 | + "dot": Distance.DOT, |
| 36 | + "manhattan": Distance.MANHATTAN, |
| 37 | + } |
| 38 | + |
| 39 | + distance_config_lower = distance_config.lower() |
| 40 | + if distance_config_lower in distance_map: |
| 41 | + return distance_map[distance_config_lower] |
| 42 | + else: |
| 43 | + raise ValueError( |
| 44 | + f"Unsupported distance: {distance_config}. Supported: {list(distance_map.keys())}" |
| 45 | + ) |
| 46 | + |
| 47 | + def __init__(self, config: Dict[str, Any]): |
| 48 | + from qdrant_client import QdrantClient |
| 49 | + from qdrant_client.models import Distance |
| 50 | + |
| 51 | + self.distance = self._parse_distance( |
| 52 | + config.get("qdrant_distance", self.DEFAULT_DISTANCE) |
| 53 | + ) |
| 54 | + self.vector_size = config.get("qdrant_vector_size", 1536) |
| 55 | + |
| 56 | + client_config = { |
| 57 | + "url": config.get("qdrant_url"), |
| 58 | + "host": config.get("qdrant_host"), |
| 59 | + "port": config.get("qdrant_port"), |
| 60 | + "grpc_port": config.get("qdrant_grpc_port"), |
| 61 | + "prefer_grpc": config.get("qdrant_prefer_grpc"), |
| 62 | + "path": config.get("qdrant_path"), |
| 63 | + "api_key": config.get("qdrant_api_key"), |
| 64 | + "timeout": config.get("qdrant_timeout"), |
| 65 | + } |
| 66 | + client_config = {k: v for k, v in client_config.items() if v is not None} |
| 67 | + |
| 68 | + self.client = QdrantClient(**client_config) |
| 69 | + |
| 70 | + def has_collection(self, collection_name: str) -> bool: |
| 71 | + return self.client.collection_exists(collection_name=collection_name) |
| 72 | + |
| 73 | + def delete_collection(self, collection_name: str): |
| 74 | + if self.has_collection(collection_name): |
| 75 | + self.client.delete_collection(collection_name=collection_name) |
| 76 | + |
| 77 | + def _get_or_create_collection(self, collection_name: str): |
| 78 | + from qdrant_client.models import VectorParams |
| 79 | + |
| 80 | + if not self.has_collection(collection_name): |
| 81 | + self.client.create_collection( |
| 82 | + collection_name=collection_name, |
| 83 | + vectors_config=VectorParams( |
| 84 | + size=self.vector_size, distance=self.distance |
| 85 | + ), |
| 86 | + ) |
| 87 | + |
| 88 | + def _convert_id_to_uuid(self, id_str: str): |
| 89 | + """Qdrant requires point IDs to be either +ve integers or UUIDs. |
| 90 | + This method converts string IDs to UUIDs deterministically. |
| 91 | + """ |
| 92 | + try: |
| 93 | + return uuid.UUID(id_str) |
| 94 | + except (ValueError, AttributeError): |
| 95 | + return uuid.uuid5(uuid.NAMESPACE_DNS, str(id_str)) |
| 96 | + |
| 97 | + def _build_filter(self, filter: Optional[dict]) -> Optional[Any]: |
| 98 | + if not filter: |
| 99 | + return None |
| 100 | + |
| 101 | + from qdrant_client.models import Filter, FieldCondition, MatchValue |
| 102 | + |
| 103 | + conditions = [] |
| 104 | + for key, value in filter.items(): |
| 105 | + if value is not None: |
| 106 | + filter_key = ( |
| 107 | + f"metadata.{key}" if not key.startswith("metadata.") else key |
| 108 | + ) |
| 109 | + conditions.append( |
| 110 | + FieldCondition(key=filter_key, match=MatchValue(value=value)) |
| 111 | + ) |
| 112 | + |
| 113 | + if not conditions: |
| 114 | + return None |
| 115 | + |
| 116 | + return Filter(must=conditions) |
| 117 | + |
| 118 | + def search( |
| 119 | + self, |
| 120 | + collection_name: str, |
| 121 | + vectors: list[list[float | int]], |
| 122 | + filter: dict, |
| 123 | + threshold: float, |
| 124 | + limit: int, |
| 125 | + ) -> Optional[SearchResults]: |
| 126 | + try: |
| 127 | + if not self.has_collection(collection_name): |
| 128 | + return None |
| 129 | + |
| 130 | + qdrant_filter = self._build_filter(filter) |
| 131 | + |
| 132 | + results = self.client.query_points( |
| 133 | + collection_name=collection_name, |
| 134 | + query=vectors[0], |
| 135 | + query_filter=qdrant_filter, |
| 136 | + limit=limit, |
| 137 | + score_threshold=threshold if threshold else None, |
| 138 | + ) |
| 139 | + |
| 140 | + docs = [] |
| 141 | + for result in results.points: |
| 142 | + payload = result.payload or {} |
| 143 | + metadata_dict = payload.get(self.METADATA_KEY, {}) |
| 144 | + metadata_obj = EmbeddingsMetadata.model_validate(metadata_dict) |
| 145 | + logging.debug( |
| 146 | + f"search embedding_result_with_score {result.id}:{result.score}" |
| 147 | + ) |
| 148 | + |
| 149 | + docs.append( |
| 150 | + SearchResult( |
| 151 | + id=str(result.id), |
| 152 | + content=payload.get(self.CONTENT_KEY, ""), |
| 153 | + metadata=metadata_obj, |
| 154 | + score=result.score, |
| 155 | + ) |
| 156 | + ) |
| 157 | + |
| 158 | + return SearchResults(docs=docs, search_at=int(time.time())) |
| 159 | + except Exception as e: |
| 160 | + logging.info(f"Error in search: {e}, trace is {traceback.format_exc()}") |
| 161 | + return None |
| 162 | + |
| 163 | + def _convert_scroll_results(self, results): |
| 164 | + docs = [] |
| 165 | + for result in results: |
| 166 | + payload = result.payload or {} |
| 167 | + metadata_dict = payload.get(self.METADATA_KEY, {}) |
| 168 | + metadata_obj = EmbeddingsMetadata.model_validate(metadata_dict) |
| 169 | + |
| 170 | + docs.append( |
| 171 | + EmbeddingsResult( |
| 172 | + id=str(result.id), |
| 173 | + embedding=result.vector, |
| 174 | + content=payload.get(self.CONTENT_KEY, ""), |
| 175 | + metadata=metadata_obj, |
| 176 | + score=None, |
| 177 | + ) |
| 178 | + ) |
| 179 | + return docs |
| 180 | + |
| 181 | + def query( |
| 182 | + self, collection_name: str, filter: dict, limit: Optional[int] = None |
| 183 | + ) -> Optional[EmbeddingsResults]: |
| 184 | + try: |
| 185 | + if not self.has_collection(collection_name): |
| 186 | + return None |
| 187 | + |
| 188 | + qdrant_filter = self._build_filter(filter) |
| 189 | + |
| 190 | + results, _ = self.client.scroll( |
| 191 | + collection_name=collection_name, |
| 192 | + scroll_filter=qdrant_filter, |
| 193 | + limit=limit if limit else 100, |
| 194 | + with_payload=True, |
| 195 | + with_vectors=False, |
| 196 | + ) |
| 197 | + |
| 198 | + docs = self._convert_scroll_results(results) |
| 199 | + |
| 200 | + return EmbeddingsResults(docs=docs, retrieved_at=int(time.time())) |
| 201 | + except: |
| 202 | + return None |
| 203 | + |
| 204 | + def get(self, collection_name: str) -> Optional[EmbeddingsResults]: |
| 205 | + if not self.has_collection(collection_name): |
| 206 | + return None |
| 207 | + |
| 208 | + all_results = [] |
| 209 | + offset = None |
| 210 | + while True: |
| 211 | + batch, offset = self.client.scroll( |
| 212 | + collection_name=collection_name, |
| 213 | + limit=100, |
| 214 | + offset=offset, |
| 215 | + with_payload=True, |
| 216 | + with_vectors=True, |
| 217 | + ) |
| 218 | + if not batch: |
| 219 | + break |
| 220 | + all_results.extend(batch) |
| 221 | + if offset is None: |
| 222 | + break |
| 223 | + |
| 224 | + docs = self._convert_scroll_results(all_results) |
| 225 | + return EmbeddingsResults(docs=docs, retrieved_at=int(time.time())) |
| 226 | + |
| 227 | + def insert(self, collection_name: str, items: list[EmbeddingsResult]): |
| 228 | + from qdrant_client.models import PointStruct |
| 229 | + |
| 230 | + if not items: |
| 231 | + return |
| 232 | + |
| 233 | + self._get_or_create_collection(collection_name) |
| 234 | + |
| 235 | + points = [] |
| 236 | + for item in items: |
| 237 | + metadata_dict = item.metadata.model_dump() if item.metadata else {} |
| 238 | + payload = {self.CONTENT_KEY: item.content, self.METADATA_KEY: metadata_dict} |
| 239 | + |
| 240 | + point_id = self._convert_id_to_uuid(item.id) |
| 241 | + |
| 242 | + points.append( |
| 243 | + PointStruct( |
| 244 | + id=point_id, |
| 245 | + vector=item.embedding, |
| 246 | + payload=payload, |
| 247 | + ) |
| 248 | + ) |
| 249 | + |
| 250 | + batch_size = 100 |
| 251 | + for i in range(0, len(points), batch_size): |
| 252 | + batch = points[i : i + batch_size] |
| 253 | + self.client.upsert(collection_name=collection_name, points=batch) |
| 254 | + |
| 255 | + def upsert(self, collection_name: str, items: list[EmbeddingsResult]): |
| 256 | + from qdrant_client.models import PointStruct |
| 257 | + |
| 258 | + if not items: |
| 259 | + return |
| 260 | + |
| 261 | + self._get_or_create_collection(collection_name) |
| 262 | + |
| 263 | + points = [] |
| 264 | + for item in items: |
| 265 | + metadata_dict = item.metadata.model_dump() if item.metadata else {} |
| 266 | + payload = {self.CONTENT_KEY: item.content, self.METADATA_KEY: metadata_dict} |
| 267 | + |
| 268 | + point_id = self._convert_id_to_uuid(item.id) |
| 269 | + |
| 270 | + points.append( |
| 271 | + PointStruct( |
| 272 | + id=point_id, |
| 273 | + vector=item.embedding, |
| 274 | + payload=payload, |
| 275 | + ) |
| 276 | + ) |
| 277 | + |
| 278 | + self.client.upsert(collection_name=collection_name, points=points) |
| 279 | + |
| 280 | + def delete( |
| 281 | + self, |
| 282 | + collection_name: str, |
| 283 | + ids: Optional[list[str]] = None, |
| 284 | + filter: Optional[dict] = None, |
| 285 | + ): |
| 286 | + try: |
| 287 | + if not self.has_collection(collection_name): |
| 288 | + return |
| 289 | + |
| 290 | + if ids: |
| 291 | + uuid_ids = [self._convert_id_to_uuid(id_str) for id_str in ids] |
| 292 | + self.client.delete( |
| 293 | + collection_name=collection_name, points_selector=uuid_ids |
| 294 | + ) |
| 295 | + elif filter: |
| 296 | + qdrant_filter = self._build_filter(filter) |
| 297 | + if qdrant_filter: |
| 298 | + self.client.delete( |
| 299 | + collection_name=collection_name, points_selector=qdrant_filter |
| 300 | + ) |
| 301 | + else: |
| 302 | + self.delete_collection(collection_name) |
| 303 | + except Exception as e: |
| 304 | + logging.debug( |
| 305 | + f"Attempted to delete from non-existent collection {collection_name}. Ignoring." |
| 306 | + ) |
| 307 | + |
| 308 | + def reset(self): |
| 309 | + collections = self.client.get_collections().collections |
| 310 | + for collection in collections: |
| 311 | + self.client.delete_collection(collection_name=collection.name) |
0 commit comments