Skip to content

Commit c0acd33

Browse files
authored
Merge pull request #647 from Anush008/main
feat: Qdrant Vector Search Support
2 parents e8750a8 + 605d758 commit c0acd33

File tree

4 files changed

+625
-0
lines changed

4 files changed

+625
-0
lines changed

aworld/core/context/amni/retrieval/vector/factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ def get_vector_db(vector_db_config: VectorDBConfig) -> Optional[VectorDB]:
1515
if vector_db_config.provider == "elasticsearch":
1616
from .elasticsearch import ElasticsearchVectorDB
1717
return ElasticsearchVectorDB(vector_db_config.config)
18+
if vector_db_config.provider == "qdrant":
19+
from .qdrant import QdrantVectorDB
20+
return QdrantVectorDB(vector_db_config.config)
1821
else:
1922
raise ValueError(f"Vector database {vector_db_config.provider} is not supported")
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
"""
2+
Qdrant vector search implementation for amnicontext.
3+
"""
4+
5+
import logging
6+
import time
7+
import traceback
8+
import uuid
9+
from typing import Optional, Dict, Any
10+
11+
from ..embeddings import (
12+
EmbeddingsResults,
13+
EmbeddingsResult,
14+
EmbeddingsMetadata,
15+
SearchResult,
16+
SearchResults,
17+
)
18+
from .base import VectorDB
19+
20+
21+
class QdrantVectorDB(VectorDB):
22+
"""Qdrant implementation of the VectorDB interface."""
23+
24+
CONTENT_KEY = "content"
25+
METADATA_KEY = "metadata"
26+
DEFAULT_DISTANCE = "cosine"
27+
28+
def _parse_distance(self, distance_config):
29+
from qdrant_client.models import Distance
30+
31+
distance_map = {
32+
"cosine": Distance.COSINE,
33+
"euclid": Distance.EUCLID,
34+
"euclidean": Distance.EUCLID,
35+
"dot": Distance.DOT,
36+
"manhattan": Distance.MANHATTAN,
37+
}
38+
39+
distance_config_lower = distance_config.lower()
40+
if distance_config_lower in distance_map:
41+
return distance_map[distance_config_lower]
42+
else:
43+
raise ValueError(
44+
f"Unsupported distance: {distance_config}. Supported: {list(distance_map.keys())}"
45+
)
46+
47+
def __init__(self, config: Dict[str, Any]):
48+
from qdrant_client import QdrantClient
49+
from qdrant_client.models import Distance
50+
51+
self.distance = self._parse_distance(
52+
config.get("qdrant_distance", self.DEFAULT_DISTANCE)
53+
)
54+
self.vector_size = config.get("qdrant_vector_size", 1536)
55+
56+
client_config = {
57+
"url": config.get("qdrant_url"),
58+
"host": config.get("qdrant_host"),
59+
"port": config.get("qdrant_port"),
60+
"grpc_port": config.get("qdrant_grpc_port"),
61+
"prefer_grpc": config.get("qdrant_prefer_grpc"),
62+
"path": config.get("qdrant_path"),
63+
"api_key": config.get("qdrant_api_key"),
64+
"timeout": config.get("qdrant_timeout"),
65+
}
66+
client_config = {k: v for k, v in client_config.items() if v is not None}
67+
68+
self.client = QdrantClient(**client_config)
69+
70+
def has_collection(self, collection_name: str) -> bool:
71+
return self.client.collection_exists(collection_name=collection_name)
72+
73+
def delete_collection(self, collection_name: str):
74+
if self.has_collection(collection_name):
75+
self.client.delete_collection(collection_name=collection_name)
76+
77+
def _get_or_create_collection(self, collection_name: str):
78+
from qdrant_client.models import VectorParams
79+
80+
if not self.has_collection(collection_name):
81+
self.client.create_collection(
82+
collection_name=collection_name,
83+
vectors_config=VectorParams(
84+
size=self.vector_size, distance=self.distance
85+
),
86+
)
87+
88+
def _convert_id_to_uuid(self, id_str: str):
89+
"""Qdrant requires point IDs to be either +ve integers or UUIDs.
90+
This method converts string IDs to UUIDs deterministically.
91+
"""
92+
try:
93+
return uuid.UUID(id_str)
94+
except (ValueError, AttributeError):
95+
return uuid.uuid5(uuid.NAMESPACE_DNS, str(id_str))
96+
97+
def _build_filter(self, filter: Optional[dict]) -> Optional[Any]:
98+
if not filter:
99+
return None
100+
101+
from qdrant_client.models import Filter, FieldCondition, MatchValue
102+
103+
conditions = []
104+
for key, value in filter.items():
105+
if value is not None:
106+
filter_key = (
107+
f"metadata.{key}" if not key.startswith("metadata.") else key
108+
)
109+
conditions.append(
110+
FieldCondition(key=filter_key, match=MatchValue(value=value))
111+
)
112+
113+
if not conditions:
114+
return None
115+
116+
return Filter(must=conditions)
117+
118+
def search(
119+
self,
120+
collection_name: str,
121+
vectors: list[list[float | int]],
122+
filter: dict,
123+
threshold: float,
124+
limit: int,
125+
) -> Optional[SearchResults]:
126+
try:
127+
if not self.has_collection(collection_name):
128+
return None
129+
130+
qdrant_filter = self._build_filter(filter)
131+
132+
results = self.client.query_points(
133+
collection_name=collection_name,
134+
query=vectors[0],
135+
query_filter=qdrant_filter,
136+
limit=limit,
137+
score_threshold=threshold if threshold else None,
138+
)
139+
140+
docs = []
141+
for result in results.points:
142+
payload = result.payload or {}
143+
metadata_dict = payload.get(self.METADATA_KEY, {})
144+
metadata_obj = EmbeddingsMetadata.model_validate(metadata_dict)
145+
logging.debug(
146+
f"search embedding_result_with_score {result.id}:{result.score}"
147+
)
148+
149+
docs.append(
150+
SearchResult(
151+
id=str(result.id),
152+
content=payload.get(self.CONTENT_KEY, ""),
153+
metadata=metadata_obj,
154+
score=result.score,
155+
)
156+
)
157+
158+
return SearchResults(docs=docs, search_at=int(time.time()))
159+
except Exception as e:
160+
logging.info(f"Error in search: {e}, trace is {traceback.format_exc()}")
161+
return None
162+
163+
def _convert_scroll_results(self, results):
164+
docs = []
165+
for result in results:
166+
payload = result.payload or {}
167+
metadata_dict = payload.get(self.METADATA_KEY, {})
168+
metadata_obj = EmbeddingsMetadata.model_validate(metadata_dict)
169+
170+
docs.append(
171+
EmbeddingsResult(
172+
id=str(result.id),
173+
embedding=result.vector,
174+
content=payload.get(self.CONTENT_KEY, ""),
175+
metadata=metadata_obj,
176+
score=None,
177+
)
178+
)
179+
return docs
180+
181+
def query(
182+
self, collection_name: str, filter: dict, limit: Optional[int] = None
183+
) -> Optional[EmbeddingsResults]:
184+
try:
185+
if not self.has_collection(collection_name):
186+
return None
187+
188+
qdrant_filter = self._build_filter(filter)
189+
190+
results, _ = self.client.scroll(
191+
collection_name=collection_name,
192+
scroll_filter=qdrant_filter,
193+
limit=limit if limit else 100,
194+
with_payload=True,
195+
with_vectors=False,
196+
)
197+
198+
docs = self._convert_scroll_results(results)
199+
200+
return EmbeddingsResults(docs=docs, retrieved_at=int(time.time()))
201+
except:
202+
return None
203+
204+
def get(self, collection_name: str) -> Optional[EmbeddingsResults]:
205+
if not self.has_collection(collection_name):
206+
return None
207+
208+
all_results = []
209+
offset = None
210+
while True:
211+
batch, offset = self.client.scroll(
212+
collection_name=collection_name,
213+
limit=100,
214+
offset=offset,
215+
with_payload=True,
216+
with_vectors=True,
217+
)
218+
if not batch:
219+
break
220+
all_results.extend(batch)
221+
if offset is None:
222+
break
223+
224+
docs = self._convert_scroll_results(all_results)
225+
return EmbeddingsResults(docs=docs, retrieved_at=int(time.time()))
226+
227+
def insert(self, collection_name: str, items: list[EmbeddingsResult]):
228+
from qdrant_client.models import PointStruct
229+
230+
if not items:
231+
return
232+
233+
self._get_or_create_collection(collection_name)
234+
235+
points = []
236+
for item in items:
237+
metadata_dict = item.metadata.model_dump() if item.metadata else {}
238+
payload = {self.CONTENT_KEY: item.content, self.METADATA_KEY: metadata_dict}
239+
240+
point_id = self._convert_id_to_uuid(item.id)
241+
242+
points.append(
243+
PointStruct(
244+
id=point_id,
245+
vector=item.embedding,
246+
payload=payload,
247+
)
248+
)
249+
250+
batch_size = 100
251+
for i in range(0, len(points), batch_size):
252+
batch = points[i : i + batch_size]
253+
self.client.upsert(collection_name=collection_name, points=batch)
254+
255+
def upsert(self, collection_name: str, items: list[EmbeddingsResult]):
256+
from qdrant_client.models import PointStruct
257+
258+
if not items:
259+
return
260+
261+
self._get_or_create_collection(collection_name)
262+
263+
points = []
264+
for item in items:
265+
metadata_dict = item.metadata.model_dump() if item.metadata else {}
266+
payload = {self.CONTENT_KEY: item.content, self.METADATA_KEY: metadata_dict}
267+
268+
point_id = self._convert_id_to_uuid(item.id)
269+
270+
points.append(
271+
PointStruct(
272+
id=point_id,
273+
vector=item.embedding,
274+
payload=payload,
275+
)
276+
)
277+
278+
self.client.upsert(collection_name=collection_name, points=points)
279+
280+
def delete(
281+
self,
282+
collection_name: str,
283+
ids: Optional[list[str]] = None,
284+
filter: Optional[dict] = None,
285+
):
286+
try:
287+
if not self.has_collection(collection_name):
288+
return
289+
290+
if ids:
291+
uuid_ids = [self._convert_id_to_uuid(id_str) for id_str in ids]
292+
self.client.delete(
293+
collection_name=collection_name, points_selector=uuid_ids
294+
)
295+
elif filter:
296+
qdrant_filter = self._build_filter(filter)
297+
if qdrant_filter:
298+
self.client.delete(
299+
collection_name=collection_name, points_selector=qdrant_filter
300+
)
301+
else:
302+
self.delete_collection(collection_name)
303+
except Exception as e:
304+
logging.debug(
305+
f"Attempted to delete from non-existent collection {collection_name}. Ignoring."
306+
)
307+
308+
def reset(self):
309+
collections = self.client.get_collections().collections
310+
for collection in collections:
311+
self.client.delete_collection(collection_name=collection.name)

0 commit comments

Comments
 (0)