diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml
index 7c133cd5f..ea346dcd0 100644
--- a/.github/workflows/pull_request.yml
+++ b/.github/workflows/pull_request.yml
@@ -31,6 +31,10 @@ jobs:
python -m pip install --upgrade pip
pip install -e ".[test]"
+ - name: Run coding checks
+ run: |
+ make lint
+
- name: Test with pytest
run: |
make unittest
diff --git a/.ruff.toml b/.ruff.toml
deleted file mode 100644
index fb6b19c3d..000000000
--- a/.ruff.toml
+++ /dev/null
@@ -1,49 +0,0 @@
-# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
-# Enable flake8-bugbear (`B`) rules.
-select = ["E", "F", "B"]
-ignore = [
- "E501", # (line length violations)
-]
-
-# Allow autofix for all enabled rules (when `--fix`) is provided.
-fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM",
- "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT",
- "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT",
-]
-unfixable = []
-
-# Exclude a variety of commonly ignored directories.
-exclude = [
- ".bzr",
- ".direnv",
- ".eggs",
- ".git",
- ".hg",
- ".mypy_cache",
- ".nox",
- ".pants.d",
- ".pytype",
- ".ruff_cache",
- ".svn",
- ".tox",
- ".venv",
- "__pypackages__",
- "_build",
- "buck-out",
- "build",
- "dist",
- "node_modules",
- "venv",
- "__pycache__",
- "__init__.py",
-]
-
-# Allow unused variables when underscore-prefixed.
-dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
-
-# Assume Python 3.11.
-target-version = "py311"
-
-[mccabe]
-# Unlike Flake8, default to a complexity level of 10.
-max-complexity = 10
diff --git a/Makefile b/Makefile
index 562615f6d..ef8207c55 100644
--- a/Makefile
+++ b/Makefile
@@ -1,2 +1,10 @@
unittest:
PYTHONPATH=`pwd` python3 -m pytest tests/test_dataset.py::TestDataSet::test_download_small -svv
+
+format:
+ PYTHONPATH=`pwd` python3 -m black vectordb_bench
+ PYTHONPATH=`pwd` python3 -m ruff check vectordb_bench --fix
+
+lint:
+ PYTHONPATH=`pwd` python3 -m black vectordb_bench --check
+ PYTHONPATH=`pwd` python3 -m ruff check vectordb_bench
diff --git a/README.md b/README.md
index 56d54c49d..737fc6064 100644
--- a/README.md
+++ b/README.md
@@ -240,13 +240,13 @@ After reopen the repository in container, run `python -m vectordb_bench` in the
### Check coding styles
```shell
-$ ruff check vectordb_bench
+$ make lint
```
-Add `--fix` if you want to fix the coding styles automatically
+To fix the coding styles automatically
```shell
-$ ruff check vectordb_bench --fix
+$ make format
```
## How does it work?
diff --git a/pyproject.toml b/pyproject.toml
index 8363aa4fa..312940634 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,6 +44,7 @@ dynamic = ["version"]
[project.optional-dependencies]
test = [
+ "black",
"ruff",
"pytest",
]
@@ -93,3 +94,116 @@ init_bench = "vectordb_bench.__main__:main"
vectordbbench = "vectordb_bench.cli.vectordbbench:cli"
[tool.setuptools_scm]
+
+[tool.black]
+line-length = 120
+target-version = ['py311']
+include = '\.pyi?$'
+
+[tool.ruff]
+lint.select = [
+ "E",
+ "F",
+ "C90",
+ "I",
+ "N",
+ "B", "C", "G",
+ "A",
+ "ANN001",
+ "S", "T", "W", "ARG", "BLE", "COM", "DJ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"
+]
+lint.ignore = [
+ "BLE001", # blind-except (BLE001)
+ "SLF001", # SLF001 Private member accessed [E]
+ "TRY003", # [ruff] TRY003 Avoid specifying long messages outside the exception class [E]
+ "FBT001", "FBT002", "FBT003",
+ "G004", # [ruff] G004 Logging statement uses f-string [E]
+ "UP031",
+ "RUF012",
+ "EM101",
+ "N805",
+ "ARG002",
+ "ARG003",
+ "PIE796", # https://github.com/zilliztech/VectorDBBench/issues/438
+ "INP001", # TODO
+ "TID252", # TODO
+ "N801", "N802", "N815",
+ "S101", "S108", "S603", "S311",
+ "PLR2004",
+ "RUF017",
+ "C416",
+ "PLW0603",
+]
+
+# Allow autofix for all enabled rules (when `--fix`) is provided.
+lint.fixable = [
+ "A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W",
+ "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT",
+ "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH",
+ "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP",
+ "YTT",
+]
+lint.unfixable = []
+
+show-fixes = true
+
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".hg",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "venv",
+ "grpc_gen",
+ "__pycache__",
+ "frontend", # TODO
+ "tests",
+]
+
+# Same as Black.
+line-length = 120
+
+# Allow unused variables when underscore-prefixed.
+lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+
+# Assume Python 3.11
+target-version = "py311"
+
+[tool.ruff.lint.mccabe]
+# Unlike Flake8, default to a complexity level of 10.
+max-complexity = 18
+
+[tool.ruff.lint.pycodestyle]
+max-line-length = 120
+max-doc-length = 120
+
+[tool.ruff.lint.pylint]
+max-args = 20
+max-branches = 15
+
+[tool.ruff.lint.flake8-builtins]
+builtins-ignorelist = [
+ # "format",
+ # "next",
+ # "object", # TODO
+ # "id",
+ # "dict", # TODO
+ # "filter",
+]
+
diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py
index 568e97705..c07fc855d 100644
--- a/vectordb_bench/__init__.py
+++ b/vectordb_bench/__init__.py
@@ -22,46 +22,71 @@ class config:
DROP_OLD = env.bool("DROP_OLD", True)
USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True)
- NUM_CONCURRENCY = env.list("NUM_CONCURRENCY", [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100], subcast=int )
+ NUM_CONCURRENCY = env.list(
+ "NUM_CONCURRENCY",
+ [
+ 1,
+ 5,
+ 10,
+ 15,
+ 20,
+ 25,
+ 30,
+ 35,
+ 40,
+ 45,
+ 50,
+ 55,
+ 60,
+ 65,
+ 70,
+ 75,
+ 80,
+ 85,
+ 90,
+ 95,
+ 100,
+ ],
+ subcast=int,
+ )
CONCURRENCY_DURATION = 30
RESULTS_LOCAL_DIR = env.path(
- "RESULTS_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("results")
+ "RESULTS_LOCAL_DIR",
+ pathlib.Path(__file__).parent.joinpath("results"),
)
CONFIG_LOCAL_DIR = env.path(
- "CONFIG_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("config-files")
+ "CONFIG_LOCAL_DIR",
+ pathlib.Path(__file__).parent.joinpath("config-files"),
)
-
K_DEFAULT = 100 # default return top k nearest neighbors during search
CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json")
- CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
- LOAD_TIMEOUT_DEFAULT = 24 * 3600 # 24h
- LOAD_TIMEOUT_768D_1M = 24 * 3600 # 24h
- LOAD_TIMEOUT_768D_10M = 240 * 3600 # 10d
- LOAD_TIMEOUT_768D_100M = 2400 * 3600 # 100d
+ CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h
+ LOAD_TIMEOUT_DEFAULT = 24 * 3600 # 24h
+ LOAD_TIMEOUT_768D_1M = 24 * 3600 # 24h
+ LOAD_TIMEOUT_768D_10M = 240 * 3600 # 10d
+ LOAD_TIMEOUT_768D_100M = 2400 * 3600 # 100d
- LOAD_TIMEOUT_1536D_500K = 24 * 3600 # 24h
- LOAD_TIMEOUT_1536D_5M = 240 * 3600 # 10d
+ LOAD_TIMEOUT_1536D_500K = 24 * 3600 # 24h
+ LOAD_TIMEOUT_1536D_5M = 240 * 3600 # 10d
- OPTIMIZE_TIMEOUT_DEFAULT = 24 * 3600 # 24h
- OPTIMIZE_TIMEOUT_768D_1M = 24 * 3600 # 24h
- OPTIMIZE_TIMEOUT_768D_10M = 240 * 3600 # 10d
- OPTIMIZE_TIMEOUT_768D_100M = 2400 * 3600 # 100d
+ OPTIMIZE_TIMEOUT_DEFAULT = 24 * 3600 # 24h
+ OPTIMIZE_TIMEOUT_768D_1M = 24 * 3600 # 24h
+ OPTIMIZE_TIMEOUT_768D_10M = 240 * 3600 # 10d
+ OPTIMIZE_TIMEOUT_768D_100M = 2400 * 3600 # 100d
+ OPTIMIZE_TIMEOUT_1536D_500K = 24 * 3600 # 24h
+ OPTIMIZE_TIMEOUT_1536D_5M = 240 * 3600 # 10d
- OPTIMIZE_TIMEOUT_1536D_500K = 24 * 3600 # 24h
- OPTIMIZE_TIMEOUT_1536D_5M = 240 * 3600 # 10d
-
def display(self) -> str:
- tmp = [
- i for i in inspect.getmembers(self)
- if not inspect.ismethod(i[1])
- and not i[0].startswith('_')
- and "TIMEOUT" not in i[0]
+ return [
+ i
+ for i in inspect.getmembers(self)
+ if not inspect.ismethod(i[1]) and not i[0].startswith("_") and "TIMEOUT" not in i[0]
]
- return tmp
+
log_util.init(config.LOG_LEVEL)
diff --git a/vectordb_bench/__main__.py b/vectordb_bench/__main__.py
index a8b4436bc..6663731f5 100644
--- a/vectordb_bench/__main__.py
+++ b/vectordb_bench/__main__.py
@@ -1,7 +1,8 @@
-import traceback
import logging
+import pathlib
import subprocess
-import os
+import traceback
+
from . import config
log = logging.getLogger("vectordb_bench")
@@ -16,7 +17,7 @@ def run_streamlit():
cmd = [
"streamlit",
"run",
- f"{os.path.dirname(__file__)}/frontend/vdb_benchmark.py",
+ f"{pathlib.Path(__file__).parent}/frontend/vdb_benchmark.py",
"--logger.level",
"info",
"--theme.base",
diff --git a/vectordb_bench/backend/assembler.py b/vectordb_bench/backend/assembler.py
index e7da4d49f..b81d315c2 100644
--- a/vectordb_bench/backend/assembler.py
+++ b/vectordb_bench/backend/assembler.py
@@ -1,24 +1,25 @@
-from .cases import CaseLabel
-from .task_runner import CaseRunner, RunningStatus, TaskRunner
-from ..models import TaskConfig
-from ..backend.clients import EmptyDBCaseConfig
-from ..backend.data_source import DatasetSource
import logging
+from vectordb_bench.backend.clients import EmptyDBCaseConfig
+from vectordb_bench.backend.data_source import DatasetSource
+from vectordb_bench.models import TaskConfig
+
+from .cases import CaseLabel
+from .task_runner import CaseRunner, RunningStatus, TaskRunner
log = logging.getLogger(__name__)
class Assembler:
@classmethod
- def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner:
+ def assemble(cls, run_id: str, task: TaskConfig, source: DatasetSource) -> CaseRunner:
c_cls = task.case_config.case_id.case_cls
c = c_cls(task.case_config.custom_case)
- if type(task.db_case_config) != EmptyDBCaseConfig:
+ if type(task.db_case_config) is not EmptyDBCaseConfig:
task.db_case_config.metric_type = c.dataset.data.metric_type
- runner = CaseRunner(
+ return CaseRunner(
run_id=run_id,
config=task,
ca=c,
@@ -26,8 +27,6 @@ def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunne
dataset_source=source,
)
- return runner
-
@classmethod
def assemble_all(
cls,
@@ -50,12 +49,12 @@ def assemble_all(
db2runner[db].append(r)
# check dbclient installed
- for k in db2runner.keys():
+ for k in db2runner:
_ = k.init_cls
# sort by dataset size
- for k in db2runner.keys():
- db2runner[k].sort(key=lambda x:x.ca.dataset.data.size)
+ for k, _ in db2runner:
+ db2runner[k].sort(key=lambda x: x.ca.dataset.data.size)
all_runners = []
all_runners.extend(load_runners)
diff --git a/vectordb_bench/backend/cases.py b/vectordb_bench/backend/cases.py
index 8b643b66b..15fc069cc 100644
--- a/vectordb_bench/backend/cases.py
+++ b/vectordb_bench/backend/cases.py
@@ -1,7 +1,5 @@
-import typing
import logging
from enum import Enum, auto
-from typing import Type
from vectordb_bench import config
from vectordb_bench.backend.clients.api import MetricType
@@ -12,7 +10,6 @@
from .dataset import CustomDataset, Dataset, DatasetManager
-
log = logging.getLogger(__name__)
@@ -50,11 +47,10 @@ class CaseType(Enum):
Custom = 100
PerformanceCustomDataset = 101
- def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]:
+ def case_cls(self, custom_configs: dict | None = None) -> type["Case"]:
if custom_configs is None:
return type2case.get(self)()
- else:
- return type2case.get(self)(**custom_configs)
+ return type2case.get(self)(**custom_configs)
def case_name(self, custom_configs: dict | None = None) -> str:
c = self.case_cls(custom_configs)
@@ -99,10 +95,10 @@ class Case(BaseModel):
@property
def filters(self) -> dict | None:
if self.filter_rate is not None:
- ID = round(self.filter_rate * self.dataset.data.size)
+ target_id = round(self.filter_rate * self.dataset.data.size)
return {
- "metadata": f">={ID}",
- "id": ID,
+ "metadata": f">={target_id}",
+ "id": target_id,
}
return None
@@ -126,8 +122,8 @@ class CapacityDim960(CapacityCase):
case_id: CaseType = CaseType.CapacityDim960
dataset: DatasetManager = Dataset.GIST.manager(100_000)
name: str = "Capacity Test (960 Dim Repeated)"
- description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension
- vectors (GIST 100K vectors, 960 dimensions) until it is fully loaded. Number of inserted vectors will be
+ description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension
+ vectors (GIST 100K vectors, 960 dimensions) until it is fully loaded. Number of inserted vectors will be
reported."""
@@ -136,7 +132,7 @@ class CapacityDim128(CapacityCase):
dataset: DatasetManager = Dataset.SIFT.manager(500_000)
name: str = "Capacity Test (128 Dim Repeated)"
description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension
- vectors (SIFT 100K vectors, 128 dimensions) until it is fully loaded. Number of inserted vectors will be
+ vectors (SIFT 100K vectors, 128 dimensions) until it is fully loaded. Number of inserted vectors will be
reported."""
@@ -144,8 +140,9 @@ class Performance768D10M(PerformanceCase):
case_id: CaseType = CaseType.Performance768D10M
dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
name: str = "Search Performance Test (10M Dataset, 768 Dim)"
- description: str = """This case tests the search performance of a vector database with a large dataset (Cohere 10M vectors, 768 dimensions) at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a large dataset
+ (Cohere 10M vectors, 768 dimensions) at varying parallel levels.
+ Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_768D_10M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_10M
@@ -154,8 +151,9 @@ class Performance768D1M(PerformanceCase):
case_id: CaseType = CaseType.Performance768D1M
dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
name: str = "Search Performance Test (1M Dataset, 768 Dim)"
- description: str = """This case tests the search performance of a vector database with a medium dataset (Cohere 1M vectors, 768 dimensions) at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium dataset
+ (Cohere 1M vectors, 768 dimensions) at varying parallel levels.
+ Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_768D_1M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_1M
@@ -165,8 +163,9 @@ class Performance768D10M1P(PerformanceCase):
filter_rate: float | int | None = 0.01
dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 1%)"
- description: str = """This case tests the search performance of a vector database with a large dataset (Cohere 10M vectors, 768 dimensions) under a low filtering rate (1% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a large dataset
+ (Cohere 10M vectors, 768 dimensions) under a low filtering rate (1% vectors), at varying parallel
+ levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_768D_10M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_10M
@@ -176,8 +175,9 @@ class Performance768D1M1P(PerformanceCase):
filter_rate: float | int | None = 0.01
dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 1%)"
- description: str = """This case tests the search performance of a vector database with a medium dataset (Cohere 1M vectors, 768 dimensions) under a low filtering rate (1% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium dataset
+ (Cohere 1M vectors, 768 dimensions) under a low filtering rate (1% vectors),
+ at varying parallel levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_768D_1M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_1M
@@ -187,8 +187,9 @@ class Performance768D10M99P(PerformanceCase):
filter_rate: float | int | None = 0.99
dataset: DatasetManager = Dataset.COHERE.manager(10_000_000)
name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 99%)"
- description: str = """This case tests the search performance of a vector database with a large dataset (Cohere 10M vectors, 768 dimensions) under a high filtering rate (99% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a large dataset
+ (Cohere 10M vectors, 768 dimensions) under a high filtering rate (99% vectors),
+ at varying parallel levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_768D_10M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_10M
@@ -198,8 +199,9 @@ class Performance768D1M99P(PerformanceCase):
filter_rate: float | int | None = 0.99
dataset: DatasetManager = Dataset.COHERE.manager(1_000_000)
name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 99%)"
- description: str = """This case tests the search performance of a vector database with a medium dataset (Cohere 1M vectors, 768 dimensions) under a high filtering rate (99% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium dataset
+ (Cohere 1M vectors, 768 dimensions) under a high filtering rate (99% vectors),
+ at varying parallel levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_768D_1M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_1M
@@ -209,8 +211,9 @@ class Performance768D100M(PerformanceCase):
filter_rate: float | int | None = None
dataset: DatasetManager = Dataset.LAION.manager(100_000_000)
name: str = "Search Performance Test (100M Dataset, 768 Dim)"
- description: str = """This case tests the search performance of a vector database with a large 100M dataset (LAION 100M vectors, 768 dimensions), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a large 100M dataset
+ (LAION 100M vectors, 768 dimensions), at varying parallel levels. Results will show index building time,
+ recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_768D_100M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_100M
@@ -220,8 +223,9 @@ class Performance1536D500K(PerformanceCase):
filter_rate: float | int | None = None
dataset: DatasetManager = Dataset.OPENAI.manager(500_000)
name: str = "Search Performance Test (500K Dataset, 1536 Dim)"
- description: str = """This case tests the search performance of a vector database with a medium 500K dataset (OpenAI 500K vectors, 1536 dimensions), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium 500K dataset
+ (OpenAI 500K vectors, 1536 dimensions), at varying parallel levels. Results will show index building time,
+ recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
@@ -231,8 +235,9 @@ class Performance1536D5M(PerformanceCase):
filter_rate: float | int | None = None
dataset: DatasetManager = Dataset.OPENAI.manager(5_000_000)
name: str = "Search Performance Test (5M Dataset, 1536 Dim)"
- description: str = """This case tests the search performance of a vector database with a medium 5M dataset (OpenAI 5M vectors, 1536 dimensions), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium 5M dataset
+ (OpenAI 5M vectors, 1536 dimensions), at varying parallel levels. Results will show index building time,
+ recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
@@ -242,8 +247,9 @@ class Performance1536D500K1P(PerformanceCase):
filter_rate: float | int | None = 0.01
dataset: DatasetManager = Dataset.OPENAI.manager(500_000)
name: str = "Filtering Search Performance Test (500K Dataset, 1536 Dim, Filter 1%)"
- description: str = """This case tests the search performance of a vector database with a large dataset (OpenAI 500K vectors, 1536 dimensions) under a low filtering rate (1% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a large dataset
+ (OpenAI 500K vectors, 1536 dimensions) under a low filtering rate (1% vectors),
+ at varying parallel levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
@@ -253,8 +259,9 @@ class Performance1536D5M1P(PerformanceCase):
filter_rate: float | int | None = 0.01
dataset: DatasetManager = Dataset.OPENAI.manager(5_000_000)
name: str = "Filtering Search Performance Test (5M Dataset, 1536 Dim, Filter 1%)"
- description: str = """This case tests the search performance of a vector database with a large dataset (OpenAI 5M vectors, 1536 dimensions) under a low filtering rate (1% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a large dataset
+ (OpenAI 5M vectors, 1536 dimensions) under a low filtering rate (1% vectors),
+ at varying parallel levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
@@ -264,8 +271,9 @@ class Performance1536D500K99P(PerformanceCase):
filter_rate: float | int | None = 0.99
dataset: DatasetManager = Dataset.OPENAI.manager(500_000)
name: str = "Filtering Search Performance Test (500K Dataset, 1536 Dim, Filter 99%)"
- description: str = """This case tests the search performance of a vector database with a medium dataset (OpenAI 500K vectors, 1536 dimensions) under a high filtering rate (99% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium dataset
+ (OpenAI 500K vectors, 1536 dimensions) under a high filtering rate (99% vectors),
+ at varying parallel levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K
@@ -275,8 +283,9 @@ class Performance1536D5M99P(PerformanceCase):
filter_rate: float | int | None = 0.99
dataset: DatasetManager = Dataset.OPENAI.manager(5_000_000)
name: str = "Filtering Search Performance Test (5M Dataset, 1536 Dim, Filter 99%)"
- description: str = """This case tests the search performance of a vector database with a medium dataset (OpenAI 5M vectors, 1536 dimensions) under a high filtering rate (99% vectors), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium dataset
+ (OpenAI 5M vectors, 1536 dimensions) under a high filtering rate (99% vectors),
+ at varying parallel levels. Results will show index building time, recall, and maximum QPS."""
load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M
@@ -286,8 +295,9 @@ class Performance1536D50K(PerformanceCase):
filter_rate: float | int | None = None
dataset: DatasetManager = Dataset.OPENAI.manager(50_000)
name: str = "Search Performance Test (50K Dataset, 1536 Dim)"
- description: str = """This case tests the search performance of a vector database with a medium 50K dataset (OpenAI 50K vectors, 1536 dimensions), at varying parallel levels.
-Results will show index building time, recall, and maximum QPS."""
+ description: str = """This case tests the search performance of a vector database with a medium 50K dataset
+ (OpenAI 50K vectors, 1536 dimensions), at varying parallel levels. Results will show index building time,
+ recall, and maximum QPS."""
load_timeout: float | int = 3600
optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_DEFAULT
@@ -312,11 +322,11 @@ class PerformanceCustomDataset(PerformanceCase):
def __init__(
self,
- name,
- description,
- load_timeout,
- optimize_timeout,
- dataset_config,
+ name: str,
+ description: str,
+ load_timeout: float,
+ optimize_timeout: float,
+ dataset_config: dict,
**kwargs,
):
dataset_config = CustomDatasetConfig(**dataset_config)
diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py
index e1b66a81d..773cd4948 100644
--- a/vectordb_bench/backend/clients/__init__.py
+++ b/vectordb_bench/backend/clients/__init__.py
@@ -1,12 +1,12 @@
from enum import Enum
-from typing import Type
+
from .api import (
- VectorDB,
- DBConfig,
DBCaseConfig,
+ DBConfig,
EmptyDBCaseConfig,
IndexType,
MetricType,
+ VectorDB,
)
@@ -41,200 +41,255 @@ class DB(Enum):
Test = "test"
AliyunOpenSearch = "AliyunOpenSearch"
-
@property
- def init_cls(self) -> Type[VectorDB]:
+ def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912
"""Import while in use"""
if self == DB.Milvus:
from .milvus.milvus import Milvus
+
return Milvus
if self == DB.ZillizCloud:
from .zilliz_cloud.zilliz_cloud import ZillizCloud
+
return ZillizCloud
if self == DB.Pinecone:
from .pinecone.pinecone import Pinecone
+
return Pinecone
if self == DB.ElasticCloud:
from .elastic_cloud.elastic_cloud import ElasticCloud
+
return ElasticCloud
if self == DB.QdrantCloud:
from .qdrant_cloud.qdrant_cloud import QdrantCloud
+
return QdrantCloud
if self == DB.WeaviateCloud:
from .weaviate_cloud.weaviate_cloud import WeaviateCloud
+
return WeaviateCloud
if self == DB.PgVector:
from .pgvector.pgvector import PgVector
+
return PgVector
if self == DB.PgVectoRS:
from .pgvecto_rs.pgvecto_rs import PgVectoRS
+
return PgVectoRS
-
+
if self == DB.PgVectorScale:
from .pgvectorscale.pgvectorscale import PgVectorScale
+
return PgVectorScale
if self == DB.PgDiskANN:
from .pgdiskann.pgdiskann import PgDiskANN
+
return PgDiskANN
if self == DB.Redis:
from .redis.redis import Redis
+
return Redis
-
+
if self == DB.MemoryDB:
from .memorydb.memorydb import MemoryDB
+
return MemoryDB
if self == DB.Chroma:
from .chroma.chroma import ChromaClient
+
return ChromaClient
if self == DB.AWSOpenSearch:
from .aws_opensearch.aws_opensearch import AWSOpenSearch
+
return AWSOpenSearch
-
+
if self == DB.AlloyDB:
from .alloydb.alloydb import AlloyDB
+
return AlloyDB
if self == DB.AliyunElasticsearch:
from .aliyun_elasticsearch.aliyun_elasticsearch import AliyunElasticsearch
+
return AliyunElasticsearch
if self == DB.AliyunOpenSearch:
from .aliyun_opensearch.aliyun_opensearch import AliyunOpenSearch
+
return AliyunOpenSearch
+ msg = f"Unknown DB: {self.name}"
+ raise ValueError(msg)
+
@property
- def config_cls(self) -> Type[DBConfig]:
+ def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912
"""Import while in use"""
if self == DB.Milvus:
from .milvus.config import MilvusConfig
+
return MilvusConfig
if self == DB.ZillizCloud:
from .zilliz_cloud.config import ZillizCloudConfig
+
return ZillizCloudConfig
if self == DB.Pinecone:
from .pinecone.config import PineconeConfig
+
return PineconeConfig
if self == DB.ElasticCloud:
from .elastic_cloud.config import ElasticCloudConfig
+
return ElasticCloudConfig
if self == DB.QdrantCloud:
from .qdrant_cloud.config import QdrantConfig
+
return QdrantConfig
if self == DB.WeaviateCloud:
from .weaviate_cloud.config import WeaviateConfig
+
return WeaviateConfig
if self == DB.PgVector:
from .pgvector.config import PgVectorConfig
+
return PgVectorConfig
if self == DB.PgVectoRS:
from .pgvecto_rs.config import PgVectoRSConfig
+
return PgVectoRSConfig
if self == DB.PgVectorScale:
from .pgvectorscale.config import PgVectorScaleConfig
+
return PgVectorScaleConfig
if self == DB.PgDiskANN:
from .pgdiskann.config import PgDiskANNConfig
+
return PgDiskANNConfig
if self == DB.Redis:
from .redis.config import RedisConfig
+
return RedisConfig
-
+
if self == DB.MemoryDB:
from .memorydb.config import MemoryDBConfig
+
return MemoryDBConfig
if self == DB.Chroma:
from .chroma.config import ChromaConfig
+
return ChromaConfig
if self == DB.AWSOpenSearch:
from .aws_opensearch.config import AWSOpenSearchConfig
+
return AWSOpenSearchConfig
-
+
if self == DB.AlloyDB:
from .alloydb.config import AlloyDBConfig
+
return AlloyDBConfig
if self == DB.AliyunElasticsearch:
from .aliyun_elasticsearch.config import AliyunElasticsearchConfig
+
return AliyunElasticsearchConfig
if self == DB.AliyunOpenSearch:
from .aliyun_opensearch.config import AliyunOpenSearchConfig
+
return AliyunOpenSearchConfig
- def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
+ msg = f"Unknown DB: {self.name}"
+ raise ValueError(msg)
+
+ def case_config_cls( # noqa: PLR0911
+ self,
+ index_type: IndexType | None = None,
+ ) -> type[DBCaseConfig]:
if self == DB.Milvus:
from .milvus.config import _milvus_case_config
+
return _milvus_case_config.get(index_type)
if self == DB.ZillizCloud:
from .zilliz_cloud.config import AutoIndexConfig
+
return AutoIndexConfig
if self == DB.ElasticCloud:
from .elastic_cloud.config import ElasticCloudIndexConfig
+
return ElasticCloudIndexConfig
if self == DB.QdrantCloud:
from .qdrant_cloud.config import QdrantIndexConfig
+
return QdrantIndexConfig
if self == DB.WeaviateCloud:
from .weaviate_cloud.config import WeaviateIndexConfig
+
return WeaviateIndexConfig
if self == DB.PgVector:
from .pgvector.config import _pgvector_case_config
+
return _pgvector_case_config.get(index_type)
if self == DB.PgVectoRS:
from .pgvecto_rs.config import _pgvecto_rs_case_config
+
return _pgvecto_rs_case_config.get(index_type)
if self == DB.AWSOpenSearch:
from .aws_opensearch.config import AWSOpenSearchIndexConfig
+
return AWSOpenSearchIndexConfig
if self == DB.PgVectorScale:
from .pgvectorscale.config import _pgvectorscale_case_config
+
return _pgvectorscale_case_config.get(index_type)
if self == DB.PgDiskANN:
from .pgdiskann.config import _pgdiskann_case_config
+
return _pgdiskann_case_config.get(index_type)
-
+
if self == DB.AlloyDB:
from .alloydb.config import _alloydb_case_config
+
return _alloydb_case_config.get(index_type)
if self == DB.AliyunElasticsearch:
from .elastic_cloud.config import ElasticCloudIndexConfig
+
return ElasticCloudIndexConfig
if self == DB.AliyunOpenSearch:
from .aliyun_opensearch.config import AliyunOpenSearchIndexConfig
+
return AliyunOpenSearchIndexConfig
# DB.Pinecone, DB.Chroma, DB.Redis
@@ -242,5 +297,11 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon
__all__ = [
- "DB", "VectorDB", "DBConfig", "DBCaseConfig", "IndexType", "MetricType", "EmptyDBCaseConfig",
+ "DB",
+ "DBCaseConfig",
+ "DBConfig",
+ "EmptyDBCaseConfig",
+ "IndexType",
+ "MetricType",
+ "VectorDB",
]
diff --git a/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py b/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py
index 41253ca1e..92ec905e5 100644
--- a/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py
+++ b/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py
@@ -1,5 +1,5 @@
-from ..elastic_cloud.elastic_cloud import ElasticCloud
from ..elastic_cloud.config import ElasticCloudIndexConfig
+from ..elastic_cloud.elastic_cloud import ElasticCloud
class AliyunElasticsearch(ElasticCloud):
@@ -24,4 +24,3 @@ def __init__(
drop_old=drop_old,
**kwargs,
)
-
diff --git a/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py b/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py
index a2de4dc75..a8f5100bf 100644
--- a/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py
+++ b/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py
@@ -1,7 +1,6 @@
-from enum import Enum
-from pydantic import SecretStr, BaseModel
+from pydantic import BaseModel, SecretStr
-from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
+from ..api import DBConfig
class AliyunElasticsearchConfig(DBConfig, BaseModel):
@@ -14,6 +13,6 @@ class AliyunElasticsearchConfig(DBConfig, BaseModel):
def to_dict(self) -> dict:
return {
- "hosts": [{'scheme': self.scheme, 'host': self.host, 'port': self.port}],
+ "hosts": [{"scheme": self.scheme, "host": self.host, "port": self.port}],
"basic_auth": (self.user, self.password.get_secret_value()),
}
diff --git a/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py b/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py
index 5d4dcbfa6..00227cfff 100644
--- a/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py
+++ b/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py
@@ -1,32 +1,32 @@
import json
import logging
-from contextlib import contextmanager
import time
+from contextlib import contextmanager
+from alibabacloud_ha3engine_vector import client, models
from alibabacloud_ha3engine_vector.models import QueryRequest
-
-from ..api import VectorDB, MetricType
-from .config import AliyunOpenSearchIndexConfig
-
-from alibabacloud_searchengine20211025.client import Client as searchengineClient
from alibabacloud_searchengine20211025 import models as searchengine_models
+from alibabacloud_searchengine20211025.client import Client as searchengineClient
from alibabacloud_tea_openapi import models as open_api_models
-from alibabacloud_ha3engine_vector import models, client
+
+from ..api import MetricType, VectorDB
+from .config import AliyunOpenSearchIndexConfig
log = logging.getLogger(__name__)
ALIYUN_OPENSEARCH_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB
ALIYUN_OPENSEARCH_MAX_NUM_PER_BATCH = 100
+
class AliyunOpenSearch(VectorDB):
def __init__(
- self,
- dim: int,
- db_config: dict,
- db_case_config: AliyunOpenSearchIndexConfig,
- collection_name: str = "VectorDBBenchCollection",
- drop_old: bool = False,
- **kwargs,
+ self,
+ dim: int,
+ db_config: dict,
+ db_case_config: AliyunOpenSearchIndexConfig,
+ collection_name: str = "VectorDBBenchCollection",
+ drop_old: bool = False,
+ **kwargs,
):
self.control_client = None
self.dim = dim
@@ -41,14 +41,17 @@ def __init__(
self._index_name = "vector_idx"
self.batch_size = int(
- min(ALIYUN_OPENSEARCH_MAX_SIZE_PER_BATCH / (dim * 25), ALIYUN_OPENSEARCH_MAX_NUM_PER_BATCH)
+ min(
+ ALIYUN_OPENSEARCH_MAX_SIZE_PER_BATCH / (dim * 25),
+ ALIYUN_OPENSEARCH_MAX_NUM_PER_BATCH,
+ ),
)
log.info(f"Aliyun_OpenSearch client config: {self.db_config}")
control_config = open_api_models.Config(
access_key_id=self.db_config["ak"],
access_key_secret=self.db_config["sk"],
- endpoint=self.db_config["control_host"]
+ endpoint=self.db_config["control_host"],
)
self.control_client = searchengineClient(control_config)
@@ -67,7 +70,7 @@ def _create_index(self, client: searchengineClient):
create_table_request.field_schema = {
self._primary_field: "INT64",
self._vector_field: "MULTI_FLOAT",
- self._scalar_field: "INT64"
+ self._scalar_field: "INT64",
}
vector_index = searchengine_models.ModifyTableRequestVectorIndex()
vector_index.index_name = self._index_name
@@ -77,8 +80,25 @@ def _create_index(self, client: searchengineClient):
vector_index.vector_index_type = "HNSW"
advance_params = searchengine_models.ModifyTableRequestVectorIndexAdvanceParams()
- advance_params.build_index_params = "{\"proxima.hnsw.builder.max_neighbor_count\":" + str(self.case_config.M) + ",\"proxima.hnsw.builder.efconstruction\":" + str(self.case_config.efConstruction) + ",\"proxima.hnsw.builder.enable_adsampling\":true,\"proxima.hnsw.builder.slack_pruning_factor\":1.1,\"proxima.hnsw.builder.thread_count\":16}"
- advance_params.search_index_params = "{\"proxima.hnsw.searcher.ef\":400,\"proxima.hnsw.searcher.dynamic_termination.prob_threshold\":0.7}"
+ str_max_neighbor_count = f'"proxima.hnsw.builder.max_neighbor_count":{self.case_config.M}'
+ str_efc = f'"proxima.hnsw.builder.efconstruction":{self.case_config.ef_construction}'
+ str_enable_adsampling = '"proxima.hnsw.builder.enable_adsampling":true'
+ str_slack_pruning_factor = '"proxima.hnsw.builder.slack_pruning_factor":1.1'
+ str_thread_count = '"proxima.hnsw.builder.thread_count":16'
+
+ params = ",".join(
+ [
+ str_max_neighbor_count,
+ str_efc,
+ str_enable_adsampling,
+ str_slack_pruning_factor,
+ str_thread_count,
+ ],
+ )
+ advance_params.build_index_params = params
+ advance_params.search_index_params = (
+ '{"proxima.hnsw.searcher.ef":400,"proxima.hnsw.searcher.dynamic_termination.prob_threshold":0.7}'
+ )
vector_index.advance_params = advance_params
create_table_request.vector_index = [vector_index]
@@ -88,7 +108,7 @@ def _create_index(self, client: searchengineClient):
except Exception as error:
log.info(error.message)
log.info(error.data.get("Recommend"))
- log.info(f"Failed to create index: error: {str(error)}")
+ log.info(f"Failed to create index: error: {error!s}")
raise error from None
# check if index create success
@@ -102,22 +122,22 @@ def _active_index(self, client: searchengineClient) -> None:
log.info(f"begin to {retry_times} times get table")
retry_times += 1
response = client.get_table(self.instance_id, self.collection_name)
- if response.body.result.status == 'IN_USE':
+ if response.body.result.status == "IN_USE":
log.info(f"{self.collection_name} table begin to use.")
return
def _index_exists(self, client: searchengineClient) -> bool:
try:
client.get_table(self.instance_id, self.collection_name)
- return True
- except Exception as error:
- log.info(f'get table from searchengine error')
- log.info(error.message)
+ except Exception as err:
+ log.warning(f"get table from searchengine error, err={err}")
return False
+ else:
+ return True
# check if index build success, Insert the embeddings to the vector database after index build success
def _index_build_success(self, client: searchengineClient) -> None:
- log.info(f"begin to check if table build success.")
+ log.info("begin to check if table build success.")
time.sleep(50)
retry_times = 0
@@ -139,9 +159,9 @@ def _index_build_success(self, client: searchengineClient) -> None:
cur_fsm = fsm
break
if cur_fsm is None:
- print("no build index fsm")
+ log.warning("no build index fsm")
return
- if "success" == cur_fsm["status"]:
+ if cur_fsm["status"] == "success":
return
def _modify_index(self, client: searchengineClient) -> None:
@@ -154,7 +174,7 @@ def _modify_index(self, client: searchengineClient) -> None:
modify_table_request.field_schema = {
self._primary_field: "INT64",
self._vector_field: "MULTI_FLOAT",
- self._scalar_field: "INT64"
+ self._scalar_field: "INT64",
}
vector_index = searchengine_models.ModifyTableRequestVectorIndex()
vector_index.index_name = self._index_name
@@ -163,19 +183,41 @@ def _modify_index(self, client: searchengineClient) -> None:
vector_index.vector_field = self._vector_field
vector_index.vector_index_type = "HNSW"
advance_params = searchengine_models.ModifyTableRequestVectorIndexAdvanceParams()
- advance_params.build_index_params = "{\"proxima.hnsw.builder.max_neighbor_count\":" + str(self.case_config.M) + ",\"proxima.hnsw.builder.efconstruction\":" + str(self.case_config.efConstruction) + ",\"proxima.hnsw.builder.enable_adsampling\":true,\"proxima.hnsw.builder.slack_pruning_factor\":1.1,\"proxima.hnsw.builder.thread_count\":16}"
- advance_params.search_index_params = "{\"proxima.hnsw.searcher.ef\":400,\"proxima.hnsw.searcher.dynamic_termination.prob_threshold\":0.7}"
+
+ str_max_neighbor_count = f'"proxima.hnsw.builder.max_neighbor_count":{self.case_config.M}'
+ str_efc = f'"proxima.hnsw.builder.efconstruction":{self.case_config.ef_construction}'
+ str_enable_adsampling = '"proxima.hnsw.builder.enable_adsampling":true'
+ str_slack_pruning_factor = '"proxima.hnsw.builder.slack_pruning_factor":1.1'
+ str_thread_count = '"proxima.hnsw.builder.thread_count":16'
+
+ params = ",".join(
+ [
+ str_max_neighbor_count,
+ str_efc,
+ str_enable_adsampling,
+ str_slack_pruning_factor,
+ str_thread_count,
+ ],
+ )
+ advance_params.build_index_params = params
+ advance_params.search_index_params = (
+ '{"proxima.hnsw.searcher.ef":400,"proxima.hnsw.searcher.dynamic_termination.prob_threshold":0.7}'
+ )
vector_index.advance_params = advance_params
modify_table_request.vector_index = [vector_index]
try:
- response = client.modify_table(self.instance_id, self.collection_name, modify_table_request)
+ response = client.modify_table(
+ self.instance_id,
+ self.collection_name,
+ modify_table_request,
+ )
log.info(f"modify table success: {response.body}")
except Exception as error:
log.info(error.message)
log.info(error.data.get("Recommend"))
- log.info(f"Failed to modify index: error: {str(error)}")
+ log.info(f"Failed to modify index: error: {error!s}")
raise error from None
# check if modify index & delete data fsm success
@@ -185,15 +227,14 @@ def _modify_index(self, client: searchengineClient) -> None:
def _get_total_count(self):
try:
response = self.client.stats(self.collection_name)
+ except Exception as e:
+ log.warning(f"Error querying index: {e}")
+ else:
body = json.loads(response.body)
log.info(f"stats info: {response.body}")
if "result" in body and "totalDocCount" in body.get("result"):
return body.get("result").get("totalDocCount")
- else:
- return 0
- except Exception as e:
- print(f"Error querying index: {e}")
return 0
@contextmanager
@@ -203,21 +244,20 @@ def init(self) -> None:
endpoint=self.db_config["host"],
protocol="http",
access_user_name=self.db_config["user"],
- access_pass_word=self.db_config["password"]
+ access_pass_word=self.db_config["password"],
)
self.client = client.Client(config)
yield
- # self.client.transport.close()
self.client = None
del self.client
def insert_embeddings(
- self,
- embeddings: list[list[float]],
- metadata: list[int],
- **kwargs,
+ self,
+ embeddings: list[list[float]],
+ metadata: list[int],
+ **kwargs,
) -> tuple[int, Exception]:
"""Insert the embeddings to the opensearch."""
assert self.client is not None, "should self.init() first"
@@ -226,25 +266,24 @@ def insert_embeddings(
try:
for batch_start_offset in range(0, len(embeddings), self.batch_size):
- batch_end_offset = min(
- batch_start_offset + self.batch_size, len(embeddings)
- )
+ batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
documents = []
for i in range(batch_start_offset, batch_end_offset):
- documentFields = {
+ document_fields = {
self._primary_field: metadata[i],
self._vector_field: embeddings[i],
self._scalar_field: metadata[i],
- "ops_build_channel": "inc"
- }
- document = {
- "fields": documentFields,
- "cmd": "add"
+ "ops_build_channel": "inc",
}
+ document = {"fields": document_fields, "cmd": "add"}
documents.append(document)
- pushDocumentsRequest = models.PushDocumentsRequest({}, documents)
- self.client.push_documents(self.collection_name, self._primary_field, pushDocumentsRequest)
+ push_doc_req = models.PushDocumentsRequest({}, documents)
+ self.client.push_documents(
+ self.collection_name,
+ self._primary_field,
+ push_doc_req,
+ )
insert_count += batch_end_offset - batch_start_offset
except Exception as e:
log.info(f"Failed to insert data: {e}")
@@ -252,33 +291,36 @@ def insert_embeddings(
return (insert_count, None)
def search_embedding(
- self,
- query: list[float],
- k: int = 100,
- filters: dict | None = None,
+ self,
+ query: list[float],
+ k: int = 100,
+ filters: dict | None = None,
) -> list[int]:
assert self.client is not None, "should self.init() first"
- search_params = "{\"proxima.hnsw.searcher.ef\":"+ str(self.case_config.ef_search) +"}"
+ search_params = '{"proxima.hnsw.searcher.ef":' + str(self.case_config.ef_search) + "}"
os_filter = f"{self._scalar_field} {filters.get('metadata')}" if filters else ""
try:
- request = QueryRequest(table_name=self.collection_name,
- vector=query,
- top_k=k,
- search_params=search_params, filter=os_filter)
+ request = QueryRequest(
+ table_name=self.collection_name,
+ vector=query,
+ top_k=k,
+ search_params=search_params,
+ filter=os_filter,
+ )
result = self.client.query(request)
except Exception as e:
log.info(f"Error querying index: {e}")
- raise e
- res = json.loads(result.body)
- id_res = [one_res["id"] for one_res in res["result"]]
- return id_res
+ raise e from e
+ else:
+ res = json.loads(result.body)
+ return [one_res["id"] for one_res in res["result"]]
def need_normalize_cosine(self) -> bool:
"""Wheather this database need to normalize dataset to support COSINE"""
if self.case_config.metric_type == MetricType.COSINE:
- log.info(f"cosine dataset need normalize.")
+ log.info("cosine dataset need normalize.")
return True
return False
@@ -296,9 +338,8 @@ def optimize_with_size(self, data_size: int):
total_count = self._get_total_count()
# check if the data is inserted
if total_count == data_size:
- log.info(f"optimize table finish.")
+ log.info("optimize table finish.")
return
def ready_to_load(self):
"""ready_to_load will be called before load in load cases."""
- pass
diff --git a/vectordb_bench/backend/clients/aliyun_opensearch/config.py b/vectordb_bench/backend/clients/aliyun_opensearch/config.py
index 7b2b9ad13..e215b7d68 100644
--- a/vectordb_bench/backend/clients/aliyun_opensearch/config.py
+++ b/vectordb_bench/backend/clients/aliyun_opensearch/config.py
@@ -1,8 +1,8 @@
import logging
-from enum import Enum
-from pydantic import SecretStr, BaseModel
-from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
+from pydantic import BaseModel, SecretStr
+
+from ..api import DBCaseConfig, DBConfig, MetricType
log = logging.getLogger(__name__)
@@ -26,18 +26,17 @@ def to_dict(self) -> dict:
"control_host": self.control_host,
}
+
class AliyunOpenSearchIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType = MetricType.L2
- efConstruction: int = 500
+ ef_construction: int = 500
M: int = 100
ef_search: int = 40
def distance_type(self) -> str:
if self.metric_type == MetricType.L2:
return "SquaredEuclidean"
- elif self.metric_type == MetricType.IP:
- return "InnerProduct"
- elif self.metric_type == MetricType.COSINE:
+ if self.metric_type in (MetricType.IP, MetricType.COSINE):
return "InnerProduct"
return "SquaredEuclidean"
diff --git a/vectordb_bench/backend/clients/alloydb/alloydb.py b/vectordb_bench/backend/clients/alloydb/alloydb.py
index 5b275b30f..c81f77675 100644
--- a/vectordb_bench/backend/clients/alloydb/alloydb.py
+++ b/vectordb_bench/backend/clients/alloydb/alloydb.py
@@ -1,9 +1,9 @@
"""Wrapper around the alloydb vector database over VectorDB"""
import logging
-import pprint
+from collections.abc import Generator, Sequence
from contextlib import contextmanager
-from typing import Any, Generator, Optional, Tuple, Sequence
+from typing import Any
import numpy as np
import psycopg
@@ -11,7 +11,7 @@
from psycopg import Connection, Cursor, sql
from ..api import VectorDB
-from .config import AlloyDBConfigDict, AlloyDBIndexConfig, AlloyDBScaNNConfig
+from .config import AlloyDBConfigDict, AlloyDBIndexConfig
log = logging.getLogger(__name__)
@@ -56,13 +56,14 @@ def __init__(
(
self.case_config.create_index_before_load,
self.case_config.create_index_after_load,
- )
+ ),
):
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
- log.error(err)
- raise RuntimeError(
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
+ msg = (
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
+ "\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
)
+ log.warning(msg)
+ raise RuntimeError(msg)
if drop_old:
self._drop_index()
@@ -77,7 +78,7 @@ def __init__(
self.conn = None
@staticmethod
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
conn = psycopg.connect(**kwargs)
register_vector(conn)
conn.autocommit = False
@@ -86,21 +87,20 @@ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
assert conn is not None, "Connection is not initialized"
assert cursor is not None, "Cursor is not initialized"
return conn, cursor
-
- def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
- search_query = sql.Composed(
+
+ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
+ return sql.Composed(
[
sql.SQL(
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ",
).format(
table_name=sql.Identifier(self.table_name),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
- ]
+ ],
)
- return search_query
@contextmanager
def init(self) -> Generator[None, None, None]:
@@ -119,8 +119,8 @@ def init(self) -> Generator[None, None, None]:
if len(session_options) > 0:
for setting in session_options:
command = sql.SQL("SET {setting_name} " + "= {val};").format(
- setting_name=sql.Identifier(setting['parameter']['setting_name']),
- val=sql.Identifier(str(setting['parameter']['val'])),
+ setting_name=sql.Identifier(setting["parameter"]["setting_name"]),
+ val=sql.Identifier(str(setting["parameter"]["val"])),
)
log.debug(command.as_string(self.cursor))
self.cursor.execute(command)
@@ -144,8 +144,8 @@ def _drop_table(self):
self.cursor.execute(
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
)
self.conn.commit()
@@ -167,7 +167,7 @@ def _drop_index(self):
log.info(f"{self.name} client drop index : {self._index_name}")
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
- index_name=sql.Identifier(self._index_name)
+ index_name=sql.Identifier(self._index_name),
)
log.debug(drop_index_sql.as_string(self.cursor))
self.cursor.execute(drop_index_sql)
@@ -181,78 +181,64 @@ def _set_parallel_index_build_param(self):
if index_param["enable_pca"] is not None:
self.cursor.execute(
- sql.SQL("SET scann.enable_pca TO {};").format(
- index_param["enable_pca"]
- )
+ sql.SQL("SET scann.enable_pca TO {};").format(index_param["enable_pca"]),
)
self.cursor.execute(
sql.SQL("ALTER USER {} SET scann.enable_pca TO {};").format(
sql.Identifier(self.db_config["user"]),
index_param["enable_pca"],
- )
+ ),
)
self.conn.commit()
if index_param["maintenance_work_mem"] is not None:
self.cursor.execute(
sql.SQL("SET maintenance_work_mem TO {};").format(
- index_param["maintenance_work_mem"]
- )
+ index_param["maintenance_work_mem"],
+ ),
)
self.cursor.execute(
sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
sql.Identifier(self.db_config["user"]),
index_param["maintenance_work_mem"],
- )
+ ),
)
self.conn.commit()
if index_param["max_parallel_workers"] is not None:
self.cursor.execute(
sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
- index_param["max_parallel_workers"]
- )
+ index_param["max_parallel_workers"],
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
- ).format(
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
sql.Identifier(self.db_config["user"]),
index_param["max_parallel_workers"],
- )
+ ),
)
self.cursor.execute(
sql.SQL("SET max_parallel_workers TO '{}';").format(
- index_param["max_parallel_workers"]
- )
+ index_param["max_parallel_workers"],
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER USER {} SET max_parallel_workers TO '{}';"
- ).format(
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
sql.Identifier(self.db_config["user"]),
index_param["max_parallel_workers"],
- )
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER TABLE {} SET (parallel_workers = {});"
- ).format(
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
sql.Identifier(self.table_name),
index_param["max_parallel_workers"],
- )
+ ),
)
self.conn.commit()
- results = self.cursor.execute(
- sql.SQL("SHOW max_parallel_maintenance_workers;")
- ).fetchall()
- results.extend(
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
- )
- results.extend(
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
- )
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
log.info(f"{self.name} parallel index creation parameters: {results}")
def _create_index(self):
@@ -264,23 +250,20 @@ def _create_index(self):
self._set_parallel_index_build_param()
options = []
for option in index_param["index_creation_with_options"]:
- if option['val'] is not None:
+ if option["val"] is not None:
options.append(
sql.SQL("{option_name} = {val}").format(
- option_name=sql.Identifier(option['option_name']),
- val=sql.Identifier(str(option['val'])),
- )
+ option_name=sql.Identifier(option["option_name"]),
+ val=sql.Identifier(str(option["val"])),
+ ),
)
- if any(options):
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
- else:
- with_clause = sql.Composed(())
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
index_create_sql = sql.SQL(
"""
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
- """
+ """,
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
@@ -288,9 +271,7 @@ def _create_index(self):
embedding_metric=sql.Identifier(index_param["metric"]),
)
- index_create_sql_with_with_clause = (
- index_create_sql + with_clause
- ).join(" ")
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
self.cursor.execute(index_create_sql_with_with_clause)
self.conn.commit()
@@ -305,14 +286,12 @@ def _create_table(self, dim: int):
# create table
self.cursor.execute(
sql.SQL(
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
)
self.conn.commit()
except Exception as e:
- log.warning(
- f"Failed to create alloydb table: {self.table_name} error: {e}"
- )
+ log.warning(f"Failed to create alloydb table: {self.table_name} error: {e}")
raise e from None
def insert_embeddings(
@@ -320,7 +299,7 @@ def insert_embeddings(
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
- ) -> Tuple[int, Optional[Exception]]:
+ ) -> tuple[int, Exception | None]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
@@ -330,8 +309,8 @@ def insert_embeddings(
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
@@ -343,9 +322,7 @@ def insert_embeddings(
return len(metadata), None
except Exception as e:
- log.warning(
- f"Failed to insert data into alloydb table ({self.table_name}), error: {e}"
- )
+ log.warning(f"Failed to insert data into alloydb table ({self.table_name}), error: {e}")
return 0, e
def search_embedding(
@@ -362,11 +339,12 @@ def search_embedding(
if filters:
gt = filters.get("id")
result = self.cursor.execute(
- self._filtered_search, (gt, q, k), prepare=True, binary=True
+ self._filtered_search,
+ (gt, q, k),
+ prepare=True,
+ binary=True,
)
else:
- result = self.cursor.execute(
- self._unfiltered_search, (q, k), prepare=True, binary=True
- )
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
return [int(i[0]) for i in result.fetchall()]
diff --git a/vectordb_bench/backend/clients/alloydb/cli.py b/vectordb_bench/backend/clients/alloydb/cli.py
index 54d9b9fa3..e207c9c21 100644
--- a/vectordb_bench/backend/clients/alloydb/cli.py
+++ b/vectordb_bench/backend/clients/alloydb/cli.py
@@ -1,10 +1,10 @@
-from typing import Annotated, Optional, TypedDict, Unpack
+import os
+from typing import Annotated, Unpack
import click
-import os
from pydantic import SecretStr
-from vectordb_bench.backend.clients.api import MetricType
+from vectordb_bench.backend.clients import DB
from ....cli.cli import (
CommonTypedDict,
@@ -13,31 +13,28 @@
get_custom_case_config,
run,
)
-from vectordb_bench.backend.clients import DB
class AlloyDBTypedDict(CommonTypedDict):
user_name: Annotated[
- str, click.option("--user-name", type=str, help="Db username", required=True)
+ str,
+ click.option("--user-name", type=str, help="Db username", required=True),
]
password: Annotated[
str,
- click.option("--password",
- type=str,
- help="Postgres database password",
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
- show_default="$POSTGRES_PASSWORD",
- ),
+ click.option(
+ "--password",
+ type=str,
+ help="Postgres database password",
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
+ show_default="$POSTGRES_PASSWORD",
+ ),
]
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
- db_name: Annotated[
- str, click.option("--db-name", type=str, help="Db name", required=True)
- ]
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
maintenance_work_mem: Annotated[
- Optional[str],
+ str | None,
click.option(
"--maintenance-work-mem",
type=str,
@@ -49,7 +46,7 @@ class AlloyDBTypedDict(CommonTypedDict):
),
]
max_parallel_workers: Annotated[
- Optional[int],
+ int | None,
click.option(
"--max-parallel-workers",
type=int,
@@ -58,32 +55,51 @@ class AlloyDBTypedDict(CommonTypedDict):
),
]
-
class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
num_leaves: Annotated[
int,
- click.option("--num-leaves", type=int, help="Number of leaves", required=True)
+ click.option("--num-leaves", type=int, help="Number of leaves", required=True),
]
num_leaves_to_search: Annotated[
int,
- click.option("--num-leaves-to-search", type=int, help="Number of leaves to search", required=True)
+ click.option(
+ "--num-leaves-to-search",
+ type=int,
+ help="Number of leaves to search",
+ required=True,
+ ),
]
pre_reordering_num_neighbors: Annotated[
int,
- click.option("--pre-reordering-num-neighbors", type=int, help="Pre-reordering number of neighbors", default=200)
+ click.option(
+ "--pre-reordering-num-neighbors",
+ type=int,
+ help="Pre-reordering number of neighbors",
+ default=200,
+ ),
]
max_top_neighbors_buffer_size: Annotated[
int,
- click.option("--max-top-neighbors-buffer-size", type=int, help="Maximum top neighbors buffer size", default=20_000)
+ click.option(
+ "--max-top-neighbors-buffer-size",
+ type=int,
+ help="Maximum top neighbors buffer size",
+ default=20_000,
+ ),
]
num_search_threads: Annotated[
int,
- click.option("--num-search-threads", type=int, help="Number of search threads", default=2)
+ click.option("--num-search-threads", type=int, help="Number of search threads", default=2),
]
max_num_prefetch_datasets: Annotated[
int,
- click.option("--max-num-prefetch-datasets", type=int, help="Maximum number of prefetch datasets", default=100)
+ click.option(
+ "--max-num-prefetch-datasets",
+ type=int,
+ help="Maximum number of prefetch datasets",
+ default=100,
+ ),
]
quantizer: Annotated[
str,
@@ -91,16 +107,17 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
"--quantizer",
type=click.Choice(["SQ8", "FLAT"]),
help="Quantizer type",
- default="SQ8"
- )
+ default="SQ8",
+ ),
]
enable_pca: Annotated[
- bool, click.option(
+ bool,
+ click.option(
"--enable-pca",
type=click.Choice(["on", "off"]),
help="Enable PCA",
- default="on"
- )
+ default="on",
+ ),
]
max_num_levels: Annotated[
int,
@@ -108,8 +125,8 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict):
"--max-num-levels",
type=click.Choice(["1", "2"]),
help="Maximum number of levels",
- default=1
- )
+ default=1,
+ ),
]
@@ -144,4 +161,4 @@ def AlloyDBScaNN(
maintenance_work_mem=parameters["maintenance_work_mem"],
),
**parameters,
- )
\ No newline at end of file
+ )
diff --git a/vectordb_bench/backend/clients/alloydb/config.py b/vectordb_bench/backend/clients/alloydb/config.py
index 1d5dde519..d6e54e487 100644
--- a/vectordb_bench/backend/clients/alloydb/config.py
+++ b/vectordb_bench/backend/clients/alloydb/config.py
@@ -1,7 +1,9 @@
from abc import abstractmethod
-from typing import Any, Mapping, Optional, Sequence, TypedDict
+from collections.abc import Mapping, Sequence
+from typing import Any, LiteralString, TypedDict
+
from pydantic import BaseModel, SecretStr
-from typing_extensions import LiteralString
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +11,7 @@
class AlloyDBConfigDict(TypedDict):
"""These keys will be directly used as kwargs in psycopg connection string,
- so the names must match exactly psycopg API"""
+ so the names must match exactly psycopg API"""
user: str
password: str
@@ -41,8 +43,8 @@ class AlloyDBIndexParam(TypedDict):
metric: str
index_type: str
index_creation_with_options: Sequence[dict[str, Any]]
- maintenance_work_mem: Optional[str]
- max_parallel_workers: Optional[int]
+ maintenance_work_mem: str | None
+ max_parallel_workers: int | None
class AlloyDBSearchParam(TypedDict):
@@ -61,31 +63,30 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig):
def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2"
- elif self.metric_type == MetricType.DP:
+ if self.metric_type == MetricType.DP:
return "dot_product"
return "cosine"
def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
return "<->"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "<#>"
return "<=>"
@abstractmethod
- def index_param(self) -> AlloyDBIndexParam:
- ...
+ def index_param(self) -> AlloyDBIndexParam: ...
@abstractmethod
- def search_param(self) -> AlloyDBSearchParam:
- ...
+ def search_param(self) -> AlloyDBSearchParam: ...
@abstractmethod
- def session_param(self) -> AlloyDBSessionCommands:
- ...
+ def session_param(self) -> AlloyDBSessionCommands: ...
@staticmethod
- def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
+ def _optionally_build_with_options(
+ with_options: Mapping[str, Any],
+ ) -> Sequence[dict[str, Any]]:
"""Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
options = []
for option_name, value in with_options.items():
@@ -94,24 +95,25 @@ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[
{
"option_name": option_name,
"val": str(value),
- }
+ },
)
return options
@staticmethod
def _optionally_build_set_options(
- set_mapping: Mapping[str, Any]
+ set_mapping: Mapping[str, Any],
) -> Sequence[dict[str, Any]]:
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
session_options = []
for setting_name, value in set_mapping.items():
if value:
session_options.append(
- {"parameter": {
+ {
+ "parameter": {
"setting_name": setting_name,
"val": str(value),
},
- }
+ },
)
return session_options
@@ -124,22 +126,22 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig):
max_num_levels: int | None
num_leaves_to_search: int | None
max_top_neighbors_buffer_size: int | None
- pre_reordering_num_neighbors: int | None
- num_search_threads: int | None
+ pre_reordering_num_neighbors: int | None
+ num_search_threads: int | None
max_num_prefetch_datasets: int | None
- maintenance_work_mem: Optional[str] = None
- max_parallel_workers: Optional[int] = None
+ maintenance_work_mem: str | None = None
+ max_parallel_workers: int | None = None
def index_param(self) -> AlloyDBIndexParam:
index_parameters = {
- "num_leaves": self.num_leaves, "max_num_levels": self.max_num_levels, "quantizer": self.quantizer,
+ "num_leaves": self.num_leaves,
+ "max_num_levels": self.max_num_levels,
+ "quantizer": self.quantizer,
}
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
- "index_creation_with_options": self._optionally_build_with_options(
- index_parameters
- ),
+ "index_creation_with_options": self._optionally_build_with_options(index_parameters),
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"enable_pca": self.enable_pca,
@@ -158,11 +160,9 @@ def session_param(self) -> AlloyDBSessionCommands:
"scann.num_search_threads": self.num_search_threads,
"scann.max_num_prefetch_datasets": self.max_num_prefetch_datasets,
}
- return {
- "session_options": self._optionally_build_set_options(session_parameters)
- }
+ return {"session_options": self._optionally_build_set_options(session_parameters)}
_alloydb_case_config = {
- IndexType.SCANN: AlloyDBScaNNConfig,
+ IndexType.SCANN: AlloyDBScaNNConfig,
}
diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py
index fe2e554f3..aa93abc12 100644
--- a/vectordb_bench/backend/clients/api.py
+++ b/vectordb_bench/backend/clients/api.py
@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
-from enum import Enum
-from typing import Any, Type
from contextlib import contextmanager
+from enum import Enum
-from pydantic import BaseModel, validator, SecretStr
+from pydantic import BaseModel, SecretStr, validator
class MetricType(str, Enum):
@@ -65,13 +64,10 @@ def to_dict(self) -> dict:
raise NotImplementedError
@validator("*")
- def not_empty_field(cls, v, field):
- if (
- field.name in cls.common_short_configs()
- or field.name in cls.common_long_configs()
- ):
+ def not_empty_field(cls, v: any, field: any):
+ if field.name in cls.common_short_configs() or field.name in cls.common_long_configs():
return v
- if not v and isinstance(v, (str, SecretStr)):
+ if not v and isinstance(v, str | SecretStr):
raise ValueError("Empty string!")
return v
diff --git a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py
index 60e519d7a..487ec67cc 100644
--- a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py
+++ b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py
@@ -1,14 +1,18 @@
import logging
-from contextlib import contextmanager
import time
-from typing import Iterable, Type
-from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType
-from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
+from collections.abc import Iterable
+from contextlib import contextmanager
+
from opensearchpy import OpenSearch
-from opensearchpy.helpers import bulk
+
+from ..api import IndexType, VectorDB
+from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine
log = logging.getLogger(__name__)
+WAITING_FOR_REFRESH_SEC = 30
+WAITING_FOR_FORCE_MERGE_SEC = 30
+
class AWSOpenSearch(VectorDB):
def __init__(
@@ -27,9 +31,7 @@ def __init__(
self.case_config = db_case_config
self.index_name = index_name
self.id_col_name = id_col_name
- self.category_col_names = [
- f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]
- ]
+ self.category_col_names = [f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]]
self.vector_col_name = vector_col_name
log.info(f"AWS_OpenSearch client config: {self.db_config}")
@@ -46,38 +48,32 @@ def config_cls(cls) -> AWSOpenSearchConfig:
return AWSOpenSearchConfig
@classmethod
- def case_config_cls(
- cls, index_type: IndexType | None = None
- ) -> AWSOpenSearchIndexConfig:
+ def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIndexConfig:
return AWSOpenSearchIndexConfig
def _create_index(self, client: OpenSearch):
settings = {
"index": {
"knn": True,
- # "number_of_shards": 5,
- # "refresh_interval": "600s",
- }
+ },
}
mappings = {
"properties": {
- **{
- categoryCol: {"type": "keyword"}
- for categoryCol in self.category_col_names
- },
+ **{categoryCol: {"type": "keyword"} for categoryCol in self.category_col_names},
self.vector_col_name: {
"type": "knn_vector",
"dimension": self.dim,
"method": self.case_config.index_param(),
},
- }
+ },
}
try:
client.indices.create(
- index=self.index_name, body=dict(settings=settings, mappings=mappings)
+ index=self.index_name,
+ body={"settings": settings, "mappings": mappings},
)
except Exception as e:
- log.warning(f"Failed to create index: {self.index_name} error: {str(e)}")
+ log.warning(f"Failed to create index: {self.index_name} error: {e!s}")
raise e from None
@contextmanager
@@ -86,7 +82,6 @@ def init(self) -> None:
self.client = OpenSearch(**self.db_config)
yield
- # self.client.transport.close()
self.client = None
del self.client
@@ -101,16 +96,20 @@ def insert_embeddings(
insert_data = []
for i in range(len(embeddings)):
- insert_data.append({"index": {"_index": self.index_name, self.id_col_name: metadata[i]}})
+ insert_data.append(
+ {"index": {"_index": self.index_name, self.id_col_name: metadata[i]}},
+ )
insert_data.append({self.vector_col_name: embeddings[i]})
try:
resp = self.client.bulk(insert_data)
log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}")
resp = self.client.indices.stats(self.index_name)
- log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}")
+ log.info(
+ f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}",
+ )
return (len(embeddings), None)
except Exception as e:
- log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}")
+ log.warning(f"Failed to insert data: {self.index_name} error: {e!s}")
time.sleep(10)
return self.insert_embeddings(embeddings, metadata)
@@ -135,20 +134,23 @@ def search_embedding(
body = {
"size": k,
"query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}},
- **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {})
+ **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}),
}
try:
- resp = self.client.search(index=self.index_name, body=body,size=k,_source=False,docvalue_fields=[self.id_col_name],stored_fields="_none_",)
+ resp = self.client.search(
+ index=self.index_name,
+ body=body,
+ size=k,
+ _source=False,
+ docvalue_fields=[self.id_col_name],
+ stored_fields="_none_",
+ )
log.info(f'Search took: {resp["took"]}')
log.info(f'Search shards: {resp["_shards"]}')
log.info(f'Search hits total: {resp["hits"]["total"]}')
- result = [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
- #result = [int(d["_id"]) for d in resp["hits"]["hits"]]
- # log.info(f'success! length={len(res)}')
-
- return result
+ return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]]
except Exception as e:
- log.warning(f"Failed to search: {self.index_name} error: {str(e)}")
+ log.warning(f"Failed to search: {self.index_name} error: {e!s}")
raise e from None
def optimize(self):
@@ -163,37 +165,35 @@ def optimize(self):
def _refresh_index(self):
log.debug(f"Starting refresh for index {self.index_name}")
- SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
while True:
try:
- log.info(f"Starting the Refresh Index..")
+ log.info("Starting the Refresh Index..")
self.client.indices.refresh(index=self.index_name)
break
except Exception as e:
log.info(
- f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
- time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
+ f"Refresh errored out. Sleeping for {WAITING_FOR_REFRESH_SEC} sec and then Retrying : {e}",
+ )
+ time.sleep(WAITING_FOR_REFRESH_SEC)
continue
log.debug(f"Completed refresh for index {self.index_name}")
def _do_force_merge(self):
log.debug(f"Starting force merge for index {self.index_name}")
- force_merge_endpoint = f'/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
- force_merge_task_id = self.client.transport.perform_request('POST', force_merge_endpoint)['task']
- SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
+ force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
+ force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"]
while True:
- time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
+ time.sleep(WAITING_FOR_FORCE_MERGE_SEC)
task_status = self.client.tasks.get(task_id=force_merge_task_id)
- if task_status['completed']:
+ if task_status["completed"]:
break
log.debug(f"Completed force merge for index {self.index_name}")
def _load_graphs_to_memory(self):
if self.case_config.engine != AWSOS_Engine.lucene:
log.info("Calling warmup API to load graphs into memory")
- warmup_endpoint = f'/_plugins/_knn/warmup/{self.index_name}'
- self.client.transport.perform_request('GET', warmup_endpoint)
+ warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}"
+ self.client.transport.perform_request("GET", warmup_endpoint)
def ready_to_load(self):
"""ready_to_load will be called before load in load cases."""
- pass
diff --git a/vectordb_bench/backend/clients/aws_opensearch/cli.py b/vectordb_bench/backend/clients/aws_opensearch/cli.py
index 5cb4ebbe1..bb0c2450d 100644
--- a/vectordb_bench/backend/clients/aws_opensearch/cli.py
+++ b/vectordb_bench/backend/clients/aws_opensearch/cli.py
@@ -14,22 +14,20 @@
class AWSOpenSearchTypedDict(TypedDict):
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")]
user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")]
password: Annotated[str, click.option("--password", type=str, help="Db password")]
-class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2):
- ...
+class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ...
@cli.command()
@click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict)
def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig
+
run(
db=DB.AWSOpenSearch,
db_config=AWSOpenSearchConfig(
@@ -38,7 +36,6 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]):
user=parameters["user"],
password=SecretStr(parameters["password"]),
),
- db_case_config=AWSOpenSearchIndexConfig(
- ),
+ db_case_config=AWSOpenSearchIndexConfig(),
**parameters,
)
diff --git a/vectordb_bench/backend/clients/aws_opensearch/config.py b/vectordb_bench/backend/clients/aws_opensearch/config.py
index 15cd4ead8..e9ccc7277 100644
--- a/vectordb_bench/backend/clients/aws_opensearch/config.py
+++ b/vectordb_bench/backend/clients/aws_opensearch/config.py
@@ -1,10 +1,13 @@
import logging
from enum import Enum
-from pydantic import SecretStr, BaseModel
-from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
+from pydantic import BaseModel, SecretStr
+
+from ..api import DBCaseConfig, DBConfig, MetricType
log = logging.getLogger(__name__)
+
+
class AWSOpenSearchConfig(DBConfig, BaseModel):
host: str = ""
port: int = 443
@@ -13,7 +16,7 @@ class AWSOpenSearchConfig(DBConfig, BaseModel):
def to_dict(self) -> dict:
return {
- "hosts": [{'host': self.host, 'port': self.port}],
+ "hosts": [{"host": self.host, "port": self.port}],
"http_auth": (self.user, self.password.get_secret_value()),
"use_ssl": True,
"http_compress": True,
@@ -40,25 +43,26 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig):
def parse_metric(self) -> str:
if self.metric_type == MetricType.IP:
return "innerproduct"
- elif self.metric_type == MetricType.COSINE:
+ if self.metric_type == MetricType.COSINE:
if self.engine == AWSOS_Engine.faiss:
- log.info(f"Using metric type as innerproduct because faiss doesn't support cosine as metric type for Opensearch")
+ log.info(
+ "Using innerproduct because faiss doesn't support cosine as metric type for Opensearch",
+ )
return "innerproduct"
return "cosinesimil"
return "l2"
def index_param(self) -> dict:
- params = {
+ return {
"name": "hnsw",
"space_type": self.parse_metric(),
"engine": self.engine.value,
"parameters": {
"ef_construction": self.efConstruction,
"m": self.M,
- "ef_search": self.efSearch
- }
+ "ef_search": self.efSearch,
+ },
}
- return params
def search_param(self) -> dict:
return {}
diff --git a/vectordb_bench/backend/clients/aws_opensearch/run.py b/vectordb_bench/backend/clients/aws_opensearch/run.py
index d2698d139..68aa200a5 100644
--- a/vectordb_bench/backend/clients/aws_opensearch/run.py
+++ b/vectordb_bench/backend/clients/aws_opensearch/run.py
@@ -1,12 +1,16 @@
-import time, random
+import logging
+import random
+import time
+
from opensearchpy import OpenSearch
-from opensearch_dsl import Search, Document, Text, Keyword
-_HOST = 'xxxxxx.us-west-2.es.amazonaws.com'
+log = logging.getLogger(__name__)
+
+_HOST = "xxxxxx.us-west-2.es.amazonaws.com"
_PORT = 443
-_AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code.
+_AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code.
-_INDEX_NAME = 'my-dsl-index'
+_INDEX_NAME = "my-dsl-index"
_BATCH = 100
_ROWS = 100
_DIM = 128
@@ -14,25 +18,24 @@
def create_client():
- client = OpenSearch(
- hosts=[{'host': _HOST, 'port': _PORT}],
- http_compress=True, # enables gzip compression for request bodies
+ return OpenSearch(
+ hosts=[{"host": _HOST, "port": _PORT}],
+ http_compress=True, # enables gzip compression for request bodies
http_auth=_AUTH,
use_ssl=True,
verify_certs=True,
ssl_assert_hostname=False,
ssl_show_warn=False,
)
- return client
-def create_index(client, index_name):
+def create_index(client: OpenSearch, index_name: str):
settings = {
"index": {
"knn": True,
"number_of_shards": 1,
"refresh_interval": "5s",
- }
+ },
}
mappings = {
"properties": {
@@ -46,41 +49,46 @@ def create_index(client, index_name):
"parameters": {
"ef_construction": 256,
"m": 16,
- }
- }
- }
- }
+ },
+ },
+ },
+ },
}
- response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings))
- print('\nCreating index:')
- print(response)
+ response = client.indices.create(
+ index=index_name,
+ body={"settings": settings, "mappings": mappings},
+ )
+ log.info("\nCreating index:")
+ log.info(response)
-def delete_index(client, index_name):
+def delete_index(client: OpenSearch, index_name: str):
response = client.indices.delete(index=index_name)
- print('\nDeleting index:')
- print(response)
+ log.info("\nDeleting index:")
+ log.info(response)
-def bulk_insert(client, index_name):
+def bulk_insert(client: OpenSearch, index_name: str):
# Perform bulk operations
- ids = [i for i in range(_ROWS)]
+ ids = list(range(_ROWS))
vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)]
docs = []
for i in range(0, _ROWS, _BATCH):
docs.clear()
- for j in range(0, _BATCH):
- docs.append({"index": {"_index": index_name, "_id": ids[i+j]}})
- docs.append({"embedding": vec[i+j]})
+ for j in range(_BATCH):
+ docs.append({"index": {"_index": index_name, "_id": ids[i + j]}})
+ docs.append({"embedding": vec[i + j]})
response = client.bulk(docs)
- print('\nAdding documents:', len(response['items']), response['errors'])
+ log.info(f"Adding documents: {len(response['items'])}, {response['errors']}")
response = client.indices.stats(index_name)
- print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total'])
+ log.info(
+ f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }',
+ )
-def search(client, index_name):
+def search(client: OpenSearch, index_name: str):
# Search for the document.
search_body = {
"size": _TOPK,
@@ -89,53 +97,55 @@ def search(client, index_name):
"embedding": {
"vector": [random.random() for _ in range(_DIM)],
"k": _TOPK,
- }
- }
- }
+ },
+ },
+ },
}
while True:
response = client.search(index=index_name, body=search_body)
- print(f'\nSearch took: {response["took"]}')
- print(f'\nSearch shards: {response["_shards"]}')
- print(f'\nSearch hits total: {response["hits"]["total"]}')
+ log.info(f'\nSearch took: {response["took"]}')
+ log.info(f'\nSearch shards: {response["_shards"]}')
+ log.info(f'\nSearch hits total: {response["hits"]["total"]}')
result = response["hits"]["hits"]
if len(result) != 0:
- print('\nSearch results:')
+ log.info("\nSearch results:")
for hit in response["hits"]["hits"]:
- print(hit["_id"], hit["_score"])
+ log.info(hit["_id"], hit["_score"])
break
- else:
- print('\nSearch not ready, sleep 1s')
- time.sleep(1)
-
-def optimize_index(client, index_name):
- print(f"Starting force merge for index {index_name}")
- force_merge_endpoint = f'/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false'
- force_merge_task_id = client.transport.perform_request('POST', force_merge_endpoint)['task']
- SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
+ log.info("\nSearch not ready, sleep 1s")
+ time.sleep(1)
+
+
+SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
+WAITINT_FOR_REFRESH_SEC = 30
+
+
+def optimize_index(client: OpenSearch, index_name: str):
+ log.info(f"Starting force merge for index {index_name}")
+ force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false"
+ force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"]
while True:
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
task_status = client.tasks.get(task_id=force_merge_task_id)
- if task_status['completed']:
+ if task_status["completed"]:
break
- print(f"Completed force merge for index {index_name}")
+ log.info(f"Completed force merge for index {index_name}")
-def refresh_index(client, index_name):
- print(f"Starting refresh for index {index_name}")
- SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30
+def refresh_index(client: OpenSearch, index_name: str):
+ log.info(f"Starting refresh for index {index_name}")
while True:
try:
- print(f"Starting the Refresh Index..")
+ log.info("Starting the Refresh Index..")
client.indices.refresh(index=index_name)
break
except Exception as e:
- print(
- f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}")
- time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC)
+ log.info(
+ f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}",
+ )
+ time.sleep(WAITINT_FOR_REFRESH_SEC)
continue
- print(f"Completed refresh for index {index_name}")
-
+ log.info(f"Completed refresh for index {index_name}")
def main():
@@ -148,9 +158,9 @@ def main():
search(client, _INDEX_NAME)
delete_index(client, _INDEX_NAME)
except Exception as e:
- print(e)
+ log.info(e)
delete_index(client, _INDEX_NAME)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py
index 235cb595b..a148fa141 100644
--- a/vectordb_bench/backend/clients/chroma/chroma.py
+++ b/vectordb_bench/backend/clients/chroma/chroma.py
@@ -1,55 +1,55 @@
-import chromadb
-import logging
+import logging
from contextlib import contextmanager
from typing import Any
-from ..api import VectorDB, DBCaseConfig
+
+import chromadb
+
+from ..api import DBCaseConfig, VectorDB
log = logging.getLogger(__name__)
+
+
class ChromaClient(VectorDB):
- """Chroma client for VectorDB.
+ """Chroma client for VectorDB.
To set up Chroma in docker, see https://docs.trychroma.com/usage-guide
or the instructions in tests/test_chroma.py
To change to running in process, modify the HttpClient() in __init__() and init().
- """
+ """
def __init__(
- self,
- dim: int,
- db_config: dict,
- db_case_config: DBCaseConfig,
- drop_old: bool = False,
-
- **kwargs
- ):
-
+ self,
+ dim: int,
+ db_config: dict,
+ db_case_config: DBCaseConfig,
+ drop_old: bool = False,
+ **kwargs,
+ ):
self.db_config = db_config
self.case_config = db_case_config
- self.collection_name = 'example2'
+ self.collection_name = "example2"
- client = chromadb.HttpClient(host=self.db_config["host"],
- port=self.db_config["port"])
+ client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
assert client.heartbeat() is not None
if drop_old:
try:
- client.reset() # Reset the database
- except:
+ client.reset() # Reset the database
+ except Exception:
drop_old = False
log.info(f"Chroma client drop_old collection: {self.collection_name}")
@contextmanager
def init(self) -> None:
- """ create and destory connections to database.
+ """create and destory connections to database.
Examples:
>>> with self.init():
>>> self.insert_embeddings()
"""
- #create connection
- self.client = chromadb.HttpClient(host=self.db_config["host"],
- port=self.db_config["port"])
-
- self.collection = self.client.get_or_create_collection('example2')
+ # create connection
+ self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"])
+
+ self.collection = self.client.get_or_create_collection("example2")
yield
self.client = None
self.collection = None
@@ -79,12 +79,12 @@ def insert_embeddings(
Returns:
(int, Exception): number of embeddings inserted and exception if any
"""
- ids=[str(i) for i in metadata]
- metadata = [{"id": int(i)} for i in metadata]
+ ids = [str(i) for i in metadata]
+ metadata = [{"id": int(i)} for i in metadata]
if len(embeddings) > 0:
self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata)
return len(embeddings), None
-
+
def search_embedding(
self,
query: list[float],
@@ -100,17 +100,19 @@ def search_embedding(
kwargs: other arguments
Returns:
- Dict {ids: list[list[int]],
- embedding: list[list[float]]
+ Dict {ids: list[list[int]],
+ embedding: list[list[float]]
distance: list[list[float]]}
"""
if filters:
# assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
id_value = filters.get("id")
- results = self.collection.query(query_embeddings=query, n_results=k,
- where={"id": {"$gt": id_value}})
- #return list of id's in results
- return [int(i) for i in results.get('ids')[0]]
+ results = self.collection.query(
+ query_embeddings=query,
+ n_results=k,
+ where={"id": {"$gt": id_value}},
+ )
+ # return list of id's in results
+ return [int(i) for i in results.get("ids")[0]]
results = self.collection.query(query_embeddings=query, n_results=k)
- return [int(i) for i in results.get('ids')[0]]
-
+ return [int(i) for i in results.get("ids")[0]]
diff --git a/vectordb_bench/backend/clients/chroma/config.py b/vectordb_bench/backend/clients/chroma/config.py
index 85c59c973..af34cf513 100644
--- a/vectordb_bench/backend/clients/chroma/config.py
+++ b/vectordb_bench/backend/clients/chroma/config.py
@@ -1,14 +1,16 @@
from pydantic import SecretStr
+
from ..api import DBConfig
+
class ChromaConfig(DBConfig):
password: SecretStr
host: SecretStr
- port: int
+ port: int
def to_dict(self) -> dict:
return {
"host": self.host.get_secret_value(),
"port": self.port,
"password": self.password.get_secret_value(),
- }
\ No newline at end of file
+ }
diff --git a/vectordb_bench/backend/clients/elastic_cloud/config.py b/vectordb_bench/backend/clients/elastic_cloud/config.py
index 204a4fc9d..d35ee68ec 100644
--- a/vectordb_bench/backend/clients/elastic_cloud/config.py
+++ b/vectordb_bench/backend/clients/elastic_cloud/config.py
@@ -1,7 +1,8 @@
from enum import Enum
-from pydantic import SecretStr, BaseModel
-from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
+from pydantic import BaseModel, SecretStr
+
+from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
class ElasticCloudConfig(DBConfig, BaseModel):
@@ -32,12 +33,12 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig):
def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_norm"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "dot_product"
return "cosine"
def index_param(self) -> dict:
- params = {
+ return {
"type": "dense_vector",
"index": True,
"element_type": self.element_type.value,
@@ -48,7 +49,6 @@ def index_param(self) -> dict:
"ef_construction": self.efConstruction,
},
}
- return params
def search_param(self) -> dict:
return {
diff --git a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py
index 64f27e490..a3183bcb7 100644
--- a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py
+++ b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py
@@ -1,17 +1,22 @@
import logging
import time
+from collections.abc import Iterable
from contextlib import contextmanager
-from typing import Iterable
-from ..api import VectorDB
-from .config import ElasticCloudIndexConfig
+
from elasticsearch.helpers import bulk
+from ..api import VectorDB
+from .config import ElasticCloudIndexConfig
for logger in ("elasticsearch", "elastic_transport"):
logging.getLogger(logger).setLevel(logging.WARNING)
log = logging.getLogger(__name__)
+
+SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
+
+
class ElasticCloud(VectorDB):
def __init__(
self,
@@ -46,14 +51,14 @@ def __init__(
def init(self) -> None:
"""connect to elasticsearch"""
from elasticsearch import Elasticsearch
+
self.client = Elasticsearch(**self.db_config, request_timeout=180)
yield
- # self.client.transport.close()
self.client = None
- del(self.client)
+ del self.client
- def _create_indice(self, client) -> None:
+ def _create_indice(self, client: any) -> None:
mappings = {
"_source": {"excludes": [self.vector_col_name]},
"properties": {
@@ -62,13 +67,13 @@ def _create_indice(self, client) -> None:
"dims": self.dim,
**self.case_config.index_param(),
},
- }
+ },
}
try:
client.indices.create(index=self.indice, mappings=mappings)
except Exception as e:
- log.warning(f"Failed to create indice: {self.indice} error: {str(e)}")
+ log.warning(f"Failed to create indice: {self.indice} error: {e!s}")
raise e from None
def insert_embeddings(
@@ -94,7 +99,7 @@ def insert_embeddings(
bulk_insert_res = bulk(self.client, insert_data)
return (bulk_insert_res[0], None)
except Exception as e:
- log.warning(f"Failed to insert data: {self.indice} error: {str(e)}")
+ log.warning(f"Failed to insert data: {self.indice} error: {e!s}")
return (0, e)
def search_embedding(
@@ -114,16 +119,12 @@ def search_embedding(
list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding.
"""
assert self.client is not None, "should self.init() first"
- # is_existed_res = self.client.indices.exists(index=self.indice)
- # assert is_existed_res.raw == True, "should self.init() first"
knn = {
"field": self.vector_col_name,
"k": k,
"num_candidates": self.case_config.num_candidates,
- "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}]
- if filters
- else [],
+ "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] if filters else [],
"query_vector": query,
}
size = k
@@ -137,26 +138,26 @@ def search_embedding(
stored_fields="_none_",
filter_path=[f"hits.hits.fields.{self.id_col_name}"],
)
- res = [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
-
- return res
+ return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]]
except Exception as e:
- log.warning(f"Failed to search: {self.indice} error: {str(e)}")
+ log.warning(f"Failed to search: {self.indice} error: {e!s}")
raise e from None
def optimize(self):
"""optimize will be called between insertion and search in performance cases."""
assert self.client is not None, "should self.init() first"
self.client.indices.refresh(index=self.indice)
- force_merge_task_id = self.client.indices.forcemerge(index=self.indice, max_num_segments=1, wait_for_completion=False)['task']
+ force_merge_task_id = self.client.indices.forcemerge(
+ index=self.indice,
+ max_num_segments=1,
+ wait_for_completion=False,
+ )["task"]
log.info(f"Elasticsearch force merge task id: {force_merge_task_id}")
- SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30
while True:
time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC)
task_status = self.client.tasks.get(task_id=force_merge_task_id)
- if task_status['completed']:
+ if task_status["completed"]:
return
def ready_to_load(self):
"""ready_to_load will be called before load in load cases."""
- pass
diff --git a/vectordb_bench/backend/clients/memorydb/cli.py b/vectordb_bench/backend/clients/memorydb/cli.py
index 50b5f89ba..ae00bfd17 100644
--- a/vectordb_bench/backend/clients/memorydb/cli.py
+++ b/vectordb_bench/backend/clients/memorydb/cli.py
@@ -14,9 +14,7 @@
class MemoryDBTypedDict(TypedDict):
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
password: Annotated[str, click.option("--password", type=str, help="Db password")]
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
ssl: Annotated[
@@ -44,7 +42,10 @@ class MemoryDBTypedDict(TypedDict):
is_flag=True,
show_default=True,
default=False,
- help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports cluster mode (CME)",
+ help=(
+ "Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance.",
+ " In production, MemoryDB only supports cluster mode (CME)",
+ ),
),
]
insert_batch_size: Annotated[
@@ -58,8 +59,7 @@ class MemoryDBTypedDict(TypedDict):
]
-class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2):
- ...
+class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): ...
@cli.command()
@@ -82,7 +82,7 @@ def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]):
M=parameters["m"],
ef_construction=parameters["ef_construction"],
ef_runtime=parameters["ef_runtime"],
- insert_batch_size=parameters["insert_batch_size"]
+ insert_batch_size=parameters["insert_batch_size"],
),
**parameters,
- )
\ No newline at end of file
+ )
diff --git a/vectordb_bench/backend/clients/memorydb/config.py b/vectordb_bench/backend/clients/memorydb/config.py
index 1284d3449..2c40ff546 100644
--- a/vectordb_bench/backend/clients/memorydb/config.py
+++ b/vectordb_bench/backend/clients/memorydb/config.py
@@ -29,7 +29,7 @@ class MemoryDBIndexConfig(BaseModel, DBCaseConfig):
def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "ip"
return "cosine"
@@ -51,4 +51,4 @@ def index_param(self) -> dict:
def search_param(self) -> dict:
return {
"ef_runtime": self.ef_runtime,
- }
\ No newline at end of file
+ }
diff --git a/vectordb_bench/backend/clients/memorydb/memorydb.py b/vectordb_bench/backend/clients/memorydb/memorydb.py
index c5f80eb2a..d05e30be1 100644
--- a/vectordb_bench/backend/clients/memorydb/memorydb.py
+++ b/vectordb_bench/backend/clients/memorydb/memorydb.py
@@ -1,30 +1,33 @@
-import logging, time
+import logging
+import time
+from collections.abc import Generator
from contextlib import contextmanager
-from typing import Any, Generator, Optional, Tuple, Type
-from ..api import VectorDB, DBCaseConfig, IndexType
-from .config import MemoryDBIndexConfig
+from typing import Any
+
+import numpy as np
import redis
from redis import Redis
from redis.cluster import RedisCluster
-from redis.commands.search.field import TagField, VectorField, NumericField
-from redis.commands.search.indexDefinition import IndexDefinition, IndexType
+from redis.commands.search.field import NumericField, TagField, VectorField
+from redis.commands.search.indexDefinition import IndexDefinition
from redis.commands.search.query import Query
-import numpy as np
+from ..api import IndexType, VectorDB
+from .config import MemoryDBIndexConfig
log = logging.getLogger(__name__)
-INDEX_NAME = "index" # Vector Index Name
+INDEX_NAME = "index" # Vector Index Name
+
class MemoryDB(VectorDB):
def __init__(
- self,
- dim: int,
- db_config: dict,
- db_case_config: MemoryDBIndexConfig,
- drop_old: bool = False,
- **kwargs
- ):
-
+ self,
+ dim: int,
+ db_config: dict,
+ db_case_config: MemoryDBIndexConfig,
+ drop_old: bool = False,
+ **kwargs,
+ ):
self.db_config = db_config
self.case_config = db_case_config
self.collection_name = INDEX_NAME
@@ -44,10 +47,10 @@ def __init__(
info = conn.ft(INDEX_NAME).info()
log.info(f"Index info: {info}")
except redis.exceptions.ResponseError as e:
- log.error(e)
+ log.warning(e)
drop_old = False
log.info(f"MemoryDB client drop_old collection: {self.collection_name}")
-
+
log.info("Executing FLUSHALL")
conn.flushall()
@@ -59,7 +62,7 @@ def __init__(
self.wait_until(self.wait_for_empty_db, 3, "", rc)
log.debug(f"Flushall done in the host: {host}")
rc.close()
-
+
self.make_index(dim, conn)
conn.close()
conn = None
@@ -69,7 +72,7 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis):
# check to see if index exists
conn.ft(INDEX_NAME).info()
except Exception as e:
- log.warn(f"Error getting info for index '{INDEX_NAME}': {e}")
+ log.warning(f"Error getting info for index '{INDEX_NAME}': {e}")
index_param = self.case_config.index_param()
search_param = self.case_config.search_param()
vector_parameters = { # Vector Index Type: FLAT or HNSW
@@ -85,17 +88,19 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis):
vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"]
schema = (
- TagField("id"),
- NumericField("metadata"),
- VectorField("vector", # Vector Field Name
- "HNSW", vector_parameters
+ TagField("id"),
+ NumericField("metadata"),
+ VectorField(
+ "vector", # Vector Field Name
+ "HNSW",
+ vector_parameters,
),
)
definition = IndexDefinition(index_type=IndexType.HASH)
rs = conn.ft(INDEX_NAME)
rs.create_index(schema, definition=definition)
-
+
def get_client(self, **kwargs):
"""
Gets either cluster connection or normal connection based on `cmd` flag.
@@ -143,7 +148,7 @@ def get_client(self, **kwargs):
@contextmanager
def init(self) -> Generator[None, None, None]:
- """ create and destory connections to database.
+ """create and destory connections to database.
Examples:
>>> with self.init():
@@ -170,7 +175,7 @@ def insert_embeddings(
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
- ) -> Tuple[int, Optional[Exception]]:
+ ) -> tuple[int, Exception | None]:
"""Insert embeddings into the database.
Should call self.init() first.
"""
@@ -178,12 +183,15 @@ def insert_embeddings(
try:
with self.conn.pipeline(transaction=False) as pipe:
for i, embedding in enumerate(embeddings):
- embedding = np.array(embedding).astype(np.float32)
- pipe.hset(metadata[i], mapping = {
- "id": str(metadata[i]),
- "metadata": metadata[i],
- "vector": embedding.tobytes(),
- })
+ ndarr_emb = np.array(embedding).astype(np.float32)
+ pipe.hset(
+ metadata[i],
+ mapping={
+ "id": str(metadata[i]),
+ "metadata": metadata[i],
+ "vector": ndarr_emb.tobytes(),
+ },
+ )
# Execute the pipe so we don't keep too much in memory at once
if (i + 1) % self.insert_batch_size == 0:
pipe.execute()
@@ -192,9 +200,9 @@ def insert_embeddings(
result_len = i + 1
except Exception as e:
return 0, e
-
+
return result_len, None
-
+
def _post_insert(self):
"""Wait for indexing to finish"""
client = self.get_client(primary=True)
@@ -208,21 +216,17 @@ def _post_insert(self):
self.wait_until(*args)
log.debug(f"Background indexing completed in the host: {host_name}")
rc.close()
-
- def wait_until(
- self, condition, interval=5, message="Operation took too long", *args
- ):
+
+ def wait_until(self, condition: any, interval: int = 5, message: str = "Operation took too long", *args):
while not condition(*args):
time.sleep(interval)
-
+
def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis):
- return (
- client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
- )
-
+ return client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY"
+
def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis):
return client.execute_command("DBSIZE") == 0
-
+
def search_embedding(
self,
query: list[float],
@@ -230,13 +234,13 @@ def search_embedding(
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
- ) -> (list[int]):
+ ) -> list[int]:
assert self.conn is not None
-
+
query_vector = np.array(query).astype(np.float32).tobytes()
query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
query_params = {"vec": query_vector}
-
+
if filters:
# benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
# gets exact match for id, and range for metadata if they exist in filters
@@ -244,11 +248,19 @@ def search_embedding(
# Removing '>=' from the id_value: '>=10000'
metadata_value = filters.get("metadata")[2:]
if id_value and metadata_value:
- query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
+ query_obj = (
+ Query(
+ f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]",
+ )
+ .return_fields("id")
+ .paging(0, k)
+ )
elif id_value:
- #gets exact match for id
+ # gets exact match for id
query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
- else: #metadata only case, greater than or equal to metadata value
- query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
+ else: # metadata only case, greater than or equal to metadata value
+ query_obj = (
+ Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k)
+ )
res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
- return [int(doc["id"]) for doc in res.docs]
\ No newline at end of file
+ return [int(doc["id"]) for doc in res.docs]
diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py
index 885995def..51ea82eff 100644
--- a/vectordb_bench/backend/clients/milvus/cli.py
+++ b/vectordb_bench/backend/clients/milvus/cli.py
@@ -1,8 +1,9 @@
-from typing import Annotated, TypedDict, Unpack, Optional
+from typing import Annotated, TypedDict, Unpack
import click
from pydantic import SecretStr
+from vectordb_bench.backend.clients import DB
from vectordb_bench.cli.cli import (
CommonTypedDict,
HNSWFlavor3,
@@ -10,33 +11,33 @@
cli,
click_parameter_decorators_from_typed_dict,
run,
-
)
-from vectordb_bench.backend.clients import DB
DBTYPE = DB.Milvus
class MilvusTypedDict(TypedDict):
uri: Annotated[
- str, click.option("--uri", type=str, help="uri connection string", required=True)
+ str,
+ click.option("--uri", type=str, help="uri connection string", required=True),
]
user_name: Annotated[
- Optional[str], click.option("--user-name", type=str, help="Db username", required=False)
+ str | None,
+ click.option("--user-name", type=str, help="Db username", required=False),
]
password: Annotated[
- Optional[str], click.option("--password", type=str, help="Db password", required=False)
+ str | None,
+ click.option("--password", type=str, help="Db password", required=False),
]
-class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict):
- ...
+class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): ...
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict)
def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]):
- from .config import MilvusConfig, AutoIndexConfig
+ from .config import AutoIndexConfig, MilvusConfig
run(
db=DBTYPE,
@@ -54,7 +55,7 @@ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]):
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict)
def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]):
- from .config import MilvusConfig, FLATConfig
+ from .config import FLATConfig, MilvusConfig
run(
db=DBTYPE,
@@ -69,14 +70,13 @@ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]):
)
-class MilvusHNSWTypedDict(CommonTypedDict, MilvusTypedDict, HNSWFlavor3):
- ...
+class MilvusHNSWTypedDict(CommonTypedDict, MilvusTypedDict, HNSWFlavor3): ...
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusHNSWTypedDict)
def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]):
- from .config import MilvusConfig, HNSWConfig
+ from .config import HNSWConfig, MilvusConfig
run(
db=DBTYPE,
@@ -95,14 +95,13 @@ def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]):
)
-class MilvusIVFFlatTypedDict(CommonTypedDict, MilvusTypedDict, IVFFlatTypedDictN):
- ...
+class MilvusIVFFlatTypedDict(CommonTypedDict, MilvusTypedDict, IVFFlatTypedDictN): ...
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict)
def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]):
- from .config import MilvusConfig, IVFFlatConfig
+ from .config import IVFFlatConfig, MilvusConfig
run(
db=DBTYPE,
@@ -123,7 +122,7 @@ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]):
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict)
def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]):
- from .config import MilvusConfig, IVFSQ8Config
+ from .config import IVFSQ8Config, MilvusConfig
run(
db=DBTYPE,
@@ -142,17 +141,13 @@ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]):
class MilvusDISKANNTypedDict(CommonTypedDict, MilvusTypedDict):
- search_list: Annotated[
- str, click.option("--search-list",
- type=int,
- required=True)
- ]
+ search_list: Annotated[str, click.option("--search-list", type=int, required=True)]
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusDISKANNTypedDict)
def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]):
- from .config import MilvusConfig, DISKANNConfig
+ from .config import DISKANNConfig, MilvusConfig
run(
db=DBTYPE,
@@ -171,21 +166,16 @@ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]):
class MilvusGPUIVFTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict):
cache_dataset_on_device: Annotated[
- str, click.option("--cache-dataset-on-device",
- type=str,
- required=True)
- ]
- refine_ratio: Annotated[
- str, click.option("--refine-ratio",
- type=float,
- required=True)
+ str,
+ click.option("--cache-dataset-on-device", type=str, required=True),
]
+ refine_ratio: Annotated[str, click.option("--refine-ratio", type=float, required=True)]
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusGPUIVFTypedDict)
def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
- from .config import MilvusConfig, GPUIVFFlatConfig
+ from .config import GPUIVFFlatConfig, MilvusConfig
run(
db=DBTYPE,
@@ -205,23 +195,20 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]):
)
-class MilvusGPUIVFPQTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict, MilvusGPUIVFTypedDict):
- m: Annotated[
- str, click.option("--m",
- type=int, help="hnsw m",
- required=True)
- ]
- nbits: Annotated[
- str, click.option("--nbits",
- type=int,
- required=True)
- ]
+class MilvusGPUIVFPQTypedDict(
+ CommonTypedDict,
+ MilvusTypedDict,
+ MilvusIVFFlatTypedDict,
+ MilvusGPUIVFTypedDict,
+):
+ m: Annotated[str, click.option("--m", type=int, help="hnsw m", required=True)]
+ nbits: Annotated[str, click.option("--nbits", type=int, required=True)]
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusGPUIVFPQTypedDict)
def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]):
- from .config import MilvusConfig, GPUIVFPQConfig
+ from .config import GPUIVFPQConfig, MilvusConfig
run(
db=DBTYPE,
@@ -245,51 +232,22 @@ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]):
class MilvusGPUCAGRATypedDict(CommonTypedDict, MilvusTypedDict, MilvusGPUIVFTypedDict):
intermediate_graph_degree: Annotated[
- str, click.option("--intermediate-graph-degree",
- type=int,
- required=True)
- ]
- graph_degree: Annotated[
- str, click.option("--graph-degree",
- type=int,
- required=True)
- ]
- build_algo: Annotated[
- str, click.option("--build_algo",
- type=str,
- required=True)
- ]
- team_size: Annotated[
- str, click.option("--team-size",
- type=int,
- required=True)
- ]
- search_width: Annotated[
- str, click.option("--search-width",
- type=int,
- required=True)
- ]
- itopk_size: Annotated[
- str, click.option("--itopk-size",
- type=int,
- required=True)
- ]
- min_iterations: Annotated[
- str, click.option("--min-iterations",
- type=int,
- required=True)
- ]
- max_iterations: Annotated[
- str, click.option("--max-iterations",
- type=int,
- required=True)
+ str,
+ click.option("--intermediate-graph-degree", type=int, required=True),
]
+ graph_degree: Annotated[str, click.option("--graph-degree", type=int, required=True)]
+ build_algo: Annotated[str, click.option("--build_algo", type=str, required=True)]
+ team_size: Annotated[str, click.option("--team-size", type=int, required=True)]
+ search_width: Annotated[str, click.option("--search-width", type=int, required=True)]
+ itopk_size: Annotated[str, click.option("--itopk-size", type=int, required=True)]
+ min_iterations: Annotated[str, click.option("--min-iterations", type=int, required=True)]
+ max_iterations: Annotated[str, click.option("--max-iterations", type=int, required=True)]
@cli.command()
@click_parameter_decorators_from_typed_dict(MilvusGPUCAGRATypedDict)
def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]):
- from .config import MilvusConfig, GPUCAGRAConfig
+ from .config import GPUCAGRAConfig, MilvusConfig
run(
db=DBTYPE,
diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py
index 059ef0461..7d0df803a 100644
--- a/vectordb_bench/backend/clients/milvus/config.py
+++ b/vectordb_bench/backend/clients/milvus/config.py
@@ -1,5 +1,6 @@
from pydantic import BaseModel, SecretStr, validator
-from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
+
+from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
class MilvusConfig(DBConfig):
@@ -15,10 +16,14 @@ def to_dict(self) -> dict:
}
@validator("*")
- def not_empty_field(cls, v, field):
- if field.name in cls.common_short_configs() or field.name in cls.common_long_configs() or field.name in ["user", "password"]:
+ def not_empty_field(cls, v: any, field: any):
+ if (
+ field.name in cls.common_short_configs()
+ or field.name in cls.common_long_configs()
+ or field.name in ["user", "password"]
+ ):
return v
- if isinstance(v, (str, SecretStr)) and len(v) == 0:
+ if isinstance(v, str | SecretStr) and len(v) == 0:
raise ValueError("Empty string!")
return v
@@ -28,10 +33,14 @@ class MilvusIndexConfig(BaseModel):
index: IndexType
metric_type: MetricType | None = None
-
+
@property
def is_gpu_index(self) -> bool:
- return self.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ]
+ return self.index in [
+ IndexType.GPU_CAGRA,
+ IndexType.GPU_IVF_FLAT,
+ IndexType.GPU_IVF_PQ,
+ ]
def parse_metric(self) -> str:
if not self.metric_type:
@@ -113,7 +122,8 @@ def search_param(self) -> dict:
"metric_type": self.parse_metric(),
"params": {"nprobe": self.nprobe},
}
-
+
+
class IVFSQ8Config(MilvusIndexConfig, DBCaseConfig):
nlist: int
nprobe: int | None = None
@@ -210,7 +220,7 @@ class GPUCAGRAConfig(MilvusIndexConfig, DBCaseConfig):
search_width: int = 4
min_iterations: int = 0
max_iterations: int = 0
- build_algo: str = "IVF_PQ" # IVF_PQ; NN_DESCENT;
+ build_algo: str = "IVF_PQ" # IVF_PQ; NN_DESCENT;
cache_dataset_on_device: str
refine_ratio: float | None = None
index: IndexType = IndexType.GPU_CAGRA
diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py
index 251dee8ad..45fe7269b 100644
--- a/vectordb_bench/backend/clients/milvus/milvus.py
+++ b/vectordb_bench/backend/clients/milvus/milvus.py
@@ -2,19 +2,18 @@
import logging
import time
+from collections.abc import Iterable
from contextlib import contextmanager
-from typing import Iterable
-from pymilvus import Collection, utility
-from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException
+from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, utility
from ..api import VectorDB
from .config import MilvusIndexConfig
-
log = logging.getLogger(__name__)
-MILVUS_LOAD_REQS_SIZE = 1.5 * 1024 *1024
+MILVUS_LOAD_REQS_SIZE = 1.5 * 1024 * 1024
+
class Milvus(VectorDB):
def __init__(
@@ -32,7 +31,7 @@ def __init__(
self.db_config = db_config
self.case_config = db_case_config
self.collection_name = collection_name
- self.batch_size = int(MILVUS_LOAD_REQS_SIZE / (dim *4))
+ self.batch_size = int(MILVUS_LOAD_REQS_SIZE / (dim * 4))
self._primary_field = "pk"
self._scalar_field = "id"
@@ -40,6 +39,7 @@ def __init__(
self._index_name = "vector_idx"
from pymilvus import connections
+
connections.connect(**self.db_config, timeout=30)
if drop_old and utility.has_collection(self.collection_name):
log.info(f"{self.name} client drop_old collection: {self.collection_name}")
@@ -49,7 +49,7 @@ def __init__(
fields = [
FieldSchema(self._primary_field, DataType.INT64, is_primary=True),
FieldSchema(self._scalar_field, DataType.INT64),
- FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim)
+ FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim),
]
log.info(f"{self.name} create collection: {self.collection_name}")
@@ -79,6 +79,7 @@ def init(self) -> None:
>>> self.search_embedding()
"""
from pymilvus import connections
+
self.col: Collection | None = None
connections.connect(**self.db_config, timeout=60)
@@ -108,6 +109,7 @@ def _post_insert(self):
)
utility.wait_for_index_building_complete(self.collection_name)
+
def wait_index():
while True:
progress = utility.index_building_progress(self.collection_name)
@@ -120,18 +122,17 @@ def wait_index():
# Skip compaction if use GPU indexType
if self.case_config.is_gpu_index:
log.debug("skip compaction for gpu index type.")
- else :
+ else:
try:
self.col.compact()
self.col.wait_for_compaction_completed()
except Exception as e:
log.warning(f"{self.name} compact error: {e}")
- if hasattr(e, 'code'):
- if e.code().name == 'PERMISSION_DENIED':
+ if hasattr(e, "code"):
+ if e.code().name == "PERMISSION_DENIED":
log.warning("Skip compact due to permission denied.")
- pass
else:
- raise e
+ raise e from e
wait_index()
except Exception as e:
log.warning(f"{self.name} optimize error: {e}")
@@ -156,7 +157,6 @@ def _pre_load(self, coll: Collection):
log.warning(f"{self.name} pre load error: {e}")
raise e from None
-
def optimize(self):
assert self.col, "Please call self.init() before"
self._optimize()
@@ -164,7 +164,7 @@ def optimize(self):
def need_normalize_cosine(self) -> bool:
"""Wheather this database need to normalize dataset to support COSINE"""
if self.case_config.is_gpu_index:
- log.info(f"current gpu_index only supports IP / L2, cosine dataset need normalize.")
+ log.info("current gpu_index only supports IP / L2, cosine dataset need normalize.")
return True
return False
@@ -184,9 +184,9 @@ def insert_embeddings(
for batch_start_offset in range(0, len(embeddings), self.batch_size):
batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
insert_data = [
- metadata[batch_start_offset : batch_end_offset],
- metadata[batch_start_offset : batch_end_offset],
- embeddings[batch_start_offset : batch_end_offset],
+ metadata[batch_start_offset:batch_end_offset],
+ metadata[batch_start_offset:batch_end_offset],
+ embeddings[batch_start_offset:batch_end_offset],
]
res = self.col.insert(insert_data)
insert_count += len(res.primary_keys)
@@ -217,5 +217,4 @@ def search_embedding(
)
# Organize results.
- ret = [result.id for result in res[0]]
- return ret
+ return [result.id for result in res[0]]
diff --git a/vectordb_bench/backend/clients/pgdiskann/cli.py b/vectordb_bench/backend/clients/pgdiskann/cli.py
index 18a9ecbd5..19f47988f 100644
--- a/vectordb_bench/backend/clients/pgdiskann/cli.py
+++ b/vectordb_bench/backend/clients/pgdiskann/cli.py
@@ -1,57 +1,63 @@
-import click
import os
+from typing import Annotated, Unpack
+
+import click
from pydantic import SecretStr
+from vectordb_bench.backend.clients import DB
+
from ....cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
-from typing import Annotated, Optional, Unpack
-from vectordb_bench.backend.clients import DB
class PgDiskAnnTypedDict(CommonTypedDict):
user_name: Annotated[
- str, click.option("--user-name", type=str, help="Db username", required=True)
+ str,
+ click.option("--user-name", type=str, help="Db username", required=True),
]
password: Annotated[
str,
- click.option("--password",
- type=str,
- help="Postgres database password",
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
- show_default="$POSTGRES_PASSWORD",
- ),
+ click.option(
+ "--password",
+ type=str,
+ help="Postgres database password",
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
+ show_default="$POSTGRES_PASSWORD",
+ ),
]
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
- db_name: Annotated[
- str, click.option("--db-name", type=str, help="Db name", required=True)
- ]
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
max_neighbors: Annotated[
int,
click.option(
- "--max-neighbors", type=int, help="PgDiskAnn max neighbors",
+ "--max-neighbors",
+ type=int,
+ help="PgDiskAnn max neighbors",
),
]
l_value_ib: Annotated[
int,
click.option(
- "--l-value-ib", type=int, help="PgDiskAnn l_value_ib",
+ "--l-value-ib",
+ type=int,
+ help="PgDiskAnn l_value_ib",
),
]
l_value_is: Annotated[
float,
click.option(
- "--l-value-is", type=float, help="PgDiskAnn l_value_is",
+ "--l-value-is",
+ type=float,
+ help="PgDiskAnn l_value_is",
),
]
maintenance_work_mem: Annotated[
- Optional[str],
+ str | None,
click.option(
"--maintenance-work-mem",
type=str,
@@ -63,7 +69,7 @@ class PgDiskAnnTypedDict(CommonTypedDict):
),
]
max_parallel_workers: Annotated[
- Optional[int],
+ int | None,
click.option(
"--max-parallel-workers",
type=int,
@@ -72,6 +78,7 @@ class PgDiskAnnTypedDict(CommonTypedDict):
),
]
+
@cli.command()
@click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict)
def PgDiskAnn(
@@ -96,4 +103,4 @@ def PgDiskAnn(
maintenance_work_mem=parameters["maintenance_work_mem"],
),
**parameters,
- )
\ No newline at end of file
+ )
diff --git a/vectordb_bench/backend/clients/pgdiskann/config.py b/vectordb_bench/backend/clients/pgdiskann/config.py
index 970720afa..ed478acc2 100644
--- a/vectordb_bench/backend/clients/pgdiskann/config.py
+++ b/vectordb_bench/backend/clients/pgdiskann/config.py
@@ -1,7 +1,9 @@
from abc import abstractmethod
-from typing import Any, Mapping, Optional, Sequence, TypedDict
+from collections.abc import Mapping, Sequence
+from typing import Any, LiteralString, TypedDict
+
from pydantic import BaseModel, SecretStr
-from typing_extensions import LiteralString
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +11,7 @@
class PgDiskANNConfigDict(TypedDict):
"""These keys will be directly used as kwargs in psycopg connection string,
- so the names must match exactly psycopg API"""
+ so the names must match exactly psycopg API"""
user: str
password: str
@@ -41,44 +43,43 @@ class PgDiskANNIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
create_index_before_load: bool = False
create_index_after_load: bool = True
- maintenance_work_mem: Optional[str]
- max_parallel_workers: Optional[int]
+ maintenance_work_mem: str | None
+ max_parallel_workers: int | None
def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "vector_ip_ops"
return "vector_cosine_ops"
def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.L2:
return "<->"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "<#>"
return "<=>"
def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_distance"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "max_inner_product"
return "cosine_distance"
-
+
@abstractmethod
- def index_param(self) -> dict:
- ...
+ def index_param(self) -> dict: ...
@abstractmethod
- def search_param(self) -> dict:
- ...
+ def search_param(self) -> dict: ...
@abstractmethod
- def session_param(self) -> dict:
- ...
+ def session_param(self) -> dict: ...
@staticmethod
- def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
+ def _optionally_build_with_options(
+ with_options: Mapping[str, Any],
+ ) -> Sequence[dict[str, Any]]:
"""Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause"""
options = []
for option_name, value in with_options.items():
@@ -87,35 +88,36 @@ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[
{
"option_name": option_name,
"val": str(value),
- }
+ },
)
return options
@staticmethod
def _optionally_build_set_options(
- set_mapping: Mapping[str, Any]
+ set_mapping: Mapping[str, Any],
) -> Sequence[dict[str, Any]]:
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
session_options = []
for setting_name, value in set_mapping.items():
if value:
session_options.append(
- {"parameter": {
+ {
+ "parameter": {
"setting_name": setting_name,
"val": str(value),
},
- }
+ },
)
return session_options
-
+
class PgDiskANNImplConfig(PgDiskANNIndexConfig):
index: IndexType = IndexType.DISKANN
max_neighbors: int | None
l_value_ib: int | None
l_value_is: float | None
- maintenance_work_mem: Optional[str] = None
- max_parallel_workers: Optional[int] = None
+ maintenance_work_mem: str | None = None
+ max_parallel_workers: int | None = None
def index_param(self) -> dict:
return {
@@ -128,18 +130,19 @@ def index_param(self) -> dict:
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
}
-
+
def search_param(self) -> dict:
return {
"metric": self.parse_metric(),
"metric_fun_op": self.parse_metric_fun_op(),
}
-
+
def session_param(self) -> dict:
return {
"diskann.l_value_is": self.l_value_is,
}
-
+
+
_pgdiskann_case_config = {
IndexType.DISKANN: PgDiskANNImplConfig,
}
diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py
index c363490f7..c21972902 100644
--- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py
+++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py
@@ -1,9 +1,9 @@
"""Wrapper around the pg_diskann vector database over VectorDB"""
import logging
-import pprint
+from collections.abc import Generator
from contextlib import contextmanager
-from typing import Any, Generator, Optional, Tuple
+from typing import Any
import numpy as np
import psycopg
@@ -44,20 +44,21 @@ def __init__(
self._primary_field = "id"
self._vector_field = "embedding"
- self.conn, self.cursor = self._create_connection(**self.db_config)
+ self.conn, self.cursor = self._create_connection(**self.db_config)
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
if not any(
(
self.case_config.create_index_before_load,
self.case_config.create_index_after_load,
- )
+ ),
):
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
- log.error(err)
- raise RuntimeError(
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
+ msg = (
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
)
+ log.error(msg)
+ raise RuntimeError(msg)
if drop_old:
self._drop_index()
@@ -72,7 +73,7 @@ def __init__(
self.conn = None
@staticmethod
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
conn = psycopg.connect(**kwargs)
conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS pg_diskann CASCADE")
conn.commit()
@@ -101,25 +102,25 @@ def init(self) -> Generator[None, None, None]:
log.debug(command.as_string(self.cursor))
self.cursor.execute(command)
self.conn.commit()
-
+
self._filtered_search = sql.Composed(
[
sql.SQL(
- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
- ).format(table_name=sql.Identifier(self.table_name)),
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
+ ).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
- ]
+ ],
)
self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
- sql.Identifier(self.table_name)
+ sql.Identifier(self.table_name),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
- ]
+ ],
)
try:
@@ -137,8 +138,8 @@ def _drop_table(self):
self.cursor.execute(
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
)
self.conn.commit()
@@ -160,7 +161,7 @@ def _drop_index(self):
log.info(f"{self.name} client drop index : {self._index_name}")
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
- index_name=sql.Identifier(self._index_name)
+ index_name=sql.Identifier(self._index_name),
)
log.debug(drop_index_sql.as_string(self.cursor))
self.cursor.execute(drop_index_sql)
@@ -175,64 +176,53 @@ def _set_parallel_index_build_param(self):
if index_param["maintenance_work_mem"] is not None:
self.cursor.execute(
sql.SQL("SET maintenance_work_mem TO {};").format(
- index_param["maintenance_work_mem"]
- )
+ index_param["maintenance_work_mem"],
+ ),
)
self.cursor.execute(
sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
sql.Identifier(self.db_config["user"]),
index_param["maintenance_work_mem"],
- )
+ ),
)
self.conn.commit()
if index_param["max_parallel_workers"] is not None:
self.cursor.execute(
sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
- index_param["max_parallel_workers"]
- )
+ index_param["max_parallel_workers"],
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
- ).format(
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
sql.Identifier(self.db_config["user"]),
index_param["max_parallel_workers"],
- )
+ ),
)
self.cursor.execute(
sql.SQL("SET max_parallel_workers TO '{}';").format(
- index_param["max_parallel_workers"]
- )
+ index_param["max_parallel_workers"],
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER USER {} SET max_parallel_workers TO '{}';"
- ).format(
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
sql.Identifier(self.db_config["user"]),
index_param["max_parallel_workers"],
- )
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER TABLE {} SET (parallel_workers = {});"
- ).format(
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
sql.Identifier(self.table_name),
index_param["max_parallel_workers"],
- )
+ ),
)
self.conn.commit()
- results = self.cursor.execute(
- sql.SQL("SHOW max_parallel_maintenance_workers;")
- ).fetchall()
- results.extend(
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
- )
- results.extend(
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
- )
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
log.info(f"{self.name} parallel index creation parameters: {results}")
+
def _create_index(self):
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
@@ -248,28 +238,23 @@ def _create_index(self):
sql.SQL("{option_name} = {val}").format(
option_name=sql.Identifier(option_name),
val=sql.Identifier(str(option_val)),
- )
+ ),
)
-
- if any(options):
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
- else:
- with_clause = sql.Composed(())
+
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
index_create_sql = sql.SQL(
"""
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
- """
+ """,
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"].lower()),
embedding_metric=sql.Identifier(index_param["metric"]),
)
- index_create_sql_with_with_clause = (
- index_create_sql + with_clause
- ).join(" ")
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
self.cursor.execute(index_create_sql_with_with_clause)
self.conn.commit()
@@ -283,14 +268,12 @@ def _create_table(self, dim: int):
self.cursor.execute(
sql.SQL(
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
)
self.conn.commit()
except Exception as e:
- log.warning(
- f"Failed to create pgdiskann table: {self.table_name} error: {e}"
- )
+ log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}")
raise e from None
def insert_embeddings(
@@ -298,7 +281,7 @@ def insert_embeddings(
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
- ) -> Tuple[int, Optional[Exception]]:
+ ) -> tuple[int, Exception | None]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
@@ -308,8 +291,8 @@ def insert_embeddings(
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
@@ -321,9 +304,7 @@ def insert_embeddings(
return len(metadata), None
except Exception as e:
- log.warning(
- f"Failed to insert data into table ({self.table_name}), error: {e}"
- )
+ log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}")
return 0, e
def search_embedding(
@@ -340,11 +321,12 @@ def search_embedding(
if filters:
gt = filters.get("id")
result = self.cursor.execute(
- self._filtered_search, (gt, q, k), prepare=True, binary=True
- )
+ self._filtered_search,
+ (gt, q, k),
+ prepare=True,
+ binary=True,
+ )
else:
- result = self.cursor.execute(
- self._unfiltered_search, (q, k), prepare=True, binary=True
- )
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
return [int(i[0]) for i in result.fetchall()]
diff --git a/vectordb_bench/backend/clients/pgvecto_rs/cli.py b/vectordb_bench/backend/clients/pgvecto_rs/cli.py
index 10dbff556..24d2cf800 100644
--- a/vectordb_bench/backend/clients/pgvecto_rs/cli.py
+++ b/vectordb_bench/backend/clients/pgvecto_rs/cli.py
@@ -1,9 +1,11 @@
-from typing import Annotated, Optional, Unpack
+import os
+from typing import Annotated, Unpack
import click
-import os
from pydantic import SecretStr
+from vectordb_bench.backend.clients import DB
+
from ....cli.cli import (
CommonTypedDict,
HNSWFlavor1,
@@ -12,12 +14,12 @@
click_parameter_decorators_from_typed_dict,
run,
)
-from vectordb_bench.backend.clients import DB
class PgVectoRSTypedDict(CommonTypedDict):
user_name: Annotated[
- str, click.option("--user-name", type=str, help="Db username", required=True)
+ str,
+ click.option("--user-name", type=str, help="Db username", required=True),
]
password: Annotated[
str,
@@ -30,14 +32,10 @@ class PgVectoRSTypedDict(CommonTypedDict):
),
]
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
- db_name: Annotated[
- str, click.option("--db-name", type=str, help="Db name", required=True)
- ]
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
max_parallel_workers: Annotated[
- Optional[int],
+ int | None,
click.option(
"--max-parallel-workers",
type=int,
diff --git a/vectordb_bench/backend/clients/pgvecto_rs/config.py b/vectordb_bench/backend/clients/pgvecto_rs/config.py
index c671a236c..fbb7c5d81 100644
--- a/vectordb_bench/backend/clients/pgvecto_rs/config.py
+++ b/vectordb_bench/backend/clients/pgvecto_rs/config.py
@@ -1,11 +1,11 @@
from abc import abstractmethod
from typing import TypedDict
+from pgvecto_rs.types import Flat, Hnsw, IndexOption, Ivf, Quantization
+from pgvecto_rs.types.index import QuantizationRatio, QuantizationType
from pydantic import BaseModel, SecretStr
-from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization
-from pgvecto_rs.types.index import QuantizationType, QuantizationRatio
-from ..api import DBConfig, DBCaseConfig, IndexType, MetricType
+from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -52,14 +52,14 @@ class PgVectoRSIndexConfig(BaseModel, DBCaseConfig):
def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "vector_l2_ops"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "vector_dot_ops"
return "vector_cos_ops"
def parse_metric_fun_op(self) -> str:
if self.metric_type == MetricType.L2:
return "<->"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "<#>"
return "<=>"
@@ -85,9 +85,7 @@ def index_param(self) -> dict[str, str]:
if self.quantization_type is None:
quantization = None
else:
- quantization = Quantization(
- typ=self.quantization_type, ratio=self.quantization_ratio
- )
+ quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
option = IndexOption(
index=Hnsw(
@@ -115,9 +113,7 @@ def index_param(self) -> dict[str, str]:
if self.quantization_type is None:
quantization = None
else:
- quantization = Quantization(
- typ=self.quantization_type, ratio=self.quantization_ratio
- )
+ quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
option = IndexOption(
index=Ivf(nlist=self.lists, quantization=quantization),
@@ -139,9 +135,7 @@ def index_param(self) -> dict[str, str]:
if self.quantization_type is None:
quantization = None
else:
- quantization = Quantization(
- typ=self.quantization_type, ratio=self.quantization_ratio
- )
+ quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio)
option = IndexOption(
index=Flat(
diff --git a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py
index bc042cc57..fc4f17807 100644
--- a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py
+++ b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py
@@ -1,14 +1,14 @@
"""Wrapper around the Pgvecto.rs vector database over VectorDB"""
import logging
-import pprint
+from collections.abc import Generator
from contextlib import contextmanager
-from typing import Any, Generator, Optional, Tuple
+from typing import Any
import numpy as np
import psycopg
-from psycopg import Connection, Cursor, sql
from pgvecto_rs.psycopg import register_vector
+from psycopg import Connection, Cursor, sql
from ..api import VectorDB
from .config import PgVectoRSConfig, PgVectoRSIndexConfig
@@ -33,7 +33,6 @@ def __init__(
drop_old: bool = False,
**kwargs,
):
-
self.name = "PgVectorRS"
self.db_config = db_config
self.case_config = db_case_config
@@ -52,13 +51,14 @@ def __init__(
(
self.case_config.create_index_before_load,
self.case_config.create_index_after_load,
- )
+ ),
):
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
- log.error(err)
- raise RuntimeError(
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
+ msg = (
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
)
+ log.error(msg)
+ raise RuntimeError(msg)
if drop_old:
log.info(f"Pgvecto.rs client drop table : {self.table_name}")
@@ -74,7 +74,7 @@ def __init__(
self.conn = None
@staticmethod
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
conn = psycopg.connect(**kwargs)
# create vector extension
@@ -116,21 +116,21 @@ def init(self) -> Generator[None, None, None]:
self._filtered_search = sql.Composed(
[
sql.SQL(
- "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding "
+ "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ",
).format(table_name=sql.Identifier(self.table_name)),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
- ]
+ ],
)
self._unfiltered_search = sql.Composed(
[
- sql.SQL(
- "SELECT id FROM public.{table_name} ORDER BY embedding "
- ).format(table_name=sql.Identifier(self.table_name)),
+ sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format(
+ table_name=sql.Identifier(self.table_name),
+ ),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
- ]
+ ],
)
try:
@@ -148,8 +148,8 @@ def _drop_table(self):
self.cursor.execute(
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
)
self.conn.commit()
@@ -171,7 +171,7 @@ def _drop_index(self):
log.info(f"{self.name} client drop index : {self._index_name}")
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
- index_name=sql.Identifier(self._index_name)
+ index_name=sql.Identifier(self._index_name),
)
log.debug(drop_index_sql.as_string(self.cursor))
self.cursor.execute(drop_index_sql)
@@ -186,9 +186,9 @@ def _create_index(self):
index_create_sql = sql.SQL(
"""
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING vectors (embedding {embedding_metric}) WITH (options = {index_options})
- """
+ """,
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
@@ -202,7 +202,7 @@ def _create_index(self):
except Exception as e:
log.warning(
f"Failed to create pgvecto.rs index {self._index_name} \
- at table {self.table_name} error: {e}"
+ at table {self.table_name} error: {e}",
)
raise e from None
@@ -214,7 +214,7 @@ def _create_table(self, dim: int):
"""
CREATE TABLE IF NOT EXISTS public.{table_name}
(id BIGINT PRIMARY KEY, embedding vector({dim}))
- """
+ """,
).format(
table_name=sql.Identifier(self.table_name),
dim=dim,
@@ -224,9 +224,7 @@ def _create_table(self, dim: int):
self.cursor.execute(table_create_sql)
self.conn.commit()
except Exception as e:
- log.warning(
- f"Failed to create pgvecto.rs table: {self.table_name} error: {e}"
- )
+ log.warning(f"Failed to create pgvecto.rs table: {self.table_name} error: {e}")
raise e from None
def insert_embeddings(
@@ -234,7 +232,7 @@ def insert_embeddings(
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
- ) -> Tuple[int, Optional[Exception]]:
+ ) -> tuple[int, Exception | None]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
@@ -247,8 +245,8 @@ def insert_embeddings(
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
@@ -261,7 +259,7 @@ def insert_embeddings(
return len(metadata), None
except Exception as e:
log.warning(
- f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}"
+ f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}",
)
return 0, e
@@ -281,12 +279,13 @@ def search_embedding(
log.debug(self._filtered_search.as_string(self.cursor))
gt = filters.get("id")
result = self.cursor.execute(
- self._filtered_search, (gt, q, k), prepare=True, binary=True
+ self._filtered_search,
+ (gt, q, k),
+ prepare=True,
+ binary=True,
)
else:
log.debug(self._unfiltered_search.as_string(self.cursor))
- result = self.cursor.execute(
- self._unfiltered_search, (q, k), prepare=True, binary=True
- )
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
return [int(i[0]) for i in result.fetchall()]
diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py
index ef8914be0..55a462055 100644
--- a/vectordb_bench/backend/clients/pgvector/cli.py
+++ b/vectordb_bench/backend/clients/pgvector/cli.py
@@ -1,9 +1,10 @@
-from typing import Annotated, Optional, TypedDict, Unpack
+import os
+from typing import Annotated, Unpack
import click
-import os
from pydantic import SecretStr
+from vectordb_bench.backend.clients import DB
from vectordb_bench.backend.clients.api import MetricType
from ....cli.cli import (
@@ -15,49 +16,48 @@
get_custom_case_config,
run,
)
-from vectordb_bench.backend.clients import DB
-def set_default_quantized_fetch_limit(ctx, param, value):
+# ruff: noqa
+def set_default_quantized_fetch_limit(ctx: any, param: any, value: any):
if ctx.params.get("reranking") and value is None:
# ef_search is the default value for quantized_fetch_limit as it's bound by ef_search.
# 100 is default value for quantized_fetch_limit for IVFFlat.
- default_value = ctx.params["ef_search"] if ctx.command.name == "pgvectorhnsw" else 100
- return default_value
+ return ctx.params["ef_search"] if ctx.command.name == "pgvectorhnsw" else 100
return value
+
class PgVectorTypedDict(CommonTypedDict):
user_name: Annotated[
- str, click.option("--user-name", type=str, help="Db username", required=True)
+ str,
+ click.option("--user-name", type=str, help="Db username", required=True),
]
password: Annotated[
str,
- click.option("--password",
- type=str,
- help="Postgres database password",
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
- show_default="$POSTGRES_PASSWORD",
- ),
+ click.option(
+ "--password",
+ type=str,
+ help="Postgres database password",
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
+ show_default="$POSTGRES_PASSWORD",
+ ),
]
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
- port: Annotated[
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
+ port: Annotated[
int,
- click.option("--port",
- type=int,
- help="Postgres database port",
- default=5432,
- show_default=True,
- required=False
- ),
- ]
- db_name: Annotated[
- str, click.option("--db-name", type=str, help="Db name", required=True)
+ click.option(
+ "--port",
+ type=int,
+ help="Postgres database port",
+ default=5432,
+ show_default=True,
+ required=False,
+ ),
]
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
maintenance_work_mem: Annotated[
- Optional[str],
+ str | None,
click.option(
"--maintenance-work-mem",
type=str,
@@ -69,7 +69,7 @@ class PgVectorTypedDict(CommonTypedDict):
),
]
max_parallel_workers: Annotated[
- Optional[int],
+ int | None,
click.option(
"--max-parallel-workers",
type=int,
@@ -78,7 +78,7 @@ class PgVectorTypedDict(CommonTypedDict):
),
]
quantization_type: Annotated[
- Optional[str],
+ str | None,
click.option(
"--quantization-type",
type=click.Choice(["none", "bit", "halfvec"]),
@@ -87,7 +87,7 @@ class PgVectorTypedDict(CommonTypedDict):
),
]
reranking: Annotated[
- Optional[bool],
+ bool | None,
click.option(
"--reranking/--skip-reranking",
type=bool,
@@ -96,11 +96,11 @@ class PgVectorTypedDict(CommonTypedDict):
),
]
reranking_metric: Annotated[
- Optional[str],
+ str | None,
click.option(
"--reranking-metric",
type=click.Choice(
- [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]]
+ [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]],
),
help="Distance metric for reranking",
default="COSINE",
@@ -108,7 +108,7 @@ class PgVectorTypedDict(CommonTypedDict):
),
]
quantized_fetch_limit: Annotated[
- Optional[int],
+ int | None,
click.option(
"--quantized-fetch-limit",
type=int,
@@ -116,13 +116,11 @@ class PgVectorTypedDict(CommonTypedDict):
-- bound by ef_search",
required=False,
callback=set_default_quantized_fetch_limit,
- )
+ ),
]
-
-class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict):
- ...
+class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): ...
@cli.command()
@@ -156,8 +154,7 @@ def PgVectorIVFFlat(
)
-class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1):
- ...
+class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1): ...
@cli.command()
diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py
index 16d547445..c386d75ef 100644
--- a/vectordb_bench/backend/clients/pgvector/config.py
+++ b/vectordb_bench/backend/clients/pgvector/config.py
@@ -1,7 +1,9 @@
from abc import abstractmethod
-from typing import Any, Mapping, Optional, Sequence, TypedDict
+from collections.abc import Mapping, Sequence
+from typing import Any, LiteralString, TypedDict
+
from pydantic import BaseModel, SecretStr
-from typing_extensions import LiteralString
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +11,7 @@
class PgVectorConfigDict(TypedDict):
"""These keys will be directly used as kwargs in psycopg connection string,
- so the names must match exactly psycopg API"""
+ so the names must match exactly psycopg API"""
user: str
password: str
@@ -41,8 +43,8 @@ class PgVectorIndexParam(TypedDict):
metric: str
index_type: str
index_creation_with_options: Sequence[dict[str, Any]]
- maintenance_work_mem: Optional[str]
- max_parallel_workers: Optional[int]
+ maintenance_work_mem: str | None
+ max_parallel_workers: int | None
class PgVectorSearchParam(TypedDict):
@@ -59,61 +61,60 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig):
create_index_after_load: bool = True
def parse_metric(self) -> str:
- if self.quantization_type == "halfvec":
- if self.metric_type == MetricType.L2:
- return "halfvec_l2_ops"
- elif self.metric_type == MetricType.IP:
- return "halfvec_ip_ops"
- return "halfvec_cosine_ops"
- elif self.quantization_type == "bit":
- if self.metric_type == MetricType.JACCARD:
- return "bit_jaccard_ops"
- return "bit_hamming_ops"
- else:
- if self.metric_type == MetricType.L2:
- return "vector_l2_ops"
- elif self.metric_type == MetricType.IP:
- return "vector_ip_ops"
- return "vector_cosine_ops"
+ d = {
+ "halfvec": {
+ MetricType.L2: "halfvec_l2_ops",
+ MetricType.IP: "halfvec_ip_ops",
+ MetricType.COSINE: "halfvec_cosine_ops",
+ },
+ "bit": {
+ MetricType.JACCARD: "bit_jaccard_ops",
+ MetricType.HAMMING: "bit_hamming_ops",
+ },
+ "_fallback": {
+ MetricType.L2: "vector_l2_ops",
+ MetricType.IP: "vector_ip_ops",
+ MetricType.COSINE: "vector_cosine_ops",
+ },
+ }
+
+ if d.get(self.quantization_type) is None:
+ return d.get("_fallback").get(self.metric_type)
+ return d.get(self.quantization_type).get(self.metric_type)
def parse_metric_fun_op(self) -> LiteralString:
if self.quantization_type == "bit":
if self.metric_type == MetricType.JACCARD:
return "<%>"
return "<~>"
- else:
- if self.metric_type == MetricType.L2:
- return "<->"
- elif self.metric_type == MetricType.IP:
- return "<#>"
- return "<=>"
+ if self.metric_type == MetricType.L2:
+ return "<->"
+ if self.metric_type == MetricType.IP:
+ return "<#>"
+ return "<=>"
def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.L2:
return "l2_distance"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "max_inner_product"
return "cosine_distance"
-
+
def parse_reranking_metric_fun_op(self) -> LiteralString:
if self.reranking_metric == MetricType.L2:
return "<->"
- elif self.reranking_metric == MetricType.IP:
+ if self.reranking_metric == MetricType.IP:
return "<#>"
return "<=>"
-
@abstractmethod
- def index_param(self) -> PgVectorIndexParam:
- ...
+ def index_param(self) -> PgVectorIndexParam: ...
@abstractmethod
- def search_param(self) -> PgVectorSearchParam:
- ...
+ def search_param(self) -> PgVectorSearchParam: ...
@abstractmethod
- def session_param(self) -> PgVectorSessionCommands:
- ...
+ def session_param(self) -> PgVectorSessionCommands: ...
@staticmethod
def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
@@ -125,24 +126,23 @@ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[
{
"option_name": option_name,
"val": str(value),
- }
+ },
)
return options
@staticmethod
- def _optionally_build_set_options(
- set_mapping: Mapping[str, Any]
- ) -> Sequence[dict[str, Any]]:
+ def _optionally_build_set_options(set_mapping: Mapping[str, Any]) -> Sequence[dict[str, Any]]:
"""Walk through options, creating 'SET 'key1 = "value1";' list"""
session_options = []
for setting_name, value in set_mapping.items():
if value:
session_options.append(
- {"parameter": {
+ {
+ "parameter": {
"setting_name": setting_name,
"val": str(value),
},
- }
+ },
)
return session_options
@@ -165,12 +165,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig):
lists: int | None
probes: int | None
index: IndexType = IndexType.ES_IVFFlat
- maintenance_work_mem: Optional[str] = None
- max_parallel_workers: Optional[int] = None
- quantization_type: Optional[str] = None
- reranking: Optional[bool] = None
- quantized_fetch_limit: Optional[int] = None
- reranking_metric: Optional[str] = None
+ maintenance_work_mem: str | None = None
+ max_parallel_workers: int | None = None
+ quantization_type: str | None = None
+ reranking: bool | None = None
+ quantized_fetch_limit: int | None = None
+ reranking_metric: str | None = None
def index_param(self) -> PgVectorIndexParam:
index_parameters = {"lists": self.lists}
@@ -179,9 +179,7 @@ def index_param(self) -> PgVectorIndexParam:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
- "index_creation_with_options": self._optionally_build_with_options(
- index_parameters
- ),
+ "index_creation_with_options": self._optionally_build_with_options(index_parameters),
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
@@ -197,9 +195,7 @@ def search_param(self) -> PgVectorSearchParam:
def session_param(self) -> PgVectorSessionCommands:
session_parameters = {"ivfflat.probes": self.probes}
- return {
- "session_options": self._optionally_build_set_options(session_parameters)
- }
+ return {"session_options": self._optionally_build_set_options(session_parameters)}
class PgVectorHNSWConfig(PgVectorIndexConfig):
@@ -210,17 +206,15 @@ class PgVectorHNSWConfig(PgVectorIndexConfig):
"""
m: int | None # DETAIL: Valid values are between "2" and "100".
- ef_construction: (
- int | None
- ) # ef_construction must be greater than or equal to 2 * m
+ ef_construction: int | None # ef_construction must be greater than or equal to 2 * m
ef_search: int | None
index: IndexType = IndexType.ES_HNSW
- maintenance_work_mem: Optional[str] = None
- max_parallel_workers: Optional[int] = None
- quantization_type: Optional[str] = None
- reranking: Optional[bool] = None
- quantized_fetch_limit: Optional[int] = None
- reranking_metric: Optional[str] = None
+ maintenance_work_mem: str | None = None
+ max_parallel_workers: int | None = None
+ quantization_type: str | None = None
+ reranking: bool | None = None
+ quantized_fetch_limit: int | None = None
+ reranking_metric: str | None = None
def index_param(self) -> PgVectorIndexParam:
index_parameters = {"m": self.m, "ef_construction": self.ef_construction}
@@ -229,9 +223,7 @@ def index_param(self) -> PgVectorIndexParam:
return {
"metric": self.parse_metric(),
"index_type": self.index.value,
- "index_creation_with_options": self._optionally_build_with_options(
- index_parameters
- ),
+ "index_creation_with_options": self._optionally_build_with_options(index_parameters),
"maintenance_work_mem": self.maintenance_work_mem,
"max_parallel_workers": self.max_parallel_workers,
"quantization_type": self.quantization_type,
@@ -247,13 +239,11 @@ def search_param(self) -> PgVectorSearchParam:
def session_param(self) -> PgVectorSessionCommands:
session_parameters = {"hnsw.ef_search": self.ef_search}
- return {
- "session_options": self._optionally_build_set_options(session_parameters)
- }
+ return {"session_options": self._optionally_build_set_options(session_parameters)}
_pgvector_case_config = {
- IndexType.HNSW: PgVectorHNSWConfig,
- IndexType.ES_HNSW: PgVectorHNSWConfig,
- IndexType.IVFFlat: PgVectorIVFFlatConfig,
+ IndexType.HNSW: PgVectorHNSWConfig,
+ IndexType.ES_HNSW: PgVectorHNSWConfig,
+ IndexType.IVFFlat: PgVectorIVFFlatConfig,
}
diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py
index 069b89381..62a7971bb 100644
--- a/vectordb_bench/backend/clients/pgvector/pgvector.py
+++ b/vectordb_bench/backend/clients/pgvector/pgvector.py
@@ -1,9 +1,9 @@
"""Wrapper around the Pgvector vector database over VectorDB"""
import logging
-import pprint
+from collections.abc import Generator, Sequence
from contextlib import contextmanager
-from typing import Any, Generator, Optional, Tuple, Sequence
+from typing import Any
import numpy as np
import psycopg
@@ -11,7 +11,7 @@
from psycopg import Connection, Cursor, sql
from ..api import VectorDB
-from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig
+from .config import PgVectorConfigDict, PgVectorIndexConfig
log = logging.getLogger(__name__)
@@ -56,13 +56,14 @@ def __init__(
(
self.case_config.create_index_before_load,
self.case_config.create_index_after_load,
- )
+ ),
):
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
- log.error(err)
- raise RuntimeError(
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
+ msg = (
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
)
+ log.error(msg)
+ raise RuntimeError(msg)
if drop_old:
self._drop_index()
@@ -77,7 +78,7 @@ def __init__(
self.conn = None
@staticmethod
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
conn = psycopg.connect(**kwargs)
register_vector(conn)
conn.autocommit = False
@@ -87,8 +88,8 @@ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
assert cursor is not None, "Cursor is not initialized"
return conn, cursor
-
- def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
+
+ def _generate_search_query(self, filtered: bool = False) -> sql.Composed:
index_param = self.case_config.index_param()
reranking = self.case_config.search_param()["reranking"]
column_name = (
@@ -103,23 +104,25 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
)
# The following sections assume that the quantization_type value matches the quantization function name
- if index_param["quantization_type"] != None:
+ if index_param["quantization_type"] is not None:
if index_param["quantization_type"] == "bit" and reranking:
# Embeddings needs to be passed to binary_quantize function if quantization_type is bit
search_query = sql.Composed(
[
sql.SQL(
"""
- SELECT i.id
+ SELECT i.id
FROM (
- SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
+ SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance
FROM public.{table_name} {where_clause}
ORDER BY {column_name}::{quantization_type}({dim})
- """
+ """,
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
- reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]),
+ reranking_metric_fun_op=sql.SQL(
+ self.case_config.search_param()["reranking_metric_fun_op"],
+ ),
quantization_type=sql.SQL(index_param["quantization_type"]),
dim=sql.Literal(self.dim),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
@@ -127,25 +130,28 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(
"""
- {search_vector}
+ {search_vector}
LIMIT {quantized_fetch_limit}
) i
- ORDER BY i.distance
+ ORDER BY i.distance
LIMIT %s::int
- """
+ """,
).format(
search_vector=search_vector,
quantized_fetch_limit=sql.Literal(
- self.case_config.search_param()["quantized_fetch_limit"]
+ self.case_config.search_param()["quantized_fetch_limit"],
),
),
- ]
+ ],
)
else:
search_query = sql.Composed(
[
sql.SQL(
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) "
+ """
+ SELECT id FROM public.{table_name}
+ {where_clause} ORDER BY {column_name}::{quantization_type}({dim})
+ """,
).format(
table_name=sql.Identifier(self.table_name),
column_name=column_name,
@@ -154,25 +160,26 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed:
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
- sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector),
- ]
+ sql.SQL(" {search_vector} LIMIT %s::int").format(
+ search_vector=search_vector,
+ ),
+ ],
)
else:
search_query = sql.Composed(
[
sql.SQL(
- "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding "
+ "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ",
).format(
table_name=sql.Identifier(self.table_name),
where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
- ]
+ ],
)
-
+
return search_query
-
@contextmanager
def init(self) -> Generator[None, None, None]:
@@ -191,8 +198,8 @@ def init(self) -> Generator[None, None, None]:
if len(session_options) > 0:
for setting in session_options:
command = sql.SQL("SET {setting_name} " + "= {val};").format(
- setting_name=sql.Identifier(setting['parameter']['setting_name']),
- val=sql.Identifier(str(setting['parameter']['val'])),
+ setting_name=sql.Identifier(setting["parameter"]["setting_name"]),
+ val=sql.Identifier(str(setting["parameter"]["val"])),
)
log.debug(command.as_string(self.cursor))
self.cursor.execute(command)
@@ -216,8 +223,8 @@ def _drop_table(self):
self.cursor.execute(
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
)
self.conn.commit()
@@ -239,7 +246,7 @@ def _drop_index(self):
log.info(f"{self.name} client drop index : {self._index_name}")
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
- index_name=sql.Identifier(self._index_name)
+ index_name=sql.Identifier(self._index_name),
)
log.debug(drop_index_sql.as_string(self.cursor))
self.cursor.execute(drop_index_sql)
@@ -254,63 +261,51 @@ def _set_parallel_index_build_param(self):
if index_param["maintenance_work_mem"] is not None:
self.cursor.execute(
sql.SQL("SET maintenance_work_mem TO {};").format(
- index_param["maintenance_work_mem"]
- )
+ index_param["maintenance_work_mem"],
+ ),
)
self.cursor.execute(
sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format(
sql.Identifier(self.db_config["user"]),
index_param["maintenance_work_mem"],
- )
+ ),
)
self.conn.commit()
if index_param["max_parallel_workers"] is not None:
self.cursor.execute(
sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format(
- index_param["max_parallel_workers"]
- )
+ index_param["max_parallel_workers"],
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';"
- ).format(
+ sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format(
sql.Identifier(self.db_config["user"]),
index_param["max_parallel_workers"],
- )
+ ),
)
self.cursor.execute(
sql.SQL("SET max_parallel_workers TO '{}';").format(
- index_param["max_parallel_workers"]
- )
+ index_param["max_parallel_workers"],
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER USER {} SET max_parallel_workers TO '{}';"
- ).format(
+ sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format(
sql.Identifier(self.db_config["user"]),
index_param["max_parallel_workers"],
- )
+ ),
)
self.cursor.execute(
- sql.SQL(
- "ALTER TABLE {} SET (parallel_workers = {});"
- ).format(
+ sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format(
sql.Identifier(self.table_name),
index_param["max_parallel_workers"],
- )
+ ),
)
self.conn.commit()
- results = self.cursor.execute(
- sql.SQL("SHOW max_parallel_maintenance_workers;")
- ).fetchall()
- results.extend(
- self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()
- )
- results.extend(
- self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()
- )
+ results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall()
+ results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall())
+ results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall())
log.info(f"{self.name} parallel index creation parameters: {results}")
def _create_index(self):
@@ -322,24 +317,21 @@ def _create_index(self):
self._set_parallel_index_build_param()
options = []
for option in index_param["index_creation_with_options"]:
- if option['val'] is not None:
+ if option["val"] is not None:
options.append(
sql.SQL("{option_name} = {val}").format(
- option_name=sql.Identifier(option['option_name']),
- val=sql.Identifier(str(option['val'])),
- )
+ option_name=sql.Identifier(option["option_name"]),
+ val=sql.Identifier(str(option["val"])),
+ ),
)
- if any(options):
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
- else:
- with_clause = sql.Composed(())
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
- if index_param["quantization_type"] != None:
+ if index_param["quantization_type"] is not None:
index_create_sql = sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric})
- """
+ """,
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
@@ -357,9 +349,9 @@ def _create_index(self):
else:
index_create_sql = sql.SQL(
"""
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
- """
+ """,
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
@@ -367,9 +359,7 @@ def _create_index(self):
embedding_metric=sql.Identifier(index_param["metric"]),
)
- index_create_sql_with_with_clause = (
- index_create_sql + with_clause
- ).join(" ")
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
self.cursor.execute(index_create_sql_with_with_clause)
self.conn.commit()
@@ -384,19 +374,17 @@ def _create_table(self, dim: int):
# create table
self.cursor.execute(
sql.SQL(
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
)
self.cursor.execute(
sql.SQL(
- "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;"
- ).format(table_name=sql.Identifier(self.table_name))
+ "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;",
+ ).format(table_name=sql.Identifier(self.table_name)),
)
self.conn.commit()
except Exception as e:
- log.warning(
- f"Failed to create pgvector table: {self.table_name} error: {e}"
- )
+ log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}")
raise e from None
def insert_embeddings(
@@ -404,7 +392,7 @@ def insert_embeddings(
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
- ) -> Tuple[int, Optional[Exception]]:
+ ) -> tuple[int, Exception | None]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
@@ -414,8 +402,8 @@ def insert_embeddings(
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
@@ -428,7 +416,7 @@ def insert_embeddings(
return len(metadata), None
except Exception as e:
log.warning(
- f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
+ f"Failed to insert data into pgvector table ({self.table_name}), error: {e}",
)
return 0, e
@@ -449,21 +437,32 @@ def search_embedding(
gt = filters.get("id")
if index_param["quantization_type"] == "bit" and search_param["reranking"]:
result = self.cursor.execute(
- self._filtered_search, (q, gt, q, k), prepare=True, binary=True
+ self._filtered_search,
+ (q, gt, q, k),
+ prepare=True,
+ binary=True,
)
else:
result = self.cursor.execute(
- self._filtered_search, (gt, q, k), prepare=True, binary=True
+ self._filtered_search,
+ (gt, q, k),
+ prepare=True,
+ binary=True,
)
-
+
+ elif index_param["quantization_type"] == "bit" and search_param["reranking"]:
+ result = self.cursor.execute(
+ self._unfiltered_search,
+ (q, q, k),
+ prepare=True,
+ binary=True,
+ )
else:
- if index_param["quantization_type"] == "bit" and search_param["reranking"]:
- result = self.cursor.execute(
- self._unfiltered_search, (q, q, k), prepare=True, binary=True
- )
- else:
- result = self.cursor.execute(
- self._unfiltered_search, (q, k), prepare=True, binary=True
- )
+ result = self.cursor.execute(
+ self._unfiltered_search,
+ (q, k),
+ prepare=True,
+ binary=True,
+ )
return [int(i[0]) for i in result.fetchall()]
diff --git a/vectordb_bench/backend/clients/pgvectorscale/cli.py b/vectordb_bench/backend/clients/pgvectorscale/cli.py
index e5a161c6b..fca9d51d8 100644
--- a/vectordb_bench/backend/clients/pgvectorscale/cli.py
+++ b/vectordb_bench/backend/clients/pgvectorscale/cli.py
@@ -1,80 +1,94 @@
-import click
import os
+from typing import Annotated, Unpack
+
+import click
from pydantic import SecretStr
+from vectordb_bench.backend.clients import DB
+
from ....cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
-from typing import Annotated, Unpack
-from vectordb_bench.backend.clients import DB
class PgVectorScaleTypedDict(CommonTypedDict):
user_name: Annotated[
- str, click.option("--user-name", type=str, help="Db username", required=True)
+ str,
+ click.option("--user-name", type=str, help="Db username", required=True),
]
password: Annotated[
str,
- click.option("--password",
- type=str,
- help="Postgres database password",
- default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
- show_default="$POSTGRES_PASSWORD",
- ),
+ click.option(
+ "--password",
+ type=str,
+ help="Postgres database password",
+ default=lambda: os.environ.get("POSTGRES_PASSWORD", ""),
+ show_default="$POSTGRES_PASSWORD",
+ ),
]
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
- db_name: Annotated[
- str, click.option("--db-name", type=str, help="Db name", required=True)
- ]
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
+ db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)]
class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict):
storage_layout: Annotated[
str,
click.option(
- "--storage-layout", type=str, help="Streaming DiskANN storage layout",
+ "--storage-layout",
+ type=str,
+ help="Streaming DiskANN storage layout",
),
]
num_neighbors: Annotated[
int,
click.option(
- "--num-neighbors", type=int, help="Streaming DiskANN num neighbors",
+ "--num-neighbors",
+ type=int,
+ help="Streaming DiskANN num neighbors",
),
]
search_list_size: Annotated[
int,
click.option(
- "--search-list-size", type=int, help="Streaming DiskANN search list size",
+ "--search-list-size",
+ type=int,
+ help="Streaming DiskANN search list size",
),
]
max_alpha: Annotated[
float,
click.option(
- "--max-alpha", type=float, help="Streaming DiskANN max alpha",
+ "--max-alpha",
+ type=float,
+ help="Streaming DiskANN max alpha",
),
]
num_dimensions: Annotated[
int,
click.option(
- "--num-dimensions", type=int, help="Streaming DiskANN num dimensions",
+ "--num-dimensions",
+ type=int,
+ help="Streaming DiskANN num dimensions",
),
]
query_search_list_size: Annotated[
int,
click.option(
- "--query-search-list-size", type=int, help="Streaming DiskANN query search list size",
+ "--query-search-list-size",
+ type=int,
+ help="Streaming DiskANN query search list size",
),
]
query_rescore: Annotated[
int,
click.option(
- "--query-rescore", type=int, help="Streaming DiskANN query rescore",
+ "--query-rescore",
+ type=int,
+ help="Streaming DiskANN query rescore",
),
]
@@ -105,4 +119,4 @@ def PgVectorScaleDiskAnn(
query_rescore=parameters["query_rescore"],
),
**parameters,
- )
\ No newline at end of file
+ )
diff --git a/vectordb_bench/backend/clients/pgvectorscale/config.py b/vectordb_bench/backend/clients/pgvectorscale/config.py
index bd9f6106b..e22c45c8d 100644
--- a/vectordb_bench/backend/clients/pgvectorscale/config.py
+++ b/vectordb_bench/backend/clients/pgvectorscale/config.py
@@ -1,7 +1,8 @@
from abc import abstractmethod
-from typing import TypedDict
+from typing import LiteralString, TypedDict
+
from pydantic import BaseModel, SecretStr
-from typing_extensions import LiteralString
+
from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s"
@@ -9,7 +10,7 @@
class PgVectorScaleConfigDict(TypedDict):
"""These keys will be directly used as kwargs in psycopg connection string,
- so the names must match exactly psycopg API"""
+ so the names must match exactly psycopg API"""
user: str
password: str
@@ -46,7 +47,7 @@ def parse_metric(self) -> str:
if self.metric_type == MetricType.COSINE:
return "vector_cosine_ops"
return ""
-
+
def parse_metric_fun_op(self) -> LiteralString:
if self.metric_type == MetricType.COSINE:
return "<=>"
@@ -56,19 +57,16 @@ def parse_metric_fun_str(self) -> str:
if self.metric_type == MetricType.COSINE:
return "cosine_distance"
return ""
-
+
@abstractmethod
- def index_param(self) -> dict:
- ...
+ def index_param(self) -> dict: ...
@abstractmethod
- def search_param(self) -> dict:
- ...
+ def search_param(self) -> dict: ...
@abstractmethod
- def session_param(self) -> dict:
- ...
-
+ def session_param(self) -> dict: ...
+
class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig):
index: IndexType = IndexType.STREAMING_DISKANN
@@ -93,19 +91,20 @@ def index_param(self) -> dict:
"num_dimensions": self.num_dimensions,
},
}
-
+
def search_param(self) -> dict:
return {
"metric": self.parse_metric(),
"metric_fun_op": self.parse_metric_fun_op(),
}
-
+
def session_param(self) -> dict:
return {
"diskann.query_search_list_size": self.query_search_list_size,
"diskann.query_rescore": self.query_rescore,
}
-
+
+
_pgvectorscale_case_config = {
IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig,
}
diff --git a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py
index d8f26394c..981accc2e 100644
--- a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py
+++ b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py
@@ -1,9 +1,9 @@
"""Wrapper around the Pgvectorscale vector database over VectorDB"""
import logging
-import pprint
+from collections.abc import Generator
from contextlib import contextmanager
-from typing import Any, Generator, Optional, Tuple
+from typing import Any
import numpy as np
import psycopg
@@ -44,20 +44,21 @@ def __init__(
self._primary_field = "id"
self._vector_field = "embedding"
- self.conn, self.cursor = self._create_connection(**self.db_config)
+ self.conn, self.cursor = self._create_connection(**self.db_config)
log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}")
if not any(
(
self.case_config.create_index_before_load,
self.case_config.create_index_after_load,
- )
+ ),
):
- err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
- log.error(err)
- raise RuntimeError(
- f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}"
+ msg = (
+ f"{self.name} config must create an index using create_index_before_load or create_index_after_load"
+ f"{self.name} config values: {self.db_config}\n{self.case_config}"
)
+ log.error(msg)
+ raise RuntimeError(msg)
if drop_old:
self._drop_index()
@@ -72,7 +73,7 @@ def __init__(
self.conn = None
@staticmethod
- def _create_connection(**kwargs) -> Tuple[Connection, Cursor]:
+ def _create_connection(**kwargs) -> tuple[Connection, Cursor]:
conn = psycopg.connect(**kwargs)
conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE")
conn.commit()
@@ -101,25 +102,25 @@ def init(self) -> Generator[None, None, None]:
log.debug(command.as_string(self.cursor))
self.cursor.execute(command)
self.conn.commit()
-
+
self._filtered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format(
sql.Identifier(self.table_name),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
- sql.SQL(" %s::vector LIMIT %s::int")
- ]
+ sql.SQL(" %s::vector LIMIT %s::int"),
+ ],
)
-
+
self._unfiltered_search = sql.Composed(
[
sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format(
- sql.Identifier(self.table_name)
+ sql.Identifier(self.table_name),
),
sql.SQL(self.case_config.search_param()["metric_fun_op"]),
sql.SQL(" %s::vector LIMIT %s::int"),
- ]
+ ],
)
try:
@@ -137,8 +138,8 @@ def _drop_table(self):
self.cursor.execute(
sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
)
self.conn.commit()
@@ -160,7 +161,7 @@ def _drop_index(self):
log.info(f"{self.name} client drop index : {self._index_name}")
drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format(
- index_name=sql.Identifier(self._index_name)
+ index_name=sql.Identifier(self._index_name),
)
log.debug(drop_index_sql.as_string(self.cursor))
self.cursor.execute(drop_index_sql)
@@ -180,36 +181,31 @@ def _create_index(self):
sql.SQL("{option_name} = {val}").format(
option_name=sql.Identifier(option_name),
val=sql.Identifier(str(option_val)),
- )
+ ),
)
-
+
num_bits_per_dimension = "2" if self.dim < 900 else "1"
options.append(
sql.SQL("{option_name} = {val}").format(
option_name=sql.Identifier("num_bits_per_dimension"),
val=sql.Identifier(num_bits_per_dimension),
- )
+ ),
)
- if any(options):
- with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options))
- else:
- with_clause = sql.Composed(())
+ with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(())
index_create_sql = sql.SQL(
"""
- CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
+ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name}
USING {index_type} (embedding {embedding_metric})
- """
+ """,
).format(
index_name=sql.Identifier(self._index_name),
table_name=sql.Identifier(self.table_name),
index_type=sql.Identifier(index_param["index_type"].lower()),
embedding_metric=sql.Identifier(index_param["metric"]),
)
- index_create_sql_with_with_clause = (
- index_create_sql + with_clause
- ).join(" ")
+ index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ")
log.debug(index_create_sql_with_with_clause.as_string(self.cursor))
self.cursor.execute(index_create_sql_with_with_clause)
self.conn.commit()
@@ -223,14 +219,12 @@ def _create_table(self, dim: int):
self.cursor.execute(
sql.SQL(
- "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));"
- ).format(table_name=sql.Identifier(self.table_name), dim=dim)
+ "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));",
+ ).format(table_name=sql.Identifier(self.table_name), dim=dim),
)
self.conn.commit()
except Exception as e:
- log.warning(
- f"Failed to create pgvectorscale table: {self.table_name} error: {e}"
- )
+ log.warning(f"Failed to create pgvectorscale table: {self.table_name} error: {e}")
raise e from None
def insert_embeddings(
@@ -238,7 +232,7 @@ def insert_embeddings(
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
- ) -> Tuple[int, Optional[Exception]]:
+ ) -> tuple[int, Exception | None]:
assert self.conn is not None, "Connection is not initialized"
assert self.cursor is not None, "Cursor is not initialized"
@@ -248,8 +242,8 @@ def insert_embeddings(
with self.cursor.copy(
sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format(
- table_name=sql.Identifier(self.table_name)
- )
+ table_name=sql.Identifier(self.table_name),
+ ),
) as copy:
copy.set_types(["bigint", "vector"])
for i, row in enumerate(metadata_arr):
@@ -262,7 +256,7 @@ def insert_embeddings(
return len(metadata), None
except Exception as e:
log.warning(
- f"Failed to insert data into pgvector table ({self.table_name}), error: {e}"
+ f"Failed to insert data into pgvector table ({self.table_name}), error: {e}",
)
return 0, e
@@ -280,11 +274,12 @@ def search_embedding(
if filters:
gt = filters.get("id")
result = self.cursor.execute(
- self._filtered_search, (gt, q, k), prepare=True, binary=True
+ self._filtered_search,
+ (gt, q, k),
+ prepare=True,
+ binary=True,
)
else:
- result = self.cursor.execute(
- self._unfiltered_search, (q, k), prepare=True, binary=True
- )
+ result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True)
return [int(i[0]) for i in result.fetchall()]
diff --git a/vectordb_bench/backend/clients/pinecone/config.py b/vectordb_bench/backend/clients/pinecone/config.py
index 2bbcbb350..fe1a039ed 100644
--- a/vectordb_bench/backend/clients/pinecone/config.py
+++ b/vectordb_bench/backend/clients/pinecone/config.py
@@ -1,4 +1,5 @@
from pydantic import SecretStr
+
from ..api import DBConfig
diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py
index c1351f7a9..c59ee8760 100644
--- a/vectordb_bench/backend/clients/pinecone/pinecone.py
+++ b/vectordb_bench/backend/clients/pinecone/pinecone.py
@@ -2,11 +2,11 @@
import logging
from contextlib import contextmanager
-from typing import Type
+
import pinecone
-from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType
-from .config import PineconeConfig
+from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB
+from .config import PineconeConfig
log = logging.getLogger(__name__)
@@ -17,7 +17,7 @@
class Pinecone(VectorDB):
def __init__(
self,
- dim,
+ dim: int,
db_config: dict,
db_case_config: DBCaseConfig,
drop_old: bool = False,
@@ -27,7 +27,7 @@ def __init__(
self.index_name = db_config.get("index_name", "")
self.api_key = db_config.get("api_key", "")
self.batch_size = int(
- min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH)
+ min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH),
)
pc = pinecone.Pinecone(api_key=self.api_key)
@@ -37,9 +37,8 @@ def __init__(
index_stats = index.describe_index_stats()
index_dim = index_stats["dimension"]
if index_dim != dim:
- raise ValueError(
- f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
- )
+ msg = f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}"
+ raise ValueError(msg)
for namespace in index_stats["namespaces"]:
log.info(f"Pinecone index delete namespace: {namespace}")
index.delete(delete_all=True, namespace=namespace)
@@ -47,11 +46,11 @@ def __init__(
self._metadata_key = "meta"
@classmethod
- def config_cls(cls) -> Type[DBConfig]:
+ def config_cls(cls) -> type[DBConfig]:
return PineconeConfig
@classmethod
- def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]:
+ def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]:
return EmptyDBCaseConfig
@contextmanager
@@ -76,9 +75,7 @@ def insert_embeddings(
insert_count = 0
try:
for batch_start_offset in range(0, len(embeddings), self.batch_size):
- batch_end_offset = min(
- batch_start_offset + self.batch_size, len(embeddings)
- )
+ batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings))
insert_datas = []
for i in range(batch_start_offset, batch_end_offset):
insert_data = (
@@ -100,10 +97,7 @@ def search_embedding(
filters: dict | None = None,
timeout: int | None = None,
) -> list[int]:
- if filters is None:
- pinecone_filters = {}
- else:
- pinecone_filters = {self._metadata_key: {"$gte": filters["id"]}}
+ pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}}
try:
res = self.index.query(
top_k=k,
@@ -111,7 +105,6 @@ def search_embedding(
filter=pinecone_filters,
)["matches"]
except Exception as e:
- print(f"Error querying index: {e}")
- raise e
- id_res = [int(one_res["id"]) for one_res in res]
- return id_res
+ log.warning(f"Error querying index: {e}")
+ raise e from e
+ return [int(one_res["id"]) for one_res in res]
diff --git a/vectordb_bench/backend/clients/qdrant_cloud/config.py b/vectordb_bench/backend/clients/qdrant_cloud/config.py
index 5b1dd7f18..c1d6882c0 100644
--- a/vectordb_bench/backend/clients/qdrant_cloud/config.py
+++ b/vectordb_bench/backend/clients/qdrant_cloud/config.py
@@ -1,7 +1,7 @@
-from pydantic import BaseModel, SecretStr
+from pydantic import BaseModel, SecretStr, validator
+
+from ..api import DBCaseConfig, DBConfig, MetricType
-from ..api import DBConfig, DBCaseConfig, MetricType
-from pydantic import validator
# Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant.
class QdrantConfig(DBConfig):
@@ -16,17 +16,19 @@ def to_dict(self) -> dict:
"api_key": self.api_key.get_secret_value(),
"prefer_grpc": True,
}
- else:
- return {"url": self.url.get_secret_value(),}
-
+ return {
+ "url": self.url.get_secret_value(),
+ }
+
@validator("*")
- def not_empty_field(cls, v, field):
+ def not_empty_field(cls, v: any, field: any):
if field.name in ["api_key", "db_label"]:
return v
- if isinstance(v, (str, SecretStr)) and len(v) == 0:
+ if isinstance(v, str | SecretStr) and len(v) == 0:
raise ValueError("Empty string!")
return v
+
class QdrantIndexConfig(BaseModel, DBCaseConfig):
metric_type: MetricType | None = None
@@ -40,8 +42,7 @@ def parse_metric(self) -> str:
return "Cosine"
def index_param(self) -> dict:
- params = {"distance": self.parse_metric()}
- return params
+ return {"distance": self.parse_metric()}
def search_param(self) -> dict:
return {}
diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py
index a51632bc6..0861e8938 100644
--- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py
+++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py
@@ -4,23 +4,26 @@
import time
from contextlib import contextmanager
-from ..api import VectorDB, DBCaseConfig
+from qdrant_client import QdrantClient
from qdrant_client.http.models import (
- CollectionStatus,
- VectorParams,
- PayloadSchemaType,
Batch,
- Filter,
+ CollectionStatus,
FieldCondition,
+ Filter,
+ PayloadSchemaType,
Range,
+ VectorParams,
)
-from qdrant_client import QdrantClient
-
+from ..api import DBCaseConfig, VectorDB
log = logging.getLogger(__name__)
+SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
+QDRANT_BATCH_SIZE = 500
+
+
class QdrantCloud(VectorDB):
def __init__(
self,
@@ -57,16 +60,14 @@ def init(self) -> None:
self.qdrant_client = QdrantClient(**self.db_config)
yield
self.qdrant_client = None
- del(self.qdrant_client)
+ del self.qdrant_client
def ready_to_load(self):
pass
-
def optimize(self):
assert self.qdrant_client, "Please call self.init() before"
# wait for vectors to be fully indexed
- SECONDS_WAITING_FOR_INDEXING_API_CALL = 5
try:
while True:
info = self.qdrant_client.get_collection(self.collection_name)
@@ -74,19 +75,26 @@ def optimize(self):
if info.status != CollectionStatus.GREEN:
continue
if info.status == CollectionStatus.GREEN:
- log.info(f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, Collection status: {info.indexed_vectors_count}")
+ msg = (
+ f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, ",
+ f"Collection status: {info.indexed_vectors_count}",
+ )
+ log.info(msg)
return
except Exception as e:
log.warning(f"QdrantCloud ready to search error: {e}")
raise e from None
- def _create_collection(self, dim, qdrant_client: int):
+ def _create_collection(self, dim: int, qdrant_client: QdrantClient):
log.info(f"Create collection: {self.collection_name}")
try:
qdrant_client.create_collection(
collection_name=self.collection_name,
- vectors_config=VectorParams(size=dim, distance=self.case_config.index_param()["distance"])
+ vectors_config=VectorParams(
+ size=dim,
+ distance=self.case_config.index_param()["distance"],
+ ),
)
qdrant_client.create_payload_index(
@@ -109,13 +117,12 @@ def insert_embeddings(
) -> (int, Exception):
"""Insert embeddings into Milvus. should call self.init() first"""
assert self.qdrant_client is not None
- QDRANT_BATCH_SIZE = 500
try:
# TODO: counts
for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE):
- vectors = embeddings[offset: offset + QDRANT_BATCH_SIZE]
- ids = metadata[offset: offset + QDRANT_BATCH_SIZE]
- payloads=[{self._primary_field: v} for v in ids]
+ vectors = embeddings[offset : offset + QDRANT_BATCH_SIZE]
+ ids = metadata[offset : offset + QDRANT_BATCH_SIZE]
+ payloads = [{self._primary_field: v} for v in ids]
_ = self.qdrant_client.upsert(
collection_name=self.collection_name,
wait=True,
@@ -142,21 +149,23 @@ def search_embedding(
f = None
if filters:
f = Filter(
- must=[FieldCondition(
- key = self._primary_field,
- range = Range(
- gt=filters.get('id'),
+ must=[
+ FieldCondition(
+ key=self._primary_field,
+ range=Range(
+ gt=filters.get("id"),
+ ),
),
- )]
+ ],
)
- res = self.qdrant_client.search(
- collection_name=self.collection_name,
- query_vector=query,
- limit=k,
- query_filter=f,
- # with_payload=True,
- ),
+ res = (
+ self.qdrant_client.search(
+ collection_name=self.collection_name,
+ query_vector=query,
+ limit=k,
+ query_filter=f,
+ ),
+ )
- ret = [result.id for result in res[0]]
- return ret
+ return [result.id for result in res[0]]
diff --git a/vectordb_bench/backend/clients/redis/cli.py b/vectordb_bench/backend/clients/redis/cli.py
index eb86b4c00..69277c00f 100644
--- a/vectordb_bench/backend/clients/redis/cli.py
+++ b/vectordb_bench/backend/clients/redis/cli.py
@@ -3,9 +3,6 @@
import click
from pydantic import SecretStr
-from .config import RedisHNSWConfig
-
-
from ....cli.cli import (
CommonTypedDict,
HNSWFlavor2,
@@ -14,12 +11,11 @@
run,
)
from .. import DB
+from .config import RedisHNSWConfig
class RedisTypedDict(TypedDict):
- host: Annotated[
- str, click.option("--host", type=str, help="Db host", required=True)
- ]
+ host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)]
password: Annotated[str, click.option("--password", type=str, help="Db password")]
port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")]
ssl: Annotated[
@@ -52,27 +48,25 @@ class RedisTypedDict(TypedDict):
]
-class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2):
- ...
+class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2): ...
@cli.command()
@click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict)
def Redis(**parameters: Unpack[RedisHNSWTypedDict]):
from .config import RedisConfig
+
run(
db=DB.Redis,
db_config=RedisConfig(
db_label=parameters["db_label"],
- password=SecretStr(parameters["password"])
- if parameters["password"]
- else None,
+ password=SecretStr(parameters["password"]) if parameters["password"] else None,
host=SecretStr(parameters["host"]),
port=parameters["port"],
ssl=parameters["ssl"],
ssl_ca_certs=parameters["ssl_ca_certs"],
cmd=parameters["cmd"],
- ),
+ ),
db_case_config=RedisHNSWConfig(
M=parameters["m"],
efConstruction=parameters["ef_construction"],
diff --git a/vectordb_bench/backend/clients/redis/config.py b/vectordb_bench/backend/clients/redis/config.py
index 55a7a4159..db0127f84 100644
--- a/vectordb_bench/backend/clients/redis/config.py
+++ b/vectordb_bench/backend/clients/redis/config.py
@@ -1,10 +1,12 @@
-from pydantic import SecretStr, BaseModel
-from ..api import DBConfig, DBCaseConfig, MetricType, IndexType
+from pydantic import BaseModel, SecretStr
+
+from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
+
class RedisConfig(DBConfig):
password: SecretStr | None = None
host: SecretStr
- port: int | None = None
+ port: int | None = None
def to_dict(self) -> dict:
return {
@@ -12,7 +14,6 @@ def to_dict(self) -> dict:
"port": self.port,
"password": self.password.get_secret_value() if self.password is not None else None,
}
-
class RedisIndexConfig(BaseModel):
@@ -24,7 +25,8 @@ def parse_metric(self) -> str:
if not self.metric_type:
return ""
return self.metric_type.value
-
+
+
class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig):
M: int
efConstruction: int
diff --git a/vectordb_bench/backend/clients/redis/redis.py b/vectordb_bench/backend/clients/redis/redis.py
index cf51fcc48..139850d2f 100644
--- a/vectordb_bench/backend/clients/redis/redis.py
+++ b/vectordb_bench/backend/clients/redis/redis.py
@@ -1,36 +1,40 @@
import logging
from contextlib import contextmanager
from typing import Any
-from ..api import VectorDB, DBCaseConfig
+
+import numpy as np
import redis
-from redis.commands.search.field import TagField, VectorField, NumericField
+from redis.commands.search.field import NumericField, TagField, VectorField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.query import Query
-import numpy as np
+from ..api import DBCaseConfig, VectorDB
log = logging.getLogger(__name__)
-INDEX_NAME = "index" # Vector Index Name
+INDEX_NAME = "index" # Vector Index Name
+
class Redis(VectorDB):
def __init__(
- self,
- dim: int,
- db_config: dict,
- db_case_config: DBCaseConfig,
- drop_old: bool = False,
-
- **kwargs
- ):
-
+ self,
+ dim: int,
+ db_config: dict,
+ db_case_config: DBCaseConfig,
+ drop_old: bool = False,
+ **kwargs,
+ ):
self.db_config = db_config
self.case_config = db_case_config
self.collection_name = INDEX_NAME
# Create a redis connection, if db has password configured, add it to the connection here and in init():
- password=self.db_config["password"]
- conn = redis.Redis(host=self.db_config["host"], port=self.db_config["port"], password=password, db=0)
-
+ password = self.db_config["password"]
+ conn = redis.Redis(
+ host=self.db_config["host"],
+ port=self.db_config["port"],
+ password=password,
+ db=0,
+ )
if drop_old:
try:
@@ -39,7 +43,7 @@ def __init__(
except redis.exceptions.ResponseError:
drop_old = False
log.info(f"Redis client drop_old collection: {self.collection_name}")
-
+
self.make_index(dim, conn)
conn.close()
conn = None
@@ -48,18 +52,20 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis):
try:
# check to see if index exists
conn.ft(INDEX_NAME).info()
- except:
+ except Exception:
schema = (
- TagField("id"),
- NumericField("metadata"),
- VectorField("vector", # Vector Field Name
- "HNSW", { # Vector Index Type: FLAT or HNSW
- "TYPE": "FLOAT32", # FLOAT32 or FLOAT64
- "DIM": vector_dimensions, # Number of Vector Dimensions
- "DISTANCE_METRIC": "COSINE", # Vector Search Distance Metric
+ TagField("id"),
+ NumericField("metadata"),
+ VectorField(
+ "vector", # Vector Field Name
+ "HNSW", # Vector Index Type: FLAT or HNSW
+ {
+ "TYPE": "FLOAT32", # FLOAT32 or FLOAT64
+ "DIM": vector_dimensions, # Number of Vector Dimensions
+ "DISTANCE_METRIC": "COSINE", # Vector Search Distance Metric
"M": self.case_config.index_param()["params"]["M"],
"EF_CONSTRUCTION": self.case_config.index_param()["params"]["efConstruction"],
- }
+ },
),
)
@@ -69,23 +75,25 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis):
rs.create_index(schema, definition=definition)
@contextmanager
- def init(self):
- """ create and destory connections to database.
+ def init(self) -> None:
+ """create and destory connections to database.
Examples:
>>> with self.init():
>>> self.insert_embeddings()
"""
- self.conn = redis.Redis(host=self.db_config["host"], port=self.db_config["port"], password=self.db_config["password"], db=0)
+ self.conn = redis.Redis(
+ host=self.db_config["host"],
+ port=self.db_config["port"],
+ password=self.db_config["password"],
+ db=0,
+ )
yield
self.conn.close()
self.conn = None
-
def ready_to_search(self) -> bool:
"""Check if the database is ready to search."""
- pass
-
def ready_to_load(self) -> bool:
pass
@@ -93,7 +101,6 @@ def ready_to_load(self) -> bool:
def optimize(self) -> None:
pass
-
def insert_embeddings(
self,
embeddings: list[list[float]],
@@ -104,27 +111,30 @@ def insert_embeddings(
Should call self.init() first.
"""
- batch_size = 1000 # Adjust this as needed, but don't make too big
+ batch_size = 1000 # Adjust this as needed, but don't make too big
try:
with self.conn.pipeline(transaction=False) as pipe:
for i, embedding in enumerate(embeddings):
- embedding = np.array(embedding).astype(np.float32)
- pipe.hset(metadata[i], mapping = {
- "id": str(metadata[i]),
- "metadata": metadata[i],
- "vector": embedding.tobytes(),
- })
+ ndarr_emb = np.array(embedding).astype(np.float32)
+ pipe.hset(
+ metadata[i],
+ mapping={
+ "id": str(metadata[i]),
+ "metadata": metadata[i],
+ "vector": ndarr_emb.tobytes(),
+ },
+ )
# Execute the pipe so we don't keep too much in memory at once
if i % batch_size == 0:
- res = pipe.execute()
+ _ = pipe.execute()
- res = pipe.execute()
+ _ = pipe.execute()
result_len = i + 1
except Exception as e:
return 0, e
-
+
return result_len, None
-
+
def search_embedding(
self,
query: list[float],
@@ -132,26 +142,53 @@ def search_embedding(
filters: dict | None = None,
timeout: int | None = None,
**kwargs: Any,
- ) -> (list[int]):
+ ) -> list[int]:
assert self.conn is not None
-
+
query_vector = np.array(query).astype(np.float32).tobytes()
ef_runtime = self.case_config.search_param()["params"]["ef"]
- query_obj = Query(f"*=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2)
+ query_obj = (
+ Query(f"*=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]")
+ .sort_by("score")
+ .return_fields("id", "score")
+ .paging(0, k)
+ .dialect(2)
+ )
query_params = {"vec": query_vector}
-
+
if filters:
# benchmark test filters of format: {'metadata': '>=10000', 'id': 10000}
# gets exact match for id, and range for metadata if they exist in filters
id_value = filters.get("id")
metadata_value = filters.get("metadata")
if id_value and metadata_value:
- query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2)
+ query_obj = (
+ Query(
+ f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} ",
+ f"@vector $vec EF_RUNTIME {ef_runtime} as score]",
+ )
+ .sort_by("score")
+ .return_fields("id", "score")
+ .paging(0, k)
+ .dialect(2)
+ )
elif id_value:
- #gets exact match for id
- query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2)
- else: #metadata only case, greater than or equal to metadata value
- query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2)
+ # gets exact match for id
+ query_obj = (
+ Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]")
+ .sort_by("score")
+ .return_fields("id", "score")
+ .paging(0, k)
+ .dialect(2)
+ )
+ else: # metadata only case, greater than or equal to metadata value
+ query_obj = (
+ Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]")
+ .sort_by("score")
+ .return_fields("id", "score")
+ .paging(0, k)
+ .dialect(2)
+ )
res = self.conn.ft(INDEX_NAME).search(query_obj, query_params)
# doc in res of format {'id': '9831', 'payload': None, 'score': '1.19209289551e-07'}
return [int(doc["id"]) for doc in res.docs]
diff --git a/vectordb_bench/backend/clients/test/cli.py b/vectordb_bench/backend/clients/test/cli.py
index f06f33492..e5cd4c78b 100644
--- a/vectordb_bench/backend/clients/test/cli.py
+++ b/vectordb_bench/backend/clients/test/cli.py
@@ -10,8 +10,7 @@
from ..test.config import TestConfig, TestIndexConfig
-class TestTypedDict(CommonTypedDict):
- ...
+class TestTypedDict(CommonTypedDict): ...
@cli.command()
diff --git a/vectordb_bench/backend/clients/test/config.py b/vectordb_bench/backend/clients/test/config.py
index 01a77e000..351d7bcac 100644
--- a/vectordb_bench/backend/clients/test/config.py
+++ b/vectordb_bench/backend/clients/test/config.py
@@ -1,6 +1,6 @@
-from pydantic import BaseModel, SecretStr
+from pydantic import BaseModel
-from ..api import DBCaseConfig, DBConfig, IndexType, MetricType
+from ..api import DBCaseConfig, DBConfig, MetricType
class TestConfig(DBConfig):
diff --git a/vectordb_bench/backend/clients/test/test.py b/vectordb_bench/backend/clients/test/test.py
index 78732eb1e..ee5a523f3 100644
--- a/vectordb_bench/backend/clients/test/test.py
+++ b/vectordb_bench/backend/clients/test/test.py
@@ -1,6 +1,7 @@
import logging
+from collections.abc import Generator
from contextlib import contextmanager
-from typing import Any, Generator, Optional, Tuple
+from typing import Any
from ..api import DBCaseConfig, VectorDB
@@ -43,11 +44,10 @@ def insert_embeddings(
embeddings: list[list[float]],
metadata: list[int],
**kwargs: Any,
- ) -> Tuple[int, Optional[Exception]]:
+ ) -> tuple[int, Exception | None]:
"""Insert embeddings into the database.
Should call self.init() first.
"""
- raise RuntimeError("Not implemented")
return len(metadata), None
def search_embedding(
@@ -58,5 +58,4 @@ def search_embedding(
timeout: int | None = None,
**kwargs: Any,
) -> list[int]:
- raise NotImplementedError
- return [i for i in range(k)]
+ return list(range(k))
diff --git a/vectordb_bench/backend/clients/weaviate_cloud/cli.py b/vectordb_bench/backend/clients/weaviate_cloud/cli.py
index b6f16011b..181898c74 100644
--- a/vectordb_bench/backend/clients/weaviate_cloud/cli.py
+++ b/vectordb_bench/backend/clients/weaviate_cloud/cli.py
@@ -14,7 +14,8 @@
class WeaviateTypedDict(CommonTypedDict):
api_key: Annotated[
- str, click.option("--api-key", type=str, help="Weaviate api key", required=True)
+ str,
+ click.option("--api-key", type=str, help="Weaviate api key", required=True),
]
url: Annotated[
str,
@@ -34,8 +35,6 @@ def Weaviate(**parameters: Unpack[WeaviateTypedDict]):
api_key=SecretStr(parameters["api_key"]),
url=SecretStr(parameters["url"]),
),
- db_case_config=WeaviateIndexConfig(
- ef=256, efConstruction=256, maxConnections=16
- ),
+ db_case_config=WeaviateIndexConfig(ef=256, efConstruction=256, maxConnections=16),
**parameters,
)
diff --git a/vectordb_bench/backend/clients/weaviate_cloud/config.py b/vectordb_bench/backend/clients/weaviate_cloud/config.py
index 8b2f3ab81..4c58167d4 100644
--- a/vectordb_bench/backend/clients/weaviate_cloud/config.py
+++ b/vectordb_bench/backend/clients/weaviate_cloud/config.py
@@ -1,6 +1,6 @@
from pydantic import BaseModel, SecretStr
-from ..api import DBConfig, DBCaseConfig, MetricType
+from ..api import DBCaseConfig, DBConfig, MetricType
class WeaviateConfig(DBConfig):
@@ -23,7 +23,7 @@ class WeaviateIndexConfig(BaseModel, DBCaseConfig):
def parse_metric(self) -> str:
if self.metric_type == MetricType.L2:
return "l2-squared"
- elif self.metric_type == MetricType.IP:
+ if self.metric_type == MetricType.IP:
return "dot"
return "cosine"
diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py
index d8fde5f09..b42f70af1 100644
--- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py
+++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py
@@ -1,13 +1,13 @@
"""Wrapper around the Weaviate vector database over VectorDB"""
import logging
-from typing import Iterable
+from collections.abc import Iterable
from contextlib import contextmanager
import weaviate
from weaviate.exceptions import WeaviateBaseError
-from ..api import VectorDB, DBCaseConfig
+from ..api import DBCaseConfig, VectorDB
log = logging.getLogger(__name__)
@@ -23,7 +23,13 @@ def __init__(
**kwargs,
):
"""Initialize wrapper around the weaviate vector database."""
- db_config.update({"auth_client_secret": weaviate.AuthApiKey(api_key=db_config.get("auth_client_secret"))})
+ db_config.update(
+ {
+ "auth_client_secret": weaviate.AuthApiKey(
+ api_key=db_config.get("auth_client_secret"),
+ ),
+ },
+ )
self.db_config = db_config
self.case_config = db_case_config
self.collection_name = collection_name
@@ -33,6 +39,7 @@ def __init__(
self._index_name = "vector_idx"
from weaviate import Client
+
client = Client(**db_config)
if drop_old:
try:
@@ -40,7 +47,7 @@ def __init__(
log.info(f"weaviate client drop_old collection: {self.collection_name}")
client.schema.delete_class(self.collection_name)
except WeaviateBaseError as e:
- log.warning(f"Failed to drop collection: {self.collection_name} error: {str(e)}")
+ log.warning(f"Failed to drop collection: {self.collection_name} error: {e!s}")
raise e from None
self._create_collection(client)
client = None
@@ -54,20 +61,23 @@ def init(self) -> None:
>>> self.search_embedding()
"""
from weaviate import Client
+
self.client = Client(**self.db_config)
yield
self.client = None
- del(self.client)
+ del self.client
def ready_to_load(self):
"""Should call insert first, do nothing"""
- pass
def optimize(self):
assert self.client.schema.exists(self.collection_name)
- self.client.schema.update_config(self.collection_name, {"vectorIndexConfig": self.case_config.search_param() } )
+ self.client.schema.update_config(
+ self.collection_name,
+ {"vectorIndexConfig": self.case_config.search_param()},
+ )
- def _create_collection(self, client):
+ def _create_collection(self, client: weaviate.Client) -> None:
if not client.schema.exists(self.collection_name):
log.info(f"Create collection: {self.collection_name}")
class_obj = {
@@ -78,13 +88,13 @@ def _create_collection(self, client):
"dataType": ["int"],
"name": self._scalar_field,
},
- ]
+ ],
}
class_obj["vectorIndexConfig"] = self.case_config.index_param()
try:
client.schema.create_class(class_obj)
except WeaviateBaseError as e:
- log.warning(f"Failed to create collection: {self.collection_name} error: {str(e)}")
+ log.warning(f"Failed to create collection: {self.collection_name} error: {e!s}")
raise e from None
def insert_embeddings(
@@ -102,15 +112,17 @@ def insert_embeddings(
batch.dynamic = True
res = []
for i in range(len(metadata)):
- res.append(batch.add_data_object(
- {self._scalar_field: metadata[i]},
- class_name=self.collection_name,
- vector=embeddings[i]
- ))
+ res.append(
+ batch.add_data_object(
+ {self._scalar_field: metadata[i]},
+ class_name=self.collection_name,
+ vector=embeddings[i],
+ ),
+ )
insert_count += 1
return (len(res), None)
except WeaviateBaseError as e:
- log.warning(f"Failed to insert data, error: {str(e)}")
+ log.warning(f"Failed to insert data, error: {e!s}")
return (insert_count, e)
def search_embedding(
@@ -125,12 +137,17 @@ def search_embedding(
"""
assert self.client.schema.exists(self.collection_name)
- query_obj = self.client.query.get(self.collection_name, [self._scalar_field]).with_additional("distance").with_near_vector({"vector": query}).with_limit(k)
+ query_obj = (
+ self.client.query.get(self.collection_name, [self._scalar_field])
+ .with_additional("distance")
+ .with_near_vector({"vector": query})
+ .with_limit(k)
+ )
if filters:
where_filter = {
"path": "key",
"operator": "GreaterThanEqual",
- "valueInt": filters.get('id')
+ "valueInt": filters.get("id"),
}
query_obj = query_obj.with_where(where_filter)
@@ -138,7 +155,4 @@ def search_embedding(
res = query_obj.do()
# Organize results.
- ret = [result[self._scalar_field] for result in res["data"]["Get"][self.collection_name]]
-
- return ret
-
+ return [result[self._scalar_field] for result in res["data"]["Get"][self.collection_name]]
diff --git a/vectordb_bench/backend/clients/zilliz_cloud/cli.py b/vectordb_bench/backend/clients/zilliz_cloud/cli.py
index 31618f4ec..810a4175e 100644
--- a/vectordb_bench/backend/clients/zilliz_cloud/cli.py
+++ b/vectordb_bench/backend/clients/zilliz_cloud/cli.py
@@ -1,33 +1,36 @@
+import os
from typing import Annotated, Unpack
import click
-import os
from pydantic import SecretStr
+from vectordb_bench.backend.clients import DB
from vectordb_bench.cli.cli import (
CommonTypedDict,
cli,
click_parameter_decorators_from_typed_dict,
run,
)
-from vectordb_bench.backend.clients import DB
class ZillizTypedDict(CommonTypedDict):
uri: Annotated[
- str, click.option("--uri", type=str, help="uri connection string", required=True)
+ str,
+ click.option("--uri", type=str, help="uri connection string", required=True),
]
user_name: Annotated[
- str, click.option("--user-name", type=str, help="Db username", required=True)
+ str,
+ click.option("--user-name", type=str, help="Db username", required=True),
]
password: Annotated[
str,
- click.option("--password",
- type=str,
- help="Zilliz password",
- default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""),
- show_default="$ZILLIZ_PASSWORD",
- ),
+ click.option(
+ "--password",
+ type=str,
+ help="Zilliz password",
+ default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""),
+ show_default="$ZILLIZ_PASSWORD",
+ ),
]
level: Annotated[
str,
@@ -38,7 +41,7 @@ class ZillizTypedDict(CommonTypedDict):
@cli.command()
@click_parameter_decorators_from_typed_dict(ZillizTypedDict)
def ZillizAutoIndex(**parameters: Unpack[ZillizTypedDict]):
- from .config import ZillizCloudConfig, AutoIndexConfig
+ from .config import AutoIndexConfig, ZillizCloudConfig
run(
db=DB.ZillizCloud,
diff --git a/vectordb_bench/backend/clients/zilliz_cloud/config.py b/vectordb_bench/backend/clients/zilliz_cloud/config.py
index ee60b397f..9f113dbda 100644
--- a/vectordb_bench/backend/clients/zilliz_cloud/config.py
+++ b/vectordb_bench/backend/clients/zilliz_cloud/config.py
@@ -1,7 +1,7 @@
from pydantic import SecretStr
from ..api import DBCaseConfig, DBConfig
-from ..milvus.config import MilvusIndexConfig, IndexType
+from ..milvus.config import IndexType, MilvusIndexConfig
class ZillizCloudConfig(DBConfig):
@@ -33,7 +33,5 @@ def search_param(self) -> dict:
"metric_type": self.parse_metric(),
"params": {
"level": self.level,
- }
+ },
}
-
-
diff --git a/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py b/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py
index 36f7fb204..4ce15545a 100644
--- a/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py
+++ b/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py
@@ -1,7 +1,7 @@
"""Wrapper around the ZillizCloud vector database over VectorDB"""
-from ..milvus.milvus import Milvus
from ..api import DBCaseConfig
+from ..milvus.milvus import Milvus
class ZillizCloud(Milvus):
diff --git a/vectordb_bench/backend/data_source.py b/vectordb_bench/backend/data_source.py
index 9e2f172b4..b98dc7d7a 100644
--- a/vectordb_bench/backend/data_source.py
+++ b/vectordb_bench/backend/data_source.py
@@ -1,12 +1,12 @@
import logging
import pathlib
import typing
+from abc import ABC, abstractmethod
from enum import Enum
+
from tqdm import tqdm
-import os
-from abc import ABC, abstractmethod
-from .. import config
+from vectordb_bench import config
logging.getLogger("s3fs").setLevel(logging.CRITICAL)
@@ -14,6 +14,7 @@
DatasetReader = typing.TypeVar("DatasetReader")
+
class DatasetSource(Enum):
S3 = "S3"
AliyunOSS = "AliyunOSS"
@@ -25,6 +26,8 @@ def reader(self) -> DatasetReader:
if self == DatasetSource.AliyunOSS:
return AliyunOSSReader()
+ return None
+
class DatasetReader(ABC):
source: DatasetSource
@@ -39,7 +42,6 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
files(list[str]): all filenames of the dataset
local_ds_root(pathlib.Path): whether to write the remote data.
"""
- pass
@abstractmethod
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
@@ -52,15 +54,18 @@ class AliyunOSSReader(DatasetReader):
def __init__(self):
import oss2
+
self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True)
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
info = self.bucket.get_object_meta(remote.as_posix())
# check size equal
- remote_size, local_size = info.content_length, os.path.getsize(local)
+ remote_size, local_size = info.content_length, local.stat().st_size
if remote_size != local_size:
- log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
+ log.info(
+ f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]",
+ )
return False
return True
@@ -70,7 +75,13 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
if not local_ds_root.exists():
log.info(f"local dataset root path not exist, creating it: {local_ds_root}")
local_ds_root.mkdir(parents=True)
- downloads = [(pathlib.PurePosixPath("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files]
+ downloads = [
+ (
+ pathlib.PurePosixPath("benchmark", dataset, f),
+ local_ds_root.joinpath(f),
+ )
+ for f in files
+ ]
else:
for file in files:
@@ -78,7 +89,9 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
local_file = local_ds_root.joinpath(file)
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
- log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
+ log.info(
+ f"local file: {local_file} not match with remote: {remote_file}; add to downloading list",
+ )
downloads.append((remote_file, local_file))
if len(downloads) == 0:
@@ -92,17 +105,14 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
-
class AwsS3Reader(DatasetReader):
source: DatasetSource = DatasetSource.S3
remote_root: str = config.AWS_S3_URL
def __init__(self):
import s3fs
- self.fs = s3fs.S3FileSystem(
- anon=True,
- client_kwargs={'region_name': 'us-west-2'}
- )
+
+ self.fs = s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": "us-west-2"})
def ls_all(self, dataset: str):
dataset_root_dir = pathlib.Path(self.remote_root, dataset)
@@ -112,7 +122,6 @@ def ls_all(self, dataset: str):
log.info(n)
return names
-
def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
downloads = []
if not local_ds_root.exists():
@@ -126,7 +135,9 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
local_file = local_ds_root.joinpath(file)
if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)):
- log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list")
+ log.info(
+ f"local file: {local_file} not match with remote: {remote_file}; add to downloading list",
+ )
downloads.append(remote_file)
if len(downloads) == 0:
@@ -139,15 +150,16 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path):
log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}")
-
def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool:
# info() uses ls() inside, maybe we only need to ls once
info = self.fs.info(remote)
# check size equal
- remote_size, local_size = info.get("size"), os.path.getsize(local)
+ remote_size, local_size = info.get("size"), local.stat().st_size
if remote_size != local_size:
- log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]")
+ log.info(
+ f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]",
+ )
return False
return True
diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py
index d62d96684..62700b0fa 100644
--- a/vectordb_bench/backend/dataset.py
+++ b/vectordb_bench/backend/dataset.py
@@ -4,25 +4,30 @@
>>> Dataset.Cohere.get(100_000)
"""
-from collections import namedtuple
import logging
import pathlib
+import typing
from enum import Enum
+
import pandas as pd
-from pydantic import validator, PrivateAttr
import polars as pl
from pyarrow.parquet import ParquetFile
+from pydantic import PrivateAttr, validator
+
+from vectordb_bench import config
+from vectordb_bench.base import BaseModel
-from ..base import BaseModel
-from .. import config
-from ..backend.clients import MetricType
from . import utils
-from .data_source import DatasetSource, DatasetReader
+from .clients import MetricType
+from .data_source import DatasetReader, DatasetSource
log = logging.getLogger(__name__)
-SizeLabel = namedtuple('SizeLabel', ['size', 'label', 'file_count'])
+class SizeLabel(typing.NamedTuple):
+ size: int
+ label: str
+ file_count: int
class BaseDataset(BaseModel):
@@ -33,12 +38,13 @@ class BaseDataset(BaseModel):
use_shuffled: bool
with_gt: bool = False
_size_label: dict[int, SizeLabel] = PrivateAttr()
- isCustom: bool = False
+ is_custom: bool = False
@validator("size")
- def verify_size(cls, v):
+ def verify_size(cls, v: int):
if v not in cls._size_label:
- raise ValueError(f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}")
+ msg = f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}"
+ raise ValueError(msg)
return v
@property
@@ -53,13 +59,14 @@ def dir_name(self) -> str:
def file_count(self) -> int:
return self._size_label.get(self.size).file_count
+
class CustomDataset(BaseDataset):
dir: str
file_num: int
- isCustom: bool = True
+ is_custom: bool = True
@validator("size")
- def verify_size(cls, v):
+ def verify_size(cls, v: int):
return v
@property
@@ -102,7 +109,7 @@ class Cohere(BaseDataset):
dim: int = 768
metric_type: MetricType = MetricType.COSINE
use_shuffled: bool = config.USE_SHUFFLED_DATA
- with_gt: bool = True,
+ with_gt: bool = (True,)
_size_label: dict = {
100_000: SizeLabel(100_000, "SMALL", 1),
1_000_000: SizeLabel(1_000_000, "MEDIUM", 1),
@@ -124,7 +131,11 @@ class SIFT(BaseDataset):
metric_type: MetricType = MetricType.L2
use_shuffled: bool = False
_size_label: dict = {
- 500_000: SizeLabel(500_000, "SMALL", 1,),
+ 500_000: SizeLabel(
+ 500_000,
+ "SMALL",
+ 1,
+ ),
5_000_000: SizeLabel(5_000_000, "MEDIUM", 1),
# 50_000_000: SizeLabel(50_000_000, "LARGE", 50),
}
@@ -135,7 +146,7 @@ class OpenAI(BaseDataset):
dim: int = 1536
metric_type: MetricType = MetricType.COSINE
use_shuffled: bool = config.USE_SHUFFLED_DATA
- with_gt: bool = True,
+ with_gt: bool = (True,)
_size_label: dict = {
50_000: SizeLabel(50_000, "SMALL", 1),
500_000: SizeLabel(500_000, "MEDIUM", 1),
@@ -153,13 +164,14 @@ class DatasetManager(BaseModel):
>>> for data in cohere:
>>> print(data.columns)
"""
- data: BaseDataset
+
+ data: BaseDataset
test_data: pd.DataFrame | None = None
gt_data: pd.DataFrame | None = None
- train_files : list[str] = []
+ train_files: list[str] = []
reader: DatasetReader | None = None
- def __eq__(self, obj):
+ def __eq__(self, obj: any):
if isinstance(obj, DatasetManager):
return self.data.name == obj.data.name and self.data.label == obj.data.label
return False
@@ -169,22 +181,27 @@ def set_reader(self, reader: DatasetReader):
@property
def data_dir(self) -> pathlib.Path:
- """ data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
+ """data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname}
Examples:
>>> sift_s = Dataset.SIFT.manager(500_000)
>>> sift_s.relative_path
'/tmp/vectordb_bench/dataset/sift/sift_small_500k/'
"""
- return pathlib.Path(config.DATASET_LOCAL_DIR, self.data.name.lower(), self.data.dir_name.lower())
+ return pathlib.Path(
+ config.DATASET_LOCAL_DIR,
+ self.data.name.lower(),
+ self.data.dir_name.lower(),
+ )
def __iter__(self):
return DataSetIterator(self)
# TODO passing use_shuffle from outside
- def prepare(self,
- source: DatasetSource=DatasetSource.S3,
- filters: int | float | str | None = None,
+ def prepare(
+ self,
+ source: DatasetSource = DatasetSource.S3,
+ filters: float | str | None = None,
) -> bool:
"""Download the dataset from DatasetSource
url = f"{source}/{self.data.dir_name}"
@@ -208,7 +225,7 @@ def prepare(self,
gt_file, test_file = utils.compose_gt_file(filters), "test.parquet"
all_files.extend([gt_file, test_file])
- if not self.data.isCustom:
+ if not self.data.is_custom:
source.reader().read(
dataset=self.data.dir_name.lower(),
files=all_files,
@@ -220,7 +237,7 @@ def prepare(self,
self.gt_data = self._read_file(gt_file)
prefix = "shuffle_train" if use_shuffled else "train"
- self.train_files = sorted([f.name for f in self.data_dir.glob(f'{prefix}*.parquet')])
+ self.train_files = sorted([f.name for f in self.data_dir.glob(f"{prefix}*.parquet")])
log.debug(f"{self.data.name}: available train files {self.train_files}")
return True
@@ -241,7 +258,7 @@ def __init__(self, dataset: DatasetManager):
self._ds = dataset
self._idx = 0 # file number
self._cur = None
- self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
+ self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file
def __iter__(self):
return self
@@ -250,7 +267,9 @@ def _get_iter(self, file_name: str):
p = pathlib.Path(self._ds.data_dir, file_name)
log.info(f"Get iterator for {p.name}")
if not p.exists():
- raise IndexError(f"No such file {p}")
+ msg = f"No such file: {p}"
+ log.warning(msg)
+ raise IndexError(msg)
return ParquetFile(p, memory_map=True, pre_buffer=True).iter_batches(config.NUM_PER_BATCH)
def __next__(self) -> pd.DataFrame:
@@ -281,6 +300,7 @@ class Dataset(Enum):
>>> Dataset.COHERE.manager(100_000)
>>> Dataset.COHERE.get(100_000)
"""
+
LAION = LAION
GIST = GIST
COHERE = Cohere
diff --git a/vectordb_bench/backend/result_collector.py b/vectordb_bench/backend/result_collector.py
index d4c073c1a..a4119a5fd 100644
--- a/vectordb_bench/backend/result_collector.py
+++ b/vectordb_bench/backend/result_collector.py
@@ -1,7 +1,7 @@
+import logging
import pathlib
-from ..models import TestResult
-import logging
+from vectordb_bench.models import TestResult
log = logging.getLogger(__name__)
@@ -14,7 +14,6 @@ def collect(cls, result_dir: pathlib.Path) -> list[TestResult]:
if not result_dir.exists() or len(list(result_dir.rglob(reg))) == 0:
return []
-
for json_file in result_dir.rglob(reg):
file_result = TestResult.read_file(json_file, trans_unit=True)
diff --git a/vectordb_bench/backend/runner/__init__.py b/vectordb_bench/backend/runner/__init__.py
index 77bb25d67..b83df6f99 100644
--- a/vectordb_bench/backend/runner/__init__.py
+++ b/vectordb_bench/backend/runner/__init__.py
@@ -1,12 +1,10 @@
from .mp_runner import (
MultiProcessingSearchRunner,
)
-
-from .serial_runner import SerialSearchRunner, SerialInsertRunner
-
+from .serial_runner import SerialInsertRunner, SerialSearchRunner
__all__ = [
- 'MultiProcessingSearchRunner',
- 'SerialSearchRunner',
- 'SerialInsertRunner',
+ "MultiProcessingSearchRunner",
+ "SerialInsertRunner",
+ "SerialSearchRunner",
]
diff --git a/vectordb_bench/backend/runner/mp_runner.py b/vectordb_bench/backend/runner/mp_runner.py
index 8f35dcd8d..5b69b5481 100644
--- a/vectordb_bench/backend/runner/mp_runner.py
+++ b/vectordb_bench/backend/runner/mp_runner.py
@@ -1,27 +1,29 @@
-import time
-import traceback
import concurrent
+import logging
import multiprocessing as mp
import random
-import logging
-from typing import Iterable
+import time
+import traceback
+from collections.abc import Iterable
+
import numpy as np
-from ..clients import api
-from ... import config
+from ... import config
+from ..clients import api
NUM_PER_BATCH = config.NUM_PER_BATCH
log = logging.getLogger(__name__)
class MultiProcessingSearchRunner:
- """ multiprocessing search runner
+ """multiprocessing search runner
Args:
k(int): search topk, default to 100
concurrency(Iterable): concurrencies, default [1, 5, 10, 15, 20, 25, 30, 35]
duration(int): duration for each concurency, default to 30s
"""
+
def __init__(
self,
db: api.VectorDB,
@@ -40,7 +42,12 @@ def __init__(
self.test_data = test_data
log.debug(f"test dataset columns: {len(test_data)}")
- def search(self, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> tuple[int, float]:
+ def search(
+ self,
+ test_data: list[list[float]],
+ q: mp.Queue,
+ cond: mp.Condition,
+ ) -> tuple[int, float]:
# sync all process
q.put(1)
with cond:
@@ -71,24 +78,27 @@ def search(self, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition)
idx = idx + 1 if idx < num - 1 else 0
if count % 500 == 0:
- log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")
+ log.debug(
+ f"({mp.current_process().name:16}) ",
+ f"search_count: {count}, latest_latency={time.perf_counter()-s}",
+ )
total_dur = round(time.perf_counter() - start_time, 4)
log.info(
f"{mp.current_process().name:16} search {self.duration}s: "
- f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
- )
+ f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}",
+ )
return (count, total_dur, latencies)
@staticmethod
def get_mp_context():
mp_start_method = "spawn"
- log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}")
+ log.debug(
+ f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}",
+ )
return mp.get_context(mp_start_method)
-
-
def _run_all_concurrencies_mem_efficient(self):
max_qps = 0
conc_num_list = []
@@ -99,8 +109,13 @@ def _run_all_concurrencies_mem_efficient(self):
for conc in self.concurrencies:
with mp.Manager() as m:
q, cond = m.Queue(), m.Condition()
- with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
- log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}")
+ with concurrent.futures.ProcessPoolExecutor(
+ mp_context=self.get_mp_context(),
+ max_workers=conc,
+ ) as executor:
+ log.info(
+ f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}",
+ )
future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)]
# Sync all processes
while q.qsize() < conc:
@@ -109,7 +124,9 @@ def _run_all_concurrencies_mem_efficient(self):
with cond:
cond.notify_all()
- log.info(f"Syncing all process and start concurrency search, concurrency={conc}")
+ log.info(
+ f"Syncing all process and start concurrency search, concurrency={conc}",
+ )
start = time.perf_counter()
all_count = sum([r.result()[0] for r in future_iter])
@@ -123,13 +140,19 @@ def _run_all_concurrencies_mem_efficient(self):
conc_qps_list.append(qps)
conc_latency_p99_list.append(latency_p99)
conc_latency_avg_list.append(latency_avg)
- log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
+ log.info(
+ f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}",
+ )
if qps > max_qps:
max_qps = qps
- log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
+ log.info(
+ f"Update largest qps with concurrency {conc}: current max_qps={max_qps}",
+ )
except Exception as e:
- log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
+ log.warning(
+ f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}",
+ )
traceback.print_exc()
# No results available, raise exception
@@ -139,7 +162,13 @@ def _run_all_concurrencies_mem_efficient(self):
finally:
self.stop()
- return max_qps, conc_num_list, conc_qps_list, conc_latency_p99_list, conc_latency_avg_list
+ return (
+ max_qps,
+ conc_num_list,
+ conc_qps_list,
+ conc_latency_p99_list,
+ conc_latency_avg_list,
+ )
def run(self) -> float:
"""
@@ -160,9 +189,16 @@ def _run_by_dur(self, duration: int) -> float:
for conc in self.concurrencies:
with mp.Manager() as m:
q, cond = m.Queue(), m.Condition()
- with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor:
- log.info(f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}")
- future_iter = [executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)]
+ with concurrent.futures.ProcessPoolExecutor(
+ mp_context=self.get_mp_context(),
+ max_workers=conc,
+ ) as executor:
+ log.info(
+ f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}",
+ )
+ future_iter = [
+ executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)
+ ]
# Sync all processes
while q.qsize() < conc:
sleep_t = conc if conc < 10 else 10
@@ -170,20 +206,28 @@ def _run_by_dur(self, duration: int) -> float:
with cond:
cond.notify_all()
- log.info(f"Syncing all process and start concurrency search, concurrency={conc}")
+ log.info(
+ f"Syncing all process and start concurrency search, concurrency={conc}",
+ )
start = time.perf_counter()
all_count = sum([r.result() for r in future_iter])
cost = time.perf_counter() - start
qps = round(all_count / cost, 4)
- log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}")
+ log.info(
+ f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}",
+ )
if qps > max_qps:
max_qps = qps
- log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}")
+ log.info(
+ f"Update largest qps with concurrency {conc}: current max_qps={max_qps}",
+ )
except Exception as e:
- log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}")
+ log.warning(
+ f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}",
+ )
traceback.print_exc()
# No results available, raise exception
@@ -195,8 +239,13 @@ def _run_by_dur(self, duration: int) -> float:
return max_qps
-
- def search_by_dur(self, dur: int, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> int:
+ def search_by_dur(
+ self,
+ dur: int,
+ test_data: list[list[float]],
+ q: mp.Queue,
+ cond: mp.Condition,
+ ) -> int:
# sync all process
q.put(1)
with cond:
@@ -225,13 +274,15 @@ def search_by_dur(self, dur: int, test_data: list[list[float]], q: mp.Queue, con
idx = idx + 1 if idx < num - 1 else 0
if count % 500 == 0:
- log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}")
+ log.debug(
+ f"({mp.current_process().name:16}) search_count: {count}, ",
+ f"latest_latency={time.perf_counter()-s}",
+ )
total_dur = round(time.perf_counter() - start_time, 4)
log.debug(
f"{mp.current_process().name:16} search {self.duration}s: "
- f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}"
- )
+ f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}",
+ )
return count
-
diff --git a/vectordb_bench/backend/runner/rate_runner.py b/vectordb_bench/backend/runner/rate_runner.py
index d77c0fd15..0145af4ce 100644
--- a/vectordb_bench/backend/runner/rate_runner.py
+++ b/vectordb_bench/backend/runner/rate_runner.py
@@ -1,36 +1,36 @@
+import concurrent
import logging
+import multiprocessing as mp
import time
-import concurrent
from concurrent.futures import ThreadPoolExecutor
-import multiprocessing as mp
-
+from vectordb_bench import config
from vectordb_bench.backend.clients import api
from vectordb_bench.backend.dataset import DataSetIterator
from vectordb_bench.backend.utils import time_it
-from vectordb_bench import config
from .util import get_data
+
log = logging.getLogger(__name__)
class RatedMultiThreadingInsertRunner:
def __init__(
self,
- rate: int, # numRows per second
+ rate: int, # numRows per second
db: api.VectorDB,
dataset_iter: DataSetIterator,
normalize: bool = False,
timeout: float | None = None,
):
- self.timeout = timeout if isinstance(timeout, (int, float)) else None
+ self.timeout = timeout if isinstance(timeout, int | float) else None
self.dataset = dataset_iter
self.db = db
self.normalize = normalize
self.insert_rate = rate
self.batch_rate = rate // config.NUM_PER_BATCH
- def send_insert_task(self, db, emb: list[list[float]], metadata: list[str]):
+ def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]):
db.insert_embeddings(emb, metadata)
@time_it
@@ -43,7 +43,9 @@ def submit_by_rate() -> bool:
rate = self.batch_rate
for data in self.dataset:
emb, metadata = get_data(data, self.normalize)
- executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata))
+ executing_futures.append(
+ executor.submit(self.send_insert_task, self.db, emb, metadata),
+ )
rate -= 1
if rate == 0:
@@ -66,19 +68,26 @@ def submit_by_rate() -> bool:
done, not_done = concurrent.futures.wait(
executing_futures,
timeout=wait_interval,
- return_when=concurrent.futures.FIRST_EXCEPTION)
+ return_when=concurrent.futures.FIRST_EXCEPTION,
+ )
if len(not_done) > 0:
- log.warning(f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round")
+ log.warning(
+ f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] ",
+ f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round",
+ )
executing_futures = list(not_done)
else:
- log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}")
+ log.debug(
+ f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} ",
+ f"task in 1s, wait_interval={wait_interval:.2f}",
+ )
executing_futures = []
except Exception as e:
- log.warn(f"task error, terminating, err={e}")
- q.put(None, block=True)
- executor.shutdown(wait=True, cancel_futures=True)
- raise e
+ log.warning(f"task error, terminating, err={e}")
+ q.put(None, block=True)
+ executor.shutdown(wait=True, cancel_futures=True)
+ raise e from e
dur = time.perf_counter() - start_time
if dur < 1:
@@ -87,10 +96,12 @@ def submit_by_rate() -> bool:
# wait for all tasks in executing_futures to complete
if len(executing_futures) > 0:
try:
- done, _ = concurrent.futures.wait(executing_futures,
- return_when=concurrent.futures.FIRST_EXCEPTION)
+ done, _ = concurrent.futures.wait(
+ executing_futures,
+ return_when=concurrent.futures.FIRST_EXCEPTION,
+ )
except Exception as e:
- log.warn(f"task error, terminating, err={e}")
+ log.warning(f"task error, terminating, err={e}")
q.put(None, block=True)
executor.shutdown(wait=True, cancel_futures=True)
- raise e
+ raise e from e
diff --git a/vectordb_bench/backend/runner/read_write_runner.py b/vectordb_bench/backend/runner/read_write_runner.py
index fd425b227..e916f45d6 100644
--- a/vectordb_bench/backend/runner/read_write_runner.py
+++ b/vectordb_bench/backend/runner/read_write_runner.py
@@ -1,16 +1,18 @@
+import concurrent
import logging
-from typing import Iterable
+import math
import multiprocessing as mp
-import concurrent
+from collections.abc import Iterable
+
import numpy as np
-import math
-from .mp_runner import MultiProcessingSearchRunner
-from .serial_runner import SerialSearchRunner
-from .rate_runner import RatedMultiThreadingInsertRunner
from vectordb_bench.backend.clients import api
from vectordb_bench.backend.dataset import DatasetManager
+from .mp_runner import MultiProcessingSearchRunner
+from .rate_runner import RatedMultiThreadingInsertRunner
+from .serial_runner import SerialSearchRunner
+
log = logging.getLogger(__name__)
@@ -24,8 +26,14 @@ def __init__(
k: int = 100,
filters: dict | None = None,
concurrencies: Iterable[int] = (1, 15, 50),
- search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9), # search from insert portion, 0.0 means search from the start
- read_dur_after_write: int = 300, # seconds, search duration when insertion is done
+ search_stage: Iterable[float] = (
+ 0.5,
+ 0.6,
+ 0.7,
+ 0.8,
+ 0.9,
+ ), # search from insert portion, 0.0 means search from the start
+ read_dur_after_write: int = 300, # seconds, search duration when insertion is done
timeout: float | None = None,
):
self.insert_rate = insert_rate
@@ -36,7 +44,10 @@ def __init__(
self.search_stage = sorted(search_stage)
self.read_dur_after_write = read_dur_after_write
- log.info(f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, stage_search_dur={read_dur_after_write}")
+ log.info(
+ f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, ",
+ f"stage_search_dur={read_dur_after_write}",
+ )
test_emb = np.stack(dataset.test_data["emb"])
if normalize:
@@ -76,8 +87,13 @@ def run_search(self):
log.info("Search after write - Serial search start")
res, ssearch_dur = self.serial_search_runner.run()
recall, ndcg, p99_latency = res
- log.info(f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")
- log.info(f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}")
+ log.info(
+ f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, ",
+ f"dur={ssearch_dur:.4f}",
+ )
+ log.info(
+ f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}",
+ )
max_qps = self.run_by_dur(self.read_dur_after_write)
log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}")
@@ -86,7 +102,10 @@ def run_search(self):
def run_read_write(self):
with mp.Manager() as m:
q = m.Queue()
- with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("spawn"), max_workers=2) as executor:
+ with concurrent.futures.ProcessPoolExecutor(
+ mp_context=mp.get_context("spawn"),
+ max_workers=2,
+ ) as executor:
read_write_futures = []
read_write_futures.append(executor.submit(self.run_with_rate, q))
read_write_futures.append(executor.submit(self.run_search_by_sig, q))
@@ -107,10 +126,10 @@ def run_read_write(self):
except Exception as e:
log.warning(f"Read and write error: {e}")
executor.shutdown(wait=True, cancel_futures=True)
- raise e
+ raise e from e
log.info("Concurrent read write all done")
- def run_search_by_sig(self, q):
+ def run_search_by_sig(self, q: mp.Queue):
"""
Args:
q: multiprocessing queue
@@ -122,15 +141,14 @@ def run_search_by_sig(self, q):
total_batch = math.ceil(self.data_volume / self.insert_rate)
recall, ndcg, p99_latency = None, None, None
- def wait_next_target(start, target_batch) -> bool:
+ def wait_next_target(start: int, target_batch: int) -> bool:
"""Return False when receive True or None"""
while start < target_batch:
sig = q.get(block=True)
if sig is None or sig is True:
return False
- else:
- start += 1
+ start += 1
return True
for idx, stage in enumerate(self.search_stage):
@@ -139,19 +157,24 @@ def wait_next_target(start, target_batch) -> bool:
got = wait_next_target(start_batch, target_batch)
if got is False:
- log.warning(f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}")
- return
+ log.warning(
+ f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}",
+ )
+ return None
log.info(f"Insert {perc}% done, total batch={total_batch}")
log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start")
res, ssearch_dur = self.serial_search_runner.run()
recall, ndcg, p99_latency = res
- log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}")
+ log.info(
+ f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ",
+ f"ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}",
+ )
# Search duration for non-last search stage is carefully calculated.
# If duration for each concurrency is less than 30s, runner will raise error.
if idx < len(self.search_stage) - 1:
- total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
+ total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate
csearch_dur = total_dur_between_stages - ssearch_dur
# Try to leave room for init process executors
@@ -159,14 +182,19 @@ def wait_next_target(start, target_batch) -> bool:
each_conc_search_dur = csearch_dur / len(self.concurrencies)
if each_conc_search_dur < 30:
- warning_msg = f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}."
+ warning_msg = (
+ f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, ",
+ f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}.",
+ )
log.warning(warning_msg)
# The last stage
else:
each_conc_search_dur = 60
- log.info(f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}")
+ log.info(
+ f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}",
+ )
max_qps = self.run_by_dur(each_conc_search_dur)
result.append((perc, max_qps, recall, ndcg, p99_latency))
diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py
index 13270e21a..7eb59432b 100644
--- a/vectordb_bench/backend/runner/serial_runner.py
+++ b/vectordb_bench/backend/runner/serial_runner.py
@@ -1,20 +1,21 @@
-import time
-import logging
-import traceback
import concurrent
-import multiprocessing as mp
+import logging
import math
-import psutil
+import multiprocessing as mp
+import time
+import traceback
import numpy as np
import pandas as pd
+import psutil
-from ..clients import api
+from vectordb_bench.backend.dataset import DatasetManager
+
+from ... import config
from ...metric import calc_ndcg, calc_recall, get_ideal_dcg
from ...models import LoadTimeoutError, PerformanceTimeoutError
from .. import utils
-from ... import config
-from vectordb_bench.backend.dataset import DatasetManager
+from ..clients import api
NUM_PER_BATCH = config.NUM_PER_BATCH
LOAD_MAX_TRY_COUNT = 10
@@ -22,9 +23,16 @@
log = logging.getLogger(__name__)
+
class SerialInsertRunner:
- def __init__(self, db: api.VectorDB, dataset: DatasetManager, normalize: bool, timeout: float | None = None):
- self.timeout = timeout if isinstance(timeout, (int, float)) else None
+ def __init__(
+ self,
+ db: api.VectorDB,
+ dataset: DatasetManager,
+ normalize: bool,
+ timeout: float | None = None,
+ ):
+ self.timeout = timeout if isinstance(timeout, int | float) else None
self.dataset = dataset
self.db = db
self.normalize = normalize
@@ -32,18 +40,20 @@ def __init__(self, db: api.VectorDB, dataset: DatasetManager, normalize: bool, t
def task(self) -> int:
count = 0
with self.db.init():
- log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}")
+ log.info(
+ f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}",
+ )
start = time.perf_counter()
for data_df in self.dataset:
- all_metadata = data_df['id'].tolist()
+ all_metadata = data_df["id"].tolist()
- emb_np = np.stack(data_df['emb'])
+ emb_np = np.stack(data_df["emb"])
if self.normalize:
log.debug("normalize the 100k train data")
all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()
else:
all_embeddings = emb_np.tolist()
- del(emb_np)
+ del emb_np
log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}")
insert_count, error = self.db.insert_embeddings(
@@ -56,30 +66,41 @@ def task(self) -> int:
assert insert_count == len(all_metadata)
count += insert_count
if count % 100_000 == 0:
- log.info(f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB")
+ log.info(
+ f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB",
+ )
- log.info(f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, dur={time.perf_counter()-start}")
+ log.info(
+ f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, ",
+ f"dur={time.perf_counter()-start}",
+ )
return count
- def endless_insert_data(self, all_embeddings, all_metadata, left_id: int = 0) -> int:
+ def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: int = 0) -> int:
with self.db.init():
# unique id for endlessness insertion
- all_metadata = [i+left_id for i in all_metadata]
+ all_metadata = [i + left_id for i in all_metadata]
- NUM_BATCHES = math.ceil(len(all_embeddings)/NUM_PER_BATCH)
- log.info(f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
+ num_batches = math.ceil(len(all_embeddings) / NUM_PER_BATCH)
+ log.info(
+ f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} ",
+ f"embeddings in batch {NUM_PER_BATCH}",
+ )
count = 0
- for batch_id in range(NUM_BATCHES):
+ for batch_id in range(num_batches):
retry_count = 0
already_insert_count = 0
- metadata = all_metadata[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH]
- embeddings = all_embeddings[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH]
+ metadata = all_metadata[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH]
+ embeddings = all_embeddings[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH]
- log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Start inserting {len(metadata)} embeddings")
+ log.debug(
+ f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ",
+ f"Start inserting {len(metadata)} embeddings",
+ )
while retry_count < LOAD_MAX_TRY_COUNT:
insert_count, error = self.db.insert_embeddings(
- embeddings=embeddings[already_insert_count :],
- metadata=metadata[already_insert_count :],
+ embeddings=embeddings[already_insert_count:],
+ metadata=metadata[already_insert_count:],
)
already_insert_count += insert_count
if error is not None:
@@ -91,17 +112,26 @@ def endless_insert_data(self, all_embeddings, all_metadata, left_id: int = 0) ->
raise error
else:
break
- log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Finish inserting {len(metadata)} embeddings")
+ log.debug(
+ f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ",
+ f"Finish inserting {len(metadata)} embeddings",
+ )
assert already_insert_count == len(metadata)
count += already_insert_count
- log.info(f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}")
+ log.info(
+ f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in ",
+ f"batch {NUM_PER_BATCH}",
+ )
return count
@utils.time_it
def _insert_all_batches(self) -> int:
"""Performance case only"""
- with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context('spawn'), max_workers=1) as executor:
+ with concurrent.futures.ProcessPoolExecutor(
+ mp_context=mp.get_context("spawn"),
+ max_workers=1,
+ ) as executor:
future = executor.submit(self.task)
try:
count = future.result(timeout=self.timeout)
@@ -121,8 +151,11 @@ def run_endlessness(self) -> int:
"""run forever util DB raises exception or crash"""
# datasets for load tests are quite small, can fit into memory
# only 1 file
- data_df = [data_df for data_df in self.dataset][0]
- all_embeddings, all_metadata = np.stack(data_df["emb"]).tolist(), data_df['id'].tolist()
+ data_df = next(iter(self.dataset))
+ all_embeddings, all_metadata = (
+ np.stack(data_df["emb"]).tolist(),
+ data_df["id"].tolist(),
+ )
start_time = time.perf_counter()
max_load_count, times = 0, 0
@@ -130,18 +163,26 @@ def run_endlessness(self) -> int:
with self.db.init():
self.db.ready_to_load()
while time.perf_counter() - start_time < self.timeout:
- count = self.endless_insert_data(all_embeddings, all_metadata, left_id=max_load_count)
+ count = self.endless_insert_data(
+ all_embeddings,
+ all_metadata,
+ left_id=max_load_count,
+ )
max_load_count += count
times += 1
- log.info(f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, {max_load_count}")
+ log.info(
+ f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, ",
+ f"{max_load_count}",
+ )
except Exception as e:
- log.info(f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, {max_load_count}, err={e}")
+ log.info(
+ f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, ",
+ f"{max_load_count}, err={e}",
+ )
traceback.print_exc()
return max_load_count
else:
- msg = f"capacity case load timeout in {self.timeout}s"
- log.info(msg)
- raise LoadTimeoutError(msg)
+ raise LoadTimeoutError(self.timeout)
def run(self) -> int:
count, dur = self._insert_all_batches()
@@ -168,7 +209,9 @@ def __init__(
self.ground_truth = ground_truth
def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
- log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency")
+ log.info(
+ f"{mp.current_process().name:14} start search the entire test_data to get recall and latency",
+ )
with self.db.init():
test_data, ground_truth = args
ideal_dcg = get_ideal_dcg(self.k)
@@ -193,13 +236,15 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
latencies.append(time.perf_counter() - s)
- gt = ground_truth['neighbors_id'][idx]
- recalls.append(calc_recall(self.k, gt[:self.k], results))
- ndcgs.append(calc_ndcg(gt[:self.k], results, ideal_dcg))
-
+ gt = ground_truth["neighbors_id"][idx]
+ recalls.append(calc_recall(self.k, gt[: self.k], results))
+ ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg))
if len(latencies) % 100 == 0:
- log.debug(f"({mp.current_process().name:14}) search_count={len(latencies):3}, latest_latency={latencies[-1]}, latest recall={recalls[-1]}")
+ log.debug(
+ f"({mp.current_process().name:14}) search_count={len(latencies):3}, ",
+ f"latest_latency={latencies[-1]}, latest recall={recalls[-1]}",
+ )
avg_latency = round(np.mean(latencies), 4)
avg_recall = round(np.mean(recalls), 4)
@@ -213,16 +258,14 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]:
f"avg_recall={avg_recall}, "
f"avg_ndcg={avg_ndcg},"
f"avg_latency={avg_latency}, "
- f"p99={p99}"
- )
+ f"p99={p99}",
+ )
return (avg_recall, avg_ndcg, p99)
-
def _run_in_subprocess(self) -> tuple[float, float]:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
future = executor.submit(self.search, (self.test_data, self.ground_truth))
- result = future.result()
- return result
+ return future.result()
@utils.time_it
def run(self) -> tuple[float, float, float]:
diff --git a/vectordb_bench/backend/runner/util.py b/vectordb_bench/backend/runner/util.py
index ba1888167..50e91b04b 100644
--- a/vectordb_bench/backend/runner/util.py
+++ b/vectordb_bench/backend/runner/util.py
@@ -1,13 +1,14 @@
import logging
-from pandas import DataFrame
import numpy as np
+from pandas import DataFrame
log = logging.getLogger(__name__)
+
def get_data(data_df: DataFrame, normalize: bool) -> tuple[list[list[float]], list[str]]:
- all_metadata = data_df['id'].tolist()
- emb_np = np.stack(data_df['emb'])
+ all_metadata = data_df["id"].tolist()
+ emb_np = np.stack(data_df["emb"])
if normalize:
log.debug("normalize the 100k train data")
all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist()
diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py
index 568152ae0..e24d74f03 100644
--- a/vectordb_bench/backend/task_runner.py
+++ b/vectordb_bench/backend/task_runner.py
@@ -1,24 +1,20 @@
+import concurrent
import logging
-import psutil
import traceback
-import concurrent
-import numpy as np
from enum import Enum, auto
-from . import utils
-from .cases import Case, CaseLabel
-from ..base import BaseModel
-from ..models import TaskConfig, PerformanceTimeoutError, TaskStage
+import numpy as np
+import psutil
-from .clients import (
- api,
- MetricType
-)
-from ..metric import Metric
-from .runner import MultiProcessingSearchRunner
-from .runner import SerialSearchRunner, SerialInsertRunner
-from .data_source import DatasetSource
+from vectordb_bench.base import BaseModel
+from vectordb_bench.metric import Metric
+from vectordb_bench.models import PerformanceTimeoutError, TaskConfig, TaskStage
+from . import utils
+from .cases import Case, CaseLabel
+from .clients import MetricType, api
+from .data_source import DatasetSource
+from .runner import MultiProcessingSearchRunner, SerialInsertRunner, SerialSearchRunner
log = logging.getLogger(__name__)
@@ -53,24 +49,39 @@ class CaseRunner(BaseModel):
search_runner: MultiProcessingSearchRunner | None = None
final_search_runner: MultiProcessingSearchRunner | None = None
- def __eq__(self, obj):
+ def __eq__(self, obj: any):
if isinstance(obj, CaseRunner):
- return self.ca.label == CaseLabel.Performance and \
- self.config.db == obj.config.db and \
- self.config.db_case_config == obj.config.db_case_config and \
- self.ca.dataset == obj.ca.dataset
+ return (
+ self.ca.label == CaseLabel.Performance
+ and self.config.db == obj.config.db
+ and self.config.db_case_config == obj.config.db_case_config
+ and self.ca.dataset == obj.ca.dataset
+ )
return False
def display(self) -> dict:
- c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': {'name': True, 'size': True, 'dim': True, 'metric_type': True, 'label': True}} })
- c_dict['db'] = self.config.db_name
+ c_dict = self.ca.dict(
+ include={
+ "label": True,
+ "filters": True,
+ "dataset": {
+ "data": {
+ "name": True,
+ "size": True,
+ "dim": True,
+ "metric_type": True,
+ "label": True,
+ },
+ },
+ },
+ )
+ c_dict["db"] = self.config.db_name
return c_dict
@property
def normalize(self) -> bool:
assert self.db
- return self.db.need_normalize_cosine() and \
- self.ca.dataset.data.metric_type == MetricType.COSINE
+ return self.db.need_normalize_cosine() and self.ca.dataset.data.metric_type == MetricType.COSINE
def init_db(self, drop_old: bool = True) -> None:
db_cls = self.config.db.init_cls
@@ -80,8 +91,7 @@ def init_db(self, drop_old: bool = True) -> None:
db_config=self.config.db_config.to_dict(),
db_case_config=self.config.db_case_config,
drop_old=drop_old,
- ) # type:ignore
-
+ )
def _pre_run(self, drop_old: bool = True):
try:
@@ -89,12 +99,9 @@ def _pre_run(self, drop_old: bool = True):
self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate)
except ModuleNotFoundError as e:
log.warning(
- f"pre run case error: please install client for db: {self.config.db}, error={e}"
+ f"pre run case error: please install client for db: {self.config.db}, error={e}",
)
raise e from None
- except Exception as e:
- log.warning(f"pre run case error: {e}")
- raise e from None
def run(self, drop_old: bool = True) -> Metric:
log.info("Starting run")
@@ -103,12 +110,11 @@ def run(self, drop_old: bool = True) -> Metric:
if self.ca.label == CaseLabel.Load:
return self._run_capacity_case()
- elif self.ca.label == CaseLabel.Performance:
+ if self.ca.label == CaseLabel.Performance:
return self._run_perf_case(drop_old)
- else:
- msg = f"unknown case type: {self.ca.label}"
- log.warning(msg)
- raise ValueError(msg)
+ msg = f"unknown case type: {self.ca.label}"
+ log.warning(msg)
+ raise ValueError(msg)
def _run_capacity_case(self) -> Metric:
"""run capacity cases
@@ -120,7 +126,10 @@ def _run_capacity_case(self) -> Metric:
log.info("Start capacity case")
try:
runner = SerialInsertRunner(
- self.db, self.ca.dataset, self.normalize, self.ca.load_timeout
+ self.db,
+ self.ca.dataset,
+ self.normalize,
+ self.ca.load_timeout,
)
count = runner.run_endlessness()
except Exception as e:
@@ -128,7 +137,7 @@ def _run_capacity_case(self) -> Metric:
raise e from None
else:
log.info(
- f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}"
+ f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}",
)
return Metric(max_load_count=count)
@@ -138,7 +147,7 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric:
Returns:
Metric: load_duration, recall, serial_latency_p99, and, qps
"""
- '''
+ """
if drop_old:
_, load_dur = self._load_train_data()
build_dur = self._optimize()
@@ -153,38 +162,40 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric:
m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = self._conc_search()
m.recall, m.serial_latency_p99 = self._serial_search()
- '''
+ """
log.info("Start performance case")
try:
m = Metric()
if drop_old:
if TaskStage.LOAD in self.config.stages:
- # self._load_train_data()
_, load_dur = self._load_train_data()
build_dur = self._optimize()
m.load_duration = round(load_dur + build_dur, 4)
log.info(
f"Finish loading the entire dataset into VectorDB,"
f" insert_duration={load_dur}, optimize_duration={build_dur}"
- f" load_duration(insert + optimize) = {m.load_duration}"
+ f" load_duration(insert + optimize) = {m.load_duration}",
)
else:
log.info("Data loading skipped")
- if (
- TaskStage.SEARCH_SERIAL in self.config.stages
- or TaskStage.SEARCH_CONCURRENT in self.config.stages
- ):
+ if TaskStage.SEARCH_SERIAL in self.config.stages or TaskStage.SEARCH_CONCURRENT in self.config.stages:
self._init_search_runner()
if TaskStage.SEARCH_CONCURRENT in self.config.stages:
search_results = self._conc_search()
- m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list, m.conc_latency_avg_list = search_results
+ (
+ m.qps,
+ m.conc_num_list,
+ m.conc_qps_list,
+ m.conc_latency_p99_list,
+ m.conc_latency_avg_list,
+ ) = search_results
if TaskStage.SEARCH_SERIAL in self.config.stages:
search_results = self._serial_search()
- '''
+ """
m.recall = search_results.recall
m.serial_latencies = search_results.serial_latencies
- '''
+ """
m.recall, m.ndcg, m.serial_latency_p99 = search_results
except Exception as e:
@@ -199,7 +210,12 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric:
def _load_train_data(self):
"""Insert train data and get the insert_duration"""
try:
- runner = SerialInsertRunner(self.db, self.ca.dataset, self.normalize, self.ca.load_timeout)
+ runner = SerialInsertRunner(
+ self.db,
+ self.ca.dataset,
+ self.normalize,
+ self.ca.load_timeout,
+ )
runner.run()
except Exception as e:
raise e from None
@@ -215,11 +231,12 @@ def _serial_search(self) -> tuple[float, float, float]:
"""
try:
results, _ = self.serial_search_runner.run()
- return results
except Exception as e:
- log.warning(f"search error: {str(e)}, {e}")
+ log.warning(f"search error: {e!s}, {e}")
self.stop()
- raise e from None
+ raise e from e
+ else:
+ return results
def _conc_search(self):
"""Performance concurrency tests, search the test data endlessness
@@ -231,7 +248,7 @@ def _conc_search(self):
try:
return self.search_runner.run()
except Exception as e:
- log.warning(f"search error: {str(e)}, {e}")
+ log.warning(f"search error: {e!s}, {e}")
raise e from None
finally:
self.stop()
@@ -250,7 +267,7 @@ def _optimize(self) -> float:
log.warning(f"VectorDB optimize timeout in {self.ca.optimize_timeout}")
for pid, _ in executor._processes.items():
psutil.Process(pid).kill()
- raise PerformanceTimeoutError("Performance case optimize timeout") from e
+ raise PerformanceTimeoutError from e
except Exception as e:
log.warning(f"VectorDB optimize error: {e}")
raise e from None
@@ -286,6 +303,16 @@ def stop(self):
self.search_runner.stop()
+DATA_FORMAT = " %-14s | %-12s %-20s %7s | %-10s"
+TITLE_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s") % (
+ "DB",
+ "CaseType",
+ "Dataset",
+ "Filter",
+ "task_label",
+)
+
+
class TaskRunner(BaseModel):
run_id: str
task_label: str
@@ -304,18 +331,8 @@ def _get_num_by_status(self, status: RunningStatus) -> int:
return sum([1 for c in self.case_runners if c.status == status])
def display(self) -> None:
- DATA_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s")
- TITLE_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s") % (
- "DB", "CaseType", "Dataset", "Filter", "task_label")
-
fmt = [TITLE_FORMAT]
- fmt.append(DATA_FORMAT%(
- "-"*11,
- "-"*12,
- "-"*20,
- "-"*7,
- "-"*7
- ))
+ fmt.append(DATA_FORMAT % ("-" * 11, "-" * 12, "-" * 20, "-" * 7, "-" * 7))
for f in self.case_runners:
if f.ca.filter_rate != 0.0:
@@ -326,13 +343,16 @@ def display(self) -> None:
filters = "None"
ds_str = f"{f.ca.dataset.data.name}-{f.ca.dataset.data.label}-{utils.numerize(f.ca.dataset.data.size)}"
- fmt.append(DATA_FORMAT%(
- f.config.db_name,
- f.ca.label.name,
- ds_str,
- filters,
- self.task_label,
- ))
+ fmt.append(
+ DATA_FORMAT
+ % (
+ f.config.db_name,
+ f.ca.label.name,
+ ds_str,
+ filters,
+ self.task_label,
+ ),
+ )
tmp_logger = logging.getLogger("no_color")
for f in fmt:
diff --git a/vectordb_bench/backend/utils.py b/vectordb_bench/backend/utils.py
index ea11e6461..86c4faf5e 100644
--- a/vectordb_bench/backend/utils.py
+++ b/vectordb_bench/backend/utils.py
@@ -2,7 +2,7 @@
from functools import wraps
-def numerize(n) -> str:
+def numerize(n: int) -> str:
"""display positive number n for readability
Examples:
@@ -16,32 +16,34 @@ def numerize(n) -> str:
"K": 1e6,
"M": 1e9,
"B": 1e12,
- "END": float('inf'),
+ "END": float("inf"),
}
display_n, sufix = n, ""
for s, base in sufix2upbound.items():
# number >= 1000B will alway have sufix 'B'
if s == "END":
- display_n = int(n/1e9)
+ display_n = int(n / 1e9)
sufix = "B"
break
if n < base:
sufix = "" if s == "EMPTY" else s
- display_n = int(n/(base/1e3))
+ display_n = int(n / (base / 1e3))
break
return f"{display_n}{sufix}"
-def time_it(func):
- """ returns result and elapsed time"""
+def time_it(func: any):
+ """returns result and elapsed time"""
+
@wraps(func)
def inner(*args, **kwargs):
pref = time.perf_counter()
result = func(*args, **kwargs)
delta = time.perf_counter() - pref
return result, delta
+
return inner
@@ -62,14 +64,19 @@ def compose_train_files(train_count: int, use_shuffled: bool) -> list[str]:
return train_files
-def compose_gt_file(filters: int | float | str | None = None) -> str:
+ONE_PERCENT = 0.01
+NINETY_NINE_PERCENT = 0.99
+
+
+def compose_gt_file(filters: float | str | None = None) -> str:
if filters is None:
return "neighbors.parquet"
- if filters == 0.01:
+ if filters == ONE_PERCENT:
return "neighbors_head_1p.parquet"
- if filters == 0.99:
+ if filters == NINETY_NINE_PERCENT:
return "neighbors_tail_1p.parquet"
- raise ValueError(f"Filters not supported: {filters}")
+ msg = f"Filters not supported: {filters}"
+ raise ValueError(msg)
diff --git a/vectordb_bench/base.py b/vectordb_bench/base.py
index 3c71fb5a7..502d5fa49 100644
--- a/vectordb_bench/base.py
+++ b/vectordb_bench/base.py
@@ -3,4 +3,3 @@
class BaseModel(PydanticBaseModel, arbitrary_types_allowed=True):
pass
-
diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py
index e5b9a5fe2..3bb7763d8 100644
--- a/vectordb_bench/cli/cli.py
+++ b/vectordb_bench/cli/cli.py
@@ -1,27 +1,27 @@
import logging
+import os
import time
+from collections.abc import Callable
from concurrent.futures import wait
from datetime import datetime
from pprint import pformat
from typing import (
Annotated,
- Callable,
- List,
- Optional,
- Type,
+ Any,
TypedDict,
Unpack,
get_origin,
get_type_hints,
- Dict,
- Any,
)
+
import click
+from yaml import load
from vectordb_bench.backend.clients.api import MetricType
+
from .. import config
from ..backend.clients import DB
-from ..interface import benchMarkRunner, global_result_future
+from ..interface import benchmark_runner, global_result_future
from ..models import (
CaseConfig,
CaseType,
@@ -31,8 +31,7 @@
TaskConfig,
TaskStage,
)
-import os
-from yaml import load
+
try:
from yaml import CLoader as Loader
except ImportError:
@@ -46,8 +45,8 @@ def click_get_defaults_from_file(ctx, param, value):
else:
input_file = os.path.join(config.CONFIG_LOCAL_DIR, value)
try:
- with open(input_file, 'r') as f:
- _config: Dict[str, Dict[str, Any]] = load(f.read(), Loader=Loader)
+ with open(input_file) as f:
+ _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader)
ctx.default_map = _config.get(ctx.command.name, {})
except Exception as e:
raise click.BadParameter(f"Failed to load config file: {e}")
@@ -55,7 +54,7 @@ def click_get_defaults_from_file(ctx, param, value):
def click_parameter_decorators_from_typed_dict(
- typed_dict: Type,
+ typed_dict: type,
) -> Callable[[click.decorators.FC], click.decorators.FC]:
"""A convenience method decorator that will read in a TypedDict with parameters defined by Annotated types.
from .models import CaseConfig, CaseType, DBCaseConfig, DBConfig, TaskConfig, TaskStage
@@ -91,15 +90,12 @@ def foo(**parameters: Unpack[FooTypedDict]):
decorators = []
for _, t in get_type_hints(typed_dict, include_extras=True).items():
assert get_origin(t) is Annotated
- if (
- len(t.__metadata__) == 1
- and t.__metadata__[0].__module__ == "click.decorators"
- ):
+ if len(t.__metadata__) == 1 and t.__metadata__[0].__module__ == "click.decorators":
# happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1)
decorators.append(t.__metadata__[0])
else:
raise RuntimeError(
- "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring"
+ "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring",
)
def deco(f):
@@ -132,11 +128,11 @@ def parse_task_stages(
load: bool,
search_serial: bool,
search_concurrent: bool,
-) -> List[TaskStage]:
+) -> list[TaskStage]:
stages = []
if load and not drop_old:
raise RuntimeError("Dropping old data cannot be skipped if loading data")
- elif drop_old and not load:
+ if drop_old and not load:
raise RuntimeError("Load cannot be skipped if dropping old data")
if drop_old:
stages.append(TaskStage.DROP_OLD)
@@ -149,12 +145,19 @@ def parse_task_stages(
return stages
-def check_custom_case_parameters(ctx, param, value):
- if ctx.params.get("case_type") == "PerformanceCustomDataset":
- if value is None:
- raise click.BadParameter("Custom case parameters\
- \n--custom-case-name\n--custom-dataset-name\n--custom-dataset-dir\n--custom-dataset-size \
- \n--custom-dataset-dim\n--custom-dataset-file-count\n are required")
+# ruff: noqa
+def check_custom_case_parameters(ctx: any, param: any, value: any):
+ if ctx.params.get("case_type") == "PerformanceCustomDataset" and value is None:
+ raise click.BadParameter(
+ """ Custom case parameters
+--custom-case-name
+--custom-dataset-name
+--custom-dataset-dir
+--custom-dataset-sizes
+--custom-dataset-dim
+--custom-dataset-file-count
+are required """,
+ )
return value
@@ -175,7 +178,7 @@ def get_custom_case_config(parameters: dict) -> dict:
"file_count": parameters["custom_dataset_file_count"],
"use_shuffled": parameters["custom_dataset_use_shuffled"],
"with_gt": parameters["custom_dataset_with_gt"],
- }
+ },
}
return custom_case_config
@@ -186,12 +189,14 @@ def get_custom_case_config(parameters: dict) -> dict:
class CommonTypedDict(TypedDict):
config_file: Annotated[
bool,
- click.option('--config-file',
- type=click.Path(),
- callback=click_get_defaults_from_file,
- is_eager=True,
- expose_value=False,
- help='Read configuration from yaml file'),
+ click.option(
+ "--config-file",
+ type=click.Path(),
+ callback=click_get_defaults_from_file,
+ is_eager=True,
+ expose_value=False,
+ help="Read configuration from yaml file",
+ ),
]
drop_old: Annotated[
bool,
@@ -246,9 +251,11 @@ class CommonTypedDict(TypedDict):
db_label: Annotated[
str,
click.option(
- "--db-label", type=str, help="Db label, default: date in ISO format",
+ "--db-label",
+ type=str,
+ help="Db label, default: date in ISO format",
show_default=True,
- default=datetime.now().isoformat()
+ default=datetime.now().isoformat(),
),
]
dry_run: Annotated[
@@ -282,7 +289,7 @@ class CommonTypedDict(TypedDict):
),
]
num_concurrency: Annotated[
- List[str],
+ list[str],
click.option(
"--num-concurrency",
type=str,
@@ -298,7 +305,7 @@ class CommonTypedDict(TypedDict):
"--custom-case-name",
help="Custom dataset case name",
callback=check_custom_case_parameters,
- )
+ ),
]
custom_case_description: Annotated[
str,
@@ -307,7 +314,7 @@ class CommonTypedDict(TypedDict):
help="Custom dataset case description",
default="This is a customized dataset.",
show_default=True,
- )
+ ),
]
custom_case_load_timeout: Annotated[
int,
@@ -316,7 +323,7 @@ class CommonTypedDict(TypedDict):
help="Custom dataset case load timeout",
default=36000,
show_default=True,
- )
+ ),
]
custom_case_optimize_timeout: Annotated[
int,
@@ -325,7 +332,7 @@ class CommonTypedDict(TypedDict):
help="Custom dataset case optimize timeout",
default=36000,
show_default=True,
- )
+ ),
]
custom_dataset_name: Annotated[
str,
@@ -397,60 +404,60 @@ class CommonTypedDict(TypedDict):
class HNSWBaseTypedDict(TypedDict):
- m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m")]
+ m: Annotated[int | None, click.option("--m", type=int, help="hnsw m")]
ef_construction: Annotated[
- Optional[int],
+ int | None,
click.option("--ef-construction", type=int, help="hnsw ef-construction"),
]
class HNSWBaseRequiredTypedDict(TypedDict):
- m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m", required=True)]
+ m: Annotated[int | None, click.option("--m", type=int, help="hnsw m", required=True)]
ef_construction: Annotated[
- Optional[int],
+ int | None,
click.option("--ef-construction", type=int, help="hnsw ef-construction", required=True),
]
class HNSWFlavor1(HNSWBaseTypedDict):
ef_search: Annotated[
- Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", is_eager=True)
+ int | None,
+ click.option("--ef-search", type=int, help="hnsw ef-search", is_eager=True),
]
class HNSWFlavor2(HNSWBaseTypedDict):
ef_runtime: Annotated[
- Optional[int], click.option("--ef-runtime", type=int, help="hnsw ef-runtime")
+ int | None,
+ click.option("--ef-runtime", type=int, help="hnsw ef-runtime"),
]
class HNSWFlavor3(HNSWBaseRequiredTypedDict):
ef_search: Annotated[
- Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", required=True)
+ int | None,
+ click.option("--ef-search", type=int, help="hnsw ef-search", required=True),
]
class IVFFlatTypedDict(TypedDict):
- lists: Annotated[
- Optional[int], click.option("--lists", type=int, help="ivfflat lists")
- ]
- probes: Annotated[
- Optional[int], click.option("--probes", type=int, help="ivfflat probes")
- ]
+ lists: Annotated[int | None, click.option("--lists", type=int, help="ivfflat lists")]
+ probes: Annotated[int | None, click.option("--probes", type=int, help="ivfflat probes")]
class IVFFlatTypedDictN(TypedDict):
nlist: Annotated[
- Optional[int], click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True)
+ int | None,
+ click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True),
]
nprobe: Annotated[
- Optional[int], click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True)
+ int | None,
+ click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True),
]
@click.group()
-def cli():
- ...
+def cli(): ...
def run(
@@ -482,9 +489,7 @@ def run(
custom_case=get_custom_case_config(parameters),
),
stages=parse_task_stages(
- (
- False if not parameters["load"] else parameters["drop_old"]
- ), # only drop old data if loading new data
+ (False if not parameters["load"] else parameters["drop_old"]), # only drop old data if loading new data
parameters["load"],
parameters["search_serial"],
parameters["search_concurrent"],
@@ -493,7 +498,7 @@ def run(
log.info(f"Task:\n{pformat(task)}\n")
if not parameters["dry_run"]:
- benchMarkRunner.run([task])
+ benchmark_runner.run([task])
time.sleep(5)
if global_result_future:
wait([global_result_future])
diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py
index f9ad69ceb..5e3798691 100644
--- a/vectordb_bench/cli/vectordbbench.py
+++ b/vectordb_bench/cli/vectordbbench.py
@@ -1,16 +1,15 @@
-from ..backend.clients.pgvector.cli import PgVectorHNSW
+from ..backend.clients.alloydb.cli import AlloyDBScaNN
+from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
+from ..backend.clients.memorydb.cli import MemoryDB
+from ..backend.clients.milvus.cli import MilvusAutoIndex
+from ..backend.clients.pgdiskann.cli import PgDiskAnn
from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat
+from ..backend.clients.pgvector.cli import PgVectorHNSW
from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn
-from ..backend.clients.pgdiskann.cli import PgDiskAnn
from ..backend.clients.redis.cli import Redis
-from ..backend.clients.memorydb.cli import MemoryDB
from ..backend.clients.test.cli import Test
from ..backend.clients.weaviate_cloud.cli import Weaviate
from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex
-from ..backend.clients.milvus.cli import MilvusAutoIndex
-from ..backend.clients.aws_opensearch.cli import AWSOpenSearch
-from ..backend.clients.alloydb.cli import AlloyDBScaNN
-
from .cli import cli
cli.add_command(PgVectorHNSW)
diff --git a/vectordb_bench/frontend/components/check_results/charts.py b/vectordb_bench/frontend/components/check_results/charts.py
index 9e869b479..0e74d2752 100644
--- a/vectordb_bench/frontend/components/check_results/charts.py
+++ b/vectordb_bench/frontend/components/check_results/charts.py
@@ -1,8 +1,7 @@
-from vectordb_bench.backend.cases import Case
from vectordb_bench.frontend.components.check_results.expanderStyle import (
initMainExpanderStyle,
)
-from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap
+from vectordb_bench.metric import metric_order, isLowerIsBetterMetric, metric_unit_map
from vectordb_bench.frontend.config.styles import *
from vectordb_bench.models import ResultLabel
import plotly.express as px
@@ -21,9 +20,7 @@ def drawCharts(st, allData, failedTasks, caseNames: list[str]):
def showFailedDBs(st, errorDBs):
failedDBs = [db for db, label in errorDBs.items() if label == ResultLabel.FAILED]
- timeoutDBs = [
- db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE
- ]
+ timeoutDBs = [db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE]
showFailedText(st, "Failed", failedDBs)
showFailedText(st, "Timeout", timeoutDBs)
@@ -41,7 +38,7 @@ def drawChart(data, st, key_prefix: str):
metricsSet = set()
for d in data:
metricsSet = metricsSet.union(d["metricsSet"])
- showMetrics = [metric for metric in metricOrder if metric in metricsSet]
+ showMetrics = [metric for metric in metric_order if metric in metricsSet]
for i, metric in enumerate(showMetrics):
container = st.container()
@@ -72,9 +69,7 @@ def getLabelToShapeMap(data):
else:
usedShapes.add(labelIndexMap[label] % len(PATTERN_SHAPES))
- labelToShapeMap = {
- label: getPatternShape(index) for label, index in labelIndexMap.items()
- }
+ labelToShapeMap = {label: getPatternShape(index) for label, index in labelIndexMap.items()}
return labelToShapeMap
@@ -96,11 +91,9 @@ def drawMetricChart(data, metric, st, key: str):
xpadding = (xmax - xmin) / 16
xpadding_multiplier = 1.8
xrange = [xmin, xmax + xpadding * xpadding_multiplier]
- unit = metricUnitMap.get(metric, "")
+ unit = metric_unit_map.get(metric, "")
labelToShapeMap = getLabelToShapeMap(dataWithMetric)
- categoryorder = (
- "total descending" if isLowerIsBetterMetric(metric) else "total ascending"
- )
+ categoryorder = "total descending" if isLowerIsBetterMetric(metric) else "total ascending"
fig = px.bar(
dataWithMetric,
x=metric,
@@ -137,18 +130,14 @@ def drawMetricChart(data, metric, st, key: str):
color="#333",
size=12,
),
- marker=dict(
- pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7)
- ),
+ marker=dict(pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7)),
texttemplate="%{x:,.4~r}" + unit,
)
fig.update_layout(
margin=dict(l=0, r=0, t=48, b=12, pad=8),
bargap=0.25,
showlegend=False,
- legend=dict(
- orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""
- ),
+ legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""),
# legend=dict(orientation="v", title=""),
yaxis={"categoryorder": categoryorder},
title=dict(
diff --git a/vectordb_bench/frontend/components/check_results/data.py b/vectordb_bench/frontend/components/check_results/data.py
index b3cac21e1..94d1b4eab 100644
--- a/vectordb_bench/frontend/components/check_results/data.py
+++ b/vectordb_bench/frontend/components/check_results/data.py
@@ -1,6 +1,5 @@
from collections import defaultdict
from dataclasses import asdict
-from vectordb_bench.backend.cases import Case
from vectordb_bench.metric import isLowerIsBetterMetric
from vectordb_bench.models import CaseResult, ResultLabel
@@ -24,10 +23,7 @@ def getFilterTasks(
task
for task in tasks
if task.task_config.db_name in dbNames
- and task.task_config.case_config.case_id.case_cls(
- task.task_config.case_config.custom_case
- ).name
- in caseNames
+ and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames
]
return filterTasks
@@ -39,9 +35,7 @@ def mergeTasks(tasks: list[CaseResult]):
db = task.task_config.db.value
db_label = task.task_config.db_config.db_label or ""
version = task.task_config.db_config.version or ""
- case = task.task_config.case_config.case_id.case_cls(
- task.task_config.case_config.custom_case
- )
+ case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case)
dbCaseMetricsMap[db_name][case.name] = {
"db": db,
"db_label": db_label,
@@ -86,9 +80,7 @@ def mergeTasks(tasks: list[CaseResult]):
def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict:
metrics = {**metrics_1}
for key, value in metrics_2.items():
- metrics[key] = (
- getBetterMetric(key, value, metrics[key]) if key in metrics else value
- )
+ metrics[key] = getBetterMetric(key, value, metrics[key]) if key in metrics else value
return metrics
@@ -99,11 +91,7 @@ def getBetterMetric(metric, value_1, value_2):
return value_2
if value_2 < 1e-7:
return value_1
- return (
- min(value_1, value_2)
- if isLowerIsBetterMetric(metric)
- else max(value_1, value_2)
- )
+ return min(value_1, value_2) if isLowerIsBetterMetric(metric) else max(value_1, value_2)
except Exception:
return value_1
diff --git a/vectordb_bench/frontend/components/check_results/filters.py b/vectordb_bench/frontend/components/check_results/filters.py
index e60efb2e1..129c1d5ae 100644
--- a/vectordb_bench/frontend/components/check_results/filters.py
+++ b/vectordb_bench/frontend/components/check_results/filters.py
@@ -20,23 +20,17 @@ def getshownData(results: list[TestResult], st):
shownResults = getshownResults(results, st)
showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st)
- shownData, failedTasks = getChartData(
- shownResults, showDBNames, showCaseNames)
+ shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames)
return shownData, failedTasks, showCaseNames
def getshownResults(results: list[TestResult], st) -> list[CaseResult]:
resultSelectOptions = [
- result.task_label
- if result.task_label != result.run_id
- else f"res-{result.run_id[:4]}"
- for result in results
+ result.task_label if result.task_label != result.run_id else f"res-{result.run_id[:4]}" for result in results
]
if len(resultSelectOptions) == 0:
- st.write(
- "There are no results to display. Please wait for the task to complete or run a new task."
- )
+ st.write("There are no results to display. Please wait for the task to complete or run a new task.")
return []
selectedResultSelectedOptions = st.multiselect(
@@ -58,13 +52,12 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[st
allDbNames = list(set({res.task_config.db_name for res in result}))
allDbNames.sort()
allCases: list[Case] = [
- res.task_config.case_config.case_id.case_cls(
- res.task_config.case_config.custom_case)
- for res in result
+ res.task_config.case_config.case_id.case_cls(res.task_config.case_config.custom_case) for res in result
]
allCaseNameSet = set({case.name for case in allCases})
- allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \
- [case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER]
+ allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [
+ case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER
+ ]
# DB Filter
dbFilterContainer = st.container()
@@ -120,8 +113,7 @@ def filterView(container, header, options, col, optionLables=None):
)
if optionLables is None:
optionLables = options
- isActive = {option: st.session_state[selectAllState]
- for option in optionLables}
+ isActive = {option: st.session_state[selectAllState] for option in optionLables}
for i, option in enumerate(optionLables):
isActive[option] = columns[i % col].checkbox(
optionLables[i],
diff --git a/vectordb_bench/frontend/components/check_results/nav.py b/vectordb_bench/frontend/components/check_results/nav.py
index ad070ab35..f95e43d7a 100644
--- a/vectordb_bench/frontend/components/check_results/nav.py
+++ b/vectordb_bench/frontend/components/check_results/nav.py
@@ -7,15 +7,15 @@ def NavToRunTest(st):
navClick = st.button("Run Your Test >")
if navClick:
switch_page("run test")
-
-
+
+
def NavToQuriesPerDollar(st):
st.subheader("Compare qps with price.")
navClick = st.button("QP$ (Quries per Dollar) >")
if navClick:
switch_page("quries_per_dollar")
-
-
+
+
def NavToResults(st, key="nav-to-results"):
navClick = st.button("< Back to Results", key=key)
if navClick:
diff --git a/vectordb_bench/frontend/components/check_results/priceTable.py b/vectordb_bench/frontend/components/check_results/priceTable.py
index 06d7a6a6b..f2c0ae001 100644
--- a/vectordb_bench/frontend/components/check_results/priceTable.py
+++ b/vectordb_bench/frontend/components/check_results/priceTable.py
@@ -7,9 +7,7 @@
def priceTable(container, data):
- dbAndLabelSet = {
- (d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value
- }
+ dbAndLabelSet = {(d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value}
dbAndLabelList = list(dbAndLabelSet)
dbAndLabelList.sort()
diff --git a/vectordb_bench/frontend/components/check_results/stPageConfig.py b/vectordb_bench/frontend/components/check_results/stPageConfig.py
index c03cceab4..9521be285 100644
--- a/vectordb_bench/frontend/components/check_results/stPageConfig.py
+++ b/vectordb_bench/frontend/components/check_results/stPageConfig.py
@@ -9,10 +9,11 @@ def initResultsPageConfig(st):
# initial_sidebar_state="collapsed",
)
+
def initRunTestPageConfig(st):
st.set_page_config(
page_title=PAGE_TITLE,
page_icon=FAVICON,
# layout="wide",
initial_sidebar_state="collapsed",
- )
\ No newline at end of file
+ )
diff --git a/vectordb_bench/frontend/components/concurrent/charts.py b/vectordb_bench/frontend/components/concurrent/charts.py
index 11379c4b3..83e2961d6 100644
--- a/vectordb_bench/frontend/components/concurrent/charts.py
+++ b/vectordb_bench/frontend/components/concurrent/charts.py
@@ -14,24 +14,24 @@ def drawChartsByCase(allData, showCaseNames: list[str], st, latency_type: str):
data = [
{
"conc_num": caseData["conc_num_list"][i],
- "qps": caseData["conc_qps_list"][i]
- if 0 <= i < len(caseData["conc_qps_list"])
- else 0,
- "latency_p99": caseData["conc_latency_p99_list"][i] * 1000
- if 0 <= i < len(caseData["conc_latency_p99_list"])
- else 0,
- "latency_avg": caseData["conc_latency_avg_list"][i] * 1000
- if 0 <= i < len(caseData["conc_latency_avg_list"])
- else 0,
+ "qps": (caseData["conc_qps_list"][i] if 0 <= i < len(caseData["conc_qps_list"]) else 0),
+ "latency_p99": (
+ caseData["conc_latency_p99_list"][i] * 1000
+ if 0 <= i < len(caseData["conc_latency_p99_list"])
+ else 0
+ ),
+ "latency_avg": (
+ caseData["conc_latency_avg_list"][i] * 1000
+ if 0 <= i < len(caseData["conc_latency_avg_list"])
+ else 0
+ ),
"db_name": caseData["db_name"],
"db": caseData["db"],
}
for caseData in caseDataList
for i in range(len(caseData["conc_num_list"]))
]
- drawChart(
- data, chartContainer, key=f"{caseName}-qps-p99", x_metric=latency_type
- )
+ drawChart(data, chartContainer, key=f"{caseName}-qps-p99", x_metric=latency_type)
def getRange(metric, data, padding_multipliers):
diff --git a/vectordb_bench/frontend/components/custom/displayCustomCase.py b/vectordb_bench/frontend/components/custom/displayCustomCase.py
index 3f5266051..ac111c883 100644
--- a/vectordb_bench/frontend/components/custom/displayCustomCase.py
+++ b/vectordb_bench/frontend/components/custom/displayCustomCase.py
@@ -1,4 +1,3 @@
-
from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig
@@ -6,26 +5,33 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key):
columns = st.columns([1, 2])
customCase.dataset_config.name = columns[0].text_input(
- "Name", key=f"{key}_name", value=customCase.dataset_config.name)
+ "Name", key=f"{key}_name", value=customCase.dataset_config.name
+ )
customCase.name = f"{customCase.dataset_config.name} (Performace Case)"
customCase.dataset_config.dir = columns[1].text_input(
- "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir)
+ "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir
+ )
columns = st.columns(4)
customCase.dataset_config.dim = columns[0].number_input(
- "dim", key=f"{key}_dim", value=customCase.dataset_config.dim)
+ "dim", key=f"{key}_dim", value=customCase.dataset_config.dim
+ )
customCase.dataset_config.size = columns[1].number_input(
- "size", key=f"{key}_size", value=customCase.dataset_config.size)
+ "size", key=f"{key}_size", value=customCase.dataset_config.size
+ )
customCase.dataset_config.metric_type = columns[2].selectbox(
- "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"])
+ "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]
+ )
customCase.dataset_config.file_count = columns[3].number_input(
- "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count)
+ "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count
+ )
columns = st.columns(4)
customCase.dataset_config.use_shuffled = columns[0].checkbox(
- "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled)
+ "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled
+ )
customCase.dataset_config.with_gt = columns[1].checkbox(
- "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt)
+ "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt
+ )
- customCase.description = st.text_area(
- "description", key=f"{key}_description", value=customCase.description)
+ customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description)
diff --git a/vectordb_bench/frontend/components/custom/displaypPrams.py b/vectordb_bench/frontend/components/custom/displaypPrams.py
index a300b45e1..b677e5909 100644
--- a/vectordb_bench/frontend/components/custom/displaypPrams.py
+++ b/vectordb_bench/frontend/components/custom/displaypPrams.py
@@ -1,5 +1,6 @@
def displayParams(st):
- st.markdown("""
+ st.markdown(
+ """
- `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format.
- Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
- Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`.
@@ -8,4 +9,5 @@ def displayParams(st):
- `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files.
- `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order.
-""")
+"""
+ )
diff --git a/vectordb_bench/frontend/components/custom/getCustomConfig.py b/vectordb_bench/frontend/components/custom/getCustomConfig.py
index ede664b6a..668bfb6d5 100644
--- a/vectordb_bench/frontend/components/custom/getCustomConfig.py
+++ b/vectordb_bench/frontend/components/custom/getCustomConfig.py
@@ -32,8 +32,7 @@ def get_custom_configs():
def save_custom_configs(custom_configs: list[CustomDatasetConfig]):
with open(config.CUSTOM_CONFIG_DIR, "w") as f:
- json.dump([custom_config.dict()
- for custom_config in custom_configs], f, indent=4)
+ json.dump([custom_config.dict() for custom_config in custom_configs], f, indent=4)
def generate_custom_case():
diff --git a/vectordb_bench/frontend/components/custom/initStyle.py b/vectordb_bench/frontend/components/custom/initStyle.py
index 0e5129248..7ecbbf24e 100644
--- a/vectordb_bench/frontend/components/custom/initStyle.py
+++ b/vectordb_bench/frontend/components/custom/initStyle.py
@@ -12,4 +12,4 @@ def initStyle(st):
*/
""",
unsafe_allow_html=True,
- )
\ No newline at end of file
+ )
diff --git a/vectordb_bench/frontend/components/get_results/saveAsImage.py b/vectordb_bench/frontend/components/get_results/saveAsImage.py
index 9820e9575..4be2baec1 100644
--- a/vectordb_bench/frontend/components/get_results/saveAsImage.py
+++ b/vectordb_bench/frontend/components/get_results/saveAsImage.py
@@ -9,10 +9,12 @@
def load_unpkg(src: str) -> str:
return requests.get(src).text
+
def getResults(container, pageName="vectordb_bench"):
container.subheader("Get results")
saveAsImage(container, pageName)
+
def saveAsImage(container, pageName):
html2canvasJS = load_unpkg(HTML_2_CANVAS_URL)
container.write()
diff --git a/vectordb_bench/frontend/components/run_test/caseSelector.py b/vectordb_bench/frontend/components/run_test/caseSelector.py
index b25618271..e3d28238d 100644
--- a/vectordb_bench/frontend/components/run_test/caseSelector.py
+++ b/vectordb_bench/frontend/components/run_test/caseSelector.py
@@ -1,6 +1,4 @@
-
from vectordb_bench.frontend.config.styles import *
-from vectordb_bench.backend.cases import CaseType
from vectordb_bench.frontend.config.dbCaseConfigs import *
from collections import defaultdict
@@ -23,8 +21,7 @@ def caseSelector(st, activedDbList: list[DB]):
dbToCaseConfigs = defaultdict(lambda: defaultdict(dict))
caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()]
for caseCluster in caseClusters:
- activedCaseList += caseClusterExpander(
- st, caseCluster, dbToCaseClusterConfigs, activedDbList)
+ activedCaseList += caseClusterExpander(st, caseCluster, dbToCaseClusterConfigs, activedDbList)
for db in dbToCaseClusterConfigs:
for uiCaseItem in dbToCaseClusterConfigs[db]:
for case in uiCaseItem.cases:
@@ -40,8 +37,7 @@ def caseClusterExpander(st, caseCluster: UICaseItemCluster, dbToCaseClusterConfi
if uiCaseItem.isLine:
addHorizontalLine(expander)
else:
- activedCases += caseItemCheckbox(expander,
- dbToCaseClusterConfigs, uiCaseItem, activedDbList)
+ activedCases += caseItemCheckbox(expander, dbToCaseClusterConfigs, uiCaseItem, activedDbList)
return activedCases
@@ -53,9 +49,7 @@ def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, actived
)
if selected:
- caseConfigSetting(
- st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList
- )
+ caseConfigSetting(st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList)
return uiCaseItem.cases if selected else []
diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
index 257608413..800e6dede 100644
--- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
+++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py
@@ -42,10 +42,7 @@ def dbConfigSettingItem(st, activeDb: DB):
# db config (unique)
for key, property in properties.items():
- if (
- key not in dbConfigClass.common_short_configs()
- and key not in dbConfigClass.common_long_configs()
- ):
+ if key not in dbConfigClass.common_short_configs() and key not in dbConfigClass.common_long_configs():
column = columns[idx % DB_CONFIG_SETTING_COLUMNS]
idx += 1
dbConfig[key] = column.text_input(
diff --git a/vectordb_bench/frontend/components/run_test/dbSelector.py b/vectordb_bench/frontend/components/run_test/dbSelector.py
index ccf0168c6..e20ee5059 100644
--- a/vectordb_bench/frontend/components/run_test/dbSelector.py
+++ b/vectordb_bench/frontend/components/run_test/dbSelector.py
@@ -22,7 +22,7 @@ def dbSelector(st):
dbIsActived[db] = column.checkbox(db.name)
try:
column.image(DB_TO_ICON.get(db, ""))
- except MediaFileStorageError as e:
+ except MediaFileStorageError:
column.warning(f"{db.name} image not available")
pass
activedDbList = [db for db in DB_LIST if dbIsActived[db]]
diff --git a/vectordb_bench/frontend/components/run_test/generateTasks.py b/vectordb_bench/frontend/components/run_test/generateTasks.py
index 828913f30..d8a678ffc 100644
--- a/vectordb_bench/frontend/components/run_test/generateTasks.py
+++ b/vectordb_bench/frontend/components/run_test/generateTasks.py
@@ -7,13 +7,13 @@ def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[Cas
for db in activedDbList:
for case in activedCaseList:
task = TaskConfig(
- db=db.value,
- db_config=dbConfigs[db],
- case_config=case,
- db_case_config=db.case_config_cls(
- allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None)
- )(**{key.value: value for key, value in allCaseConfigs[db][case].items()}),
- )
+ db=db.value,
+ db_config=dbConfigs[db],
+ case_config=case,
+ db_case_config=db.case_config_cls(allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None))(
+ **{key.value: value for key, value in allCaseConfigs[db][case].items()}
+ ),
+ )
tasks.append(task)
-
+
return tasks
diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py
index f824dd9d9..426095397 100644
--- a/vectordb_bench/frontend/components/run_test/submitTask.py
+++ b/vectordb_bench/frontend/components/run_test/submitTask.py
@@ -1,6 +1,6 @@
from datetime import datetime
-from vectordb_bench.frontend.config.styles import *
-from vectordb_bench.interface import benchMarkRunner
+from vectordb_bench.frontend.config import styles
+from vectordb_bench.interface import benchmark_runner
def submitTask(st, tasks, isAllValid):
@@ -27,10 +27,8 @@ def submitTask(st, tasks, isAllValid):
def taskLabelInput(st):
defaultTaskLabel = datetime.now().strftime("%Y%m%d%H")
- columns = st.columns(TASK_LABEL_INPUT_COLUMNS)
- taskLabel = columns[0].text_input(
- "task_label", defaultTaskLabel, label_visibility="collapsed"
- )
+ columns = st.columns(styles.TASK_LABEL_INPUT_COLUMNS)
+ taskLabel = columns[0].text_input("task_label", defaultTaskLabel, label_visibility="collapsed")
return taskLabel
@@ -46,10 +44,8 @@ def advancedSettings(st):
)
container = st.columns([1, 2])
- k = container[0].number_input("k",min_value=1, value=100, label_visibility="collapsed")
- container[1].caption(
- "K value for number of nearest neighbors to search"
- )
+ k = container[0].number_input("k", min_value=1, value=100, label_visibility="collapsed")
+ container[1].caption("K value for number of nearest neighbors to search")
return index_already_exists, use_aliyun, k
@@ -58,20 +54,20 @@ def controlPanel(st, tasks, taskLabel, isAllValid):
index_already_exists, use_aliyun, k = advancedSettings(st)
def runHandler():
- benchMarkRunner.set_drop_old(not index_already_exists)
+ benchmark_runner.set_drop_old(not index_already_exists)
for task in tasks:
task.case_config.k = k
- benchMarkRunner.set_download_address(use_aliyun)
- benchMarkRunner.run(tasks, taskLabel)
+ benchmark_runner.set_download_address(use_aliyun)
+ benchmark_runner.run(tasks, taskLabel)
def stopHandler():
- benchMarkRunner.stop_running()
+ benchmark_runner.stop_running()
- isRunning = benchMarkRunner.has_running()
+ isRunning = benchmark_runner.has_running()
if isRunning:
- currentTaskId = benchMarkRunner.get_current_task_id()
- tasksCount = benchMarkRunner.get_tasks_count()
+ currentTaskId = benchmark_runner.get_current_task_id()
+ tasksCount = benchmark_runner.get_tasks_count()
text = f":running: Running Task {currentTaskId} / {tasksCount}"
st.progress(currentTaskId / tasksCount, text=text)
@@ -89,7 +85,7 @@ def stopHandler():
)
else:
- errorText = benchMarkRunner.latest_error or ""
+ errorText = benchmark_runner.latest_error or ""
if len(errorText) > 0:
st.error(errorText)
disabled = True if len(tasks) == 0 or not isAllValid else False
diff --git a/vectordb_bench/frontend/components/tables/data.py b/vectordb_bench/frontend/components/tables/data.py
index 96134c7ff..fbe83f197 100644
--- a/vectordb_bench/frontend/components/tables/data.py
+++ b/vectordb_bench/frontend/components/tables/data.py
@@ -1,12 +1,11 @@
from dataclasses import asdict
-from vectordb_bench.backend.cases import CaseType
-from vectordb_bench.interface import benchMarkRunner
+from vectordb_bench.interface import benchmark_runner
from vectordb_bench.models import CaseResult, ResultLabel
import pandas as pd
def getNewResults():
- allResults = benchMarkRunner.get_results()
+ allResults = benchmark_runner.get_results()
newResults: list[CaseResult] = []
for res in allResults:
@@ -14,7 +13,6 @@ def getNewResults():
for result in results:
if result.label == ResultLabel.NORMAL:
newResults.append(result)
-
df = pd.DataFrame(formatData(newResults))
return df
@@ -26,7 +24,6 @@ def formatData(caseResults: list[CaseResult]):
db = caseResult.task_config.db.value
db_label = caseResult.task_config.db_config.db_label
case_config = caseResult.task_config.case_config
- db_case_config = caseResult.task_config.db_case_config
case = case_config.case_id.case_cls()
filter_rate = case.filter_rate
dataset = case.dataset.data.name
@@ -41,4 +38,4 @@ def formatData(caseResults: list[CaseResult]):
**metrics,
}
)
- return data
\ No newline at end of file
+ return data
diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py
index 7f076a2dd..e004f2ba7 100644
--- a/vectordb_bench/frontend/config/dbCaseConfigs.py
+++ b/vectordb_bench/frontend/config/dbCaseConfigs.py
@@ -33,9 +33,9 @@ class UICaseItem(BaseModel):
def __init__(
self,
isLine: bool = False,
- case_id: CaseType = None,
- custom_case: dict = {},
- cases: list[CaseConfig] = [],
+ case_id: CaseType | None = None,
+ custom_case: dict | None = None,
+ cases: list[CaseConfig] | None = None,
label: str = "",
description: str = "",
caseLabel: CaseLabel = CaseLabel.Performance,
@@ -70,17 +70,13 @@ class UICaseItemCluster(BaseModel):
def get_custom_case_items() -> list[UICaseItem]:
custom_configs = get_custom_configs()
return [
- UICaseItem(
- case_id=CaseType.PerformanceCustomDataset, custom_case=custom_config.dict()
- )
+ UICaseItem(case_id=CaseType.PerformanceCustomDataset, custom_case=custom_config.dict())
for custom_config in custom_configs
]
def get_custom_case_cluter() -> UICaseItemCluster:
- return UICaseItemCluster(
- label="Custom Search Performance Test", uiCaseItems=get_custom_case_items()
- )
+ return UICaseItemCluster(label="Custom Search Performance Test", uiCaseItems=get_custom_case_items())
UI_CASE_CLUSTERS: list[UICaseItemCluster] = [
@@ -224,8 +220,7 @@ class CaseConfigInput(BaseModel):
"max": 300,
"value": 32,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value,
)
CaseConfigParamInput_l_value_ib = CaseConfigInput(
@@ -236,8 +231,7 @@ class CaseConfigInput(BaseModel):
"max": 300,
"value": 50,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value,
)
CaseConfigParamInput_l_value_is = CaseConfigInput(
@@ -248,8 +242,7 @@ class CaseConfigInput(BaseModel):
"max": 300,
"value": 40,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value,
)
CaseConfigParamInput_num_neighbors = CaseConfigInput(
@@ -260,8 +253,7 @@ class CaseConfigInput(BaseModel):
"max": 300,
"value": 50,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.STREAMING_DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value,
)
CaseConfigParamInput_search_list_size = CaseConfigInput(
@@ -272,8 +264,7 @@ class CaseConfigInput(BaseModel):
"max": 300,
"value": 100,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.STREAMING_DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value,
)
CaseConfigParamInput_max_alpha = CaseConfigInput(
@@ -284,8 +275,7 @@ class CaseConfigInput(BaseModel):
"max": 2.0,
"value": 1.2,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.STREAMING_DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value,
)
CaseConfigParamInput_num_dimensions = CaseConfigInput(
@@ -296,8 +286,7 @@ class CaseConfigInput(BaseModel):
"max": 2000,
"value": 0,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.STREAMING_DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value,
)
CaseConfigParamInput_query_search_list_size = CaseConfigInput(
@@ -308,8 +297,7 @@ class CaseConfigInput(BaseModel):
"max": 150,
"value": 100,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.STREAMING_DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value,
)
@@ -321,8 +309,7 @@ class CaseConfigInput(BaseModel):
"max": 150,
"value": 50,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.STREAMING_DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value,
)
CaseConfigParamInput_IndexType_PgVector = CaseConfigInput(
@@ -358,8 +345,7 @@ class CaseConfigInput(BaseModel):
"max": 64,
"value": 30,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value,
)
CaseConfigParamInput_m = CaseConfigInput(
@@ -370,8 +356,7 @@ class CaseConfigInput(BaseModel):
"max": 64,
"value": 16,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value,
)
@@ -383,8 +368,7 @@ class CaseConfigInput(BaseModel):
"max": 512,
"value": 360,
},
- isDisplayed=lambda config: config[CaseConfigParamType.IndexType]
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value,
)
CaseConfigParamInput_EFConstruction_Weaviate = CaseConfigInput(
@@ -480,8 +464,7 @@ class CaseConfigInput(BaseModel):
"max": 2000,
"value": 300,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value,
)
CaseConfigParamInput_EFSearch_PgVectoRS = CaseConfigInput(
@@ -492,8 +475,7 @@ class CaseConfigInput(BaseModel):
"max": 65535,
"value": 100,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value,
)
CaseConfigParamInput_EFConstruction_PgVector = CaseConfigInput(
@@ -504,8 +486,7 @@ class CaseConfigInput(BaseModel):
"max": 1024,
"value": 256,
},
- isDisplayed=lambda config: config[CaseConfigParamType.IndexType]
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value,
)
@@ -537,8 +518,7 @@ class CaseConfigInput(BaseModel):
"max": MAX_STREAMLIT_INT,
"value": 100,
},
- isDisplayed=lambda config: config[CaseConfigParamType.IndexType]
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value,
)
CaseConfigParamInput_EF_Weaviate = CaseConfigInput(
@@ -565,8 +545,7 @@ class CaseConfigInput(BaseModel):
"max": MAX_STREAMLIT_INT,
"value": 100,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.DISKANN.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value,
)
CaseConfigParamInput_Nlist = CaseConfigInput(
@@ -611,8 +590,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 0,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_IVF_PQ.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value],
)
@@ -624,8 +602,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 8,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_IVF_PQ.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value],
)
CaseConfigParamInput_intermediate_graph_degree = CaseConfigInput(
@@ -636,8 +613,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 64,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
CaseConfigParamInput_graph_degree = CaseConfigInput(
@@ -648,8 +624,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 32,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
CaseConfigParamInput_itopk_size = CaseConfigInput(
@@ -660,8 +635,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 128,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
CaseConfigParamInput_team_size = CaseConfigInput(
@@ -672,8 +646,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 0,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
CaseConfigParamInput_search_width = CaseConfigInput(
@@ -684,8 +657,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 4,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
CaseConfigParamInput_min_iterations = CaseConfigInput(
@@ -696,8 +668,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 0,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
CaseConfigParamInput_max_iterations = CaseConfigInput(
@@ -708,8 +679,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 0,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
CaseConfigParamInput_build_algo = CaseConfigInput(
@@ -718,8 +688,7 @@ class CaseConfigInput(BaseModel):
inputConfig={
"options": ["IVF_PQ", "NN_DESCENT"],
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.GPU_CAGRA.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value],
)
@@ -762,8 +731,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 10,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- in [IndexType.IVFFlat.value],
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.IVFFlat.value],
)
CaseConfigParamInput_Probes = CaseConfigInput(
@@ -784,8 +752,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 10,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.IVFFlat.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFFlat.value,
)
CaseConfigParamInput_Probes_PgVector = CaseConfigInput(
@@ -796,8 +763,7 @@ class CaseConfigInput(BaseModel):
"max": 65536,
"value": 1,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.IVFFlat.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFFlat.value,
)
CaseConfigParamInput_EFSearch_PgVector = CaseConfigInput(
@@ -808,8 +774,7 @@ class CaseConfigInput(BaseModel):
"max": 2048,
"value": 256,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None)
- == IndexType.HNSW.value,
+ isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value,
)
@@ -845,8 +810,7 @@ class CaseConfigInput(BaseModel):
inputConfig={
"options": ["x4", "x8", "x16", "x32", "x64"],
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
- == "product"
+ isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "product"
and config.get(CaseConfigParamType.IndexType, None)
in [
IndexType.HNSW.value,
@@ -885,8 +849,7 @@ class CaseConfigInput(BaseModel):
inputConfig={
"value": False,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
- == "bit"
+ isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "bit",
)
CaseConfigParamInput_quantized_fetch_limit_PgVector = CaseConfigInput(
@@ -899,8 +862,8 @@ class CaseConfigInput(BaseModel):
"max": 1000,
"value": 200,
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
- == "bit" and config.get(CaseConfigParamType.reranking, False)
+ isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "bit"
+ and config.get(CaseConfigParamType.reranking, False),
)
@@ -908,12 +871,10 @@ class CaseConfigInput(BaseModel):
label=CaseConfigParamType.rerankingMetric,
inputType=InputType.Option,
inputConfig={
- "options": [
- metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]
- ],
+ "options": [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]],
},
- isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None)
- == "bit" and config.get(CaseConfigParamType.reranking, False)
+ isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "bit"
+ and config.get(CaseConfigParamType.reranking, False),
)
@@ -1131,7 +1092,10 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_NumCandidates_ES,
]
-AWSOpensearchLoadingConfig = [CaseConfigParamInput_EFConstruction_AWSOpensearch, CaseConfigParamInput_M_AWSOpensearch]
+AWSOpensearchLoadingConfig = [
+ CaseConfigParamInput_EFConstruction_AWSOpensearch,
+ CaseConfigParamInput_M_AWSOpensearch,
+]
AWSOpenSearchPerformanceConfig = [
CaseConfigParamInput_EFConstruction_AWSOpensearch,
CaseConfigParamInput_M_AWSOpensearch,
@@ -1250,7 +1214,10 @@ class CaseConfigInput(BaseModel):
CaseConfigParamInput_max_parallel_workers_AlloyDB,
]
-AliyunElasticsearchLoadingConfig = [CaseConfigParamInput_EFConstruction_AliES, CaseConfigParamInput_M_AliES]
+AliyunElasticsearchLoadingConfig = [
+ CaseConfigParamInput_EFConstruction_AliES,
+ CaseConfigParamInput_M_AliES,
+]
AliyunElasticsearchPerformanceConfig = [
CaseConfigParamInput_EFConstruction_AliES,
CaseConfigParamInput_M_AliES,
diff --git a/vectordb_bench/frontend/pages/concurrent.py b/vectordb_bench/frontend/pages/concurrent.py
index 941675436..045239494 100644
--- a/vectordb_bench/frontend/pages/concurrent.py
+++ b/vectordb_bench/frontend/pages/concurrent.py
@@ -9,7 +9,7 @@
from vectordb_bench.frontend.components.concurrent.charts import drawChartsByCase
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
from vectordb_bench.frontend.config.styles import FAVICON
-from vectordb_bench.interface import benchMarkRunner
+from vectordb_bench.interface import benchmark_runner
from vectordb_bench.models import TestResult
@@ -25,7 +25,7 @@ def main():
# header
drawHeaderIcon(st)
- allResults = benchMarkRunner.get_results()
+ allResults = benchmark_runner.get_results()
def check_conc_data(res: TestResult):
case_results = res.results
@@ -57,9 +57,7 @@ def check_conc_data(res: TestResult):
# main
latency_type = st.radio("Latency Type", options=["latency_p99", "latency_avg"])
- drawChartsByCase(
- shownData, showCaseNames, st.container(), latency_type=latency_type
- )
+ drawChartsByCase(shownData, showCaseNames, st.container(), latency_type=latency_type)
# footer
footer(st.container())
diff --git a/vectordb_bench/frontend/pages/custom.py b/vectordb_bench/frontend/pages/custom.py
index 28c249f78..4f6beed91 100644
--- a/vectordb_bench/frontend/pages/custom.py
+++ b/vectordb_bench/frontend/pages/custom.py
@@ -1,13 +1,21 @@
+from functools import partial
import streamlit as st
from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon
-from vectordb_bench.frontend.components.custom.displayCustomCase import displayCustomCase
+from vectordb_bench.frontend.components.custom.displayCustomCase import (
+ displayCustomCase,
+)
from vectordb_bench.frontend.components.custom.displaypPrams import displayParams
-from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig, generate_custom_case, get_custom_configs, save_custom_configs
+from vectordb_bench.frontend.components.custom.getCustomConfig import (
+ CustomCaseConfig,
+ generate_custom_case,
+ get_custom_configs,
+ save_custom_configs,
+)
from vectordb_bench.frontend.components.custom.initStyle import initStyle
from vectordb_bench.frontend.config.styles import FAVICON, PAGE_TITLE
-class CustomCaseManager():
+class CustomCaseManager:
customCaseItems: list[CustomCaseConfig]
def __init__(self):
@@ -52,12 +60,25 @@ def main():
columns = expander.columns(8)
columns[0].button(
- "Save", key=f"{key}_", type="secondary", on_click=lambda: customCaseManager.save())
- columns[1].button(":red[Delete]", key=f"{key}_delete", type="secondary",
- on_click=lambda: customCaseManager.deleteCase(idx))
-
- st.button("\+ New Dataset", key=f"add_custom_configs",
- type="primary", on_click=lambda: customCaseManager.addCase())
+ "Save",
+ key=f"{key}_",
+ type="secondary",
+ on_click=lambda: customCaseManager.save(),
+ )
+ columns[1].button(
+ ":red[Delete]",
+ key=f"{key}_delete",
+ type="secondary",
+ # B023
+ on_click=partial(lambda idx: customCaseManager.deleteCase(idx), idx=idx),
+ )
+
+ st.button(
+ "\+ New Dataset",
+ key="add_custom_configs",
+ type="primary",
+ on_click=lambda: customCaseManager.addCase(),
+ )
if __name__ == "__main__":
diff --git a/vectordb_bench/frontend/pages/quries_per_dollar.py b/vectordb_bench/frontend/pages/quries_per_dollar.py
index 4a45181de..2822f2864 100644
--- a/vectordb_bench/frontend/pages/quries_per_dollar.py
+++ b/vectordb_bench/frontend/pages/quries_per_dollar.py
@@ -15,8 +15,8 @@
from vectordb_bench.frontend.components.check_results.charts import drawMetricChart
from vectordb_bench.frontend.components.check_results.filters import getshownData
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
-from vectordb_bench.frontend.config.styles import *
-from vectordb_bench.interface import benchMarkRunner
+
+from vectordb_bench.interface import benchmark_runner
from vectordb_bench.metric import QURIES_PER_DOLLAR_METRIC
@@ -27,7 +27,7 @@ def main():
# header
drawHeaderIcon(st)
- allResults = benchMarkRunner.get_results()
+ allResults = benchmark_runner.get_results()
st.title("Vector DB Benchmark (QP$)")
diff --git a/vectordb_bench/frontend/pages/run_test.py b/vectordb_bench/frontend/pages/run_test.py
index 1297743ae..3da8ea2c0 100644
--- a/vectordb_bench/frontend/pages/run_test.py
+++ b/vectordb_bench/frontend/pages/run_test.py
@@ -15,10 +15,10 @@
def main():
# set page config
initRunTestPageConfig(st)
-
+
# init style
initStyle(st)
-
+
# header
drawHeaderIcon(st)
@@ -48,11 +48,7 @@ def main():
activedCaseList, allCaseConfigs = caseSelector(caseSelectorContainer, activedDbList)
# generate tasks
- tasks = (
- generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs)
- if isAllValid
- else []
- )
+ tasks = generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs) if isAllValid else []
# submit
submitContainer = st.container()
diff --git a/vectordb_bench/frontend/utils.py b/vectordb_bench/frontend/utils.py
index 787b67d03..407dd497d 100644
--- a/vectordb_bench/frontend/utils.py
+++ b/vectordb_bench/frontend/utils.py
@@ -18,5 +18,5 @@ def addHorizontalLine(st):
def generate_random_string(length):
letters = string.ascii_letters + string.digits
- result = ''.join(random.choice(letters) for _ in range(length))
+ result = "".join(random.choice(letters) for _ in range(length))
return result
diff --git a/vectordb_bench/frontend/vdb_benchmark.py b/vectordb_bench/frontend/vdb_benchmark.py
index c76a2f3de..cf261bf1d 100644
--- a/vectordb_bench/frontend/vdb_benchmark.py
+++ b/vectordb_bench/frontend/vdb_benchmark.py
@@ -11,8 +11,8 @@
from vectordb_bench.frontend.components.check_results.charts import drawCharts
from vectordb_bench.frontend.components.check_results.filters import getshownData
from vectordb_bench.frontend.components.get_results.saveAsImage import getResults
-from vectordb_bench.frontend.config.styles import *
-from vectordb_bench.interface import benchMarkRunner
+
+from vectordb_bench.interface import benchmark_runner
def main():
@@ -22,7 +22,7 @@ def main():
# header
drawHeaderIcon(st)
- allResults = benchMarkRunner.get_results()
+ allResults = benchmark_runner.get_results()
st.title("Vector Database Benchmark")
st.caption(
@@ -32,9 +32,7 @@ def main():
# results selector and filter
resultSelectorContainer = st.sidebar.container()
- shownData, failedTasks, showCaseNames = getshownData(
- allResults, resultSelectorContainer
- )
+ shownData, failedTasks, showCaseNames = getshownData(allResults, resultSelectorContainer)
resultSelectorContainer.divider()
diff --git a/vectordb_bench/interface.py b/vectordb_bench/interface.py
index c765d0d63..615a9600d 100644
--- a/vectordb_bench/interface.py
+++ b/vectordb_bench/interface.py
@@ -5,6 +5,7 @@
import signal
import traceback
import uuid
+from collections.abc import Callable
from enum import Enum
from multiprocessing.connection import Connection
@@ -16,8 +17,15 @@
from .backend.result_collector import ResultCollector
from .backend.task_runner import TaskRunner
from .metric import Metric
-from .models import (CaseResult, LoadTimeoutError, PerformanceTimeoutError,
- ResultLabel, TaskConfig, TaskStage, TestResult)
+from .models import (
+ CaseResult,
+ LoadTimeoutError,
+ PerformanceTimeoutError,
+ ResultLabel,
+ TaskConfig,
+ TaskStage,
+ TestResult,
+)
log = logging.getLogger(__name__)
@@ -37,11 +45,9 @@ def __init__(self):
self.drop_old: bool = True
self.dataset_source: DatasetSource = DatasetSource.S3
-
def set_drop_old(self, drop_old: bool):
self.drop_old = drop_old
-
def set_download_address(self, use_aliyun: bool):
if use_aliyun:
self.dataset_source = DatasetSource.AliyunOSS
@@ -59,7 +65,9 @@ def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool:
log.warning("Empty tasks submitted")
return False
- log.debug(f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}")
+ log.debug(
+ f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}",
+ )
# Generate run_id
run_id = uuid.uuid4().hex
@@ -70,7 +78,12 @@ def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool:
self.latest_error = ""
try:
- self.running_task = Assembler.assemble_all(run_id, task_label, tasks, self.dataset_source)
+ self.running_task = Assembler.assemble_all(
+ run_id,
+ task_label,
+ tasks,
+ self.dataset_source,
+ )
self.running_task.display()
except ModuleNotFoundError as e:
msg = f"Please install client for database, error={e}"
@@ -119,7 +132,7 @@ def get_tasks_count(self) -> int:
return 0
def get_current_task_id(self) -> int:
- """ the index of current running task
+ """the index of current running task
return -1 if not running
"""
if not self.running_task:
@@ -153,18 +166,18 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non
task_config=runner.config,
)
- # drop_old = False if latest_runner and runner == latest_runner else config.DROP_OLD
- # drop_old = config.DROP_OLD
drop_old = TaskStage.DROP_OLD in runner.config.stages
- if latest_runner and runner == latest_runner:
- drop_old = False
- elif not self.drop_old:
+ if (latest_runner and runner == latest_runner) or not self.drop_old:
drop_old = False
try:
- log.info(f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}")
+ log.info(
+ f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}",
+ )
case_res.metrics = runner.run(drop_old)
- log.info(f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, "
- f"result={case_res.metrics}, label={case_res.label}")
+ log.info(
+ f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, "
+ f"result={case_res.metrics}, label={case_res.label}",
+ )
# cache the latest succeeded runner
latest_runner = runner
@@ -176,12 +189,16 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non
if not drop_old:
case_res.metrics.load_duration = cached_load_duration if cached_load_duration else 0.0
except (LoadTimeoutError, PerformanceTimeoutError) as e:
- log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}")
+ log.warning(
+ f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}",
+ )
case_res.label = ResultLabel.OUTOFRANGE
continue
except Exception as e:
- log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}")
+ log.warning(
+ f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}",
+ )
traceback.print_exc()
case_res.label = ResultLabel.FAILED
continue
@@ -200,10 +217,14 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non
send_conn.send((SIGNAL.SUCCESS, None))
send_conn.close()
- log.info(f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}")
+ log.info(
+ f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}",
+ )
except Exception as e:
- err_msg = f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}"
+ err_msg = (
+ f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}"
+ )
traceback.print_exc()
log.warning(err_msg)
send_conn.send((SIGNAL.ERROR, err_msg))
@@ -226,16 +247,26 @@ def _clear_running_task(self):
self.receive_conn.close()
self.receive_conn = None
-
def _run_async(self, conn: Connection) -> bool:
- log.info(f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, case number: {len(self.running_task.case_runners)}")
+ log.info(
+ f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, ",
+ f"case number: {len(self.running_task.case_runners)}",
+ )
global global_result_future
- executor = concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=mp.get_context("spawn"))
+ executor = concurrent.futures.ProcessPoolExecutor(
+ max_workers=1,
+ mp_context=mp.get_context("spawn"),
+ )
global_result_future = executor.submit(self._async_task_v2, self.running_task, conn)
return True
- def kill_proc_tree(self, sig=signal.SIGTERM, timeout=None, on_terminate=None):
+ def kill_proc_tree(
+ self,
+ sig: int = signal.SIGTERM,
+ timeout: float | None = None,
+ on_terminate: Callable | None = None,
+ ):
"""Kill a process tree (including grandchildren) with signal
"sig" and return a (gone, still_alive) tuple.
"on_terminate", if specified, is a callback function which is
@@ -248,12 +279,11 @@ def kill_proc_tree(self, sig=signal.SIGTERM, timeout=None, on_terminate=None):
p.send_signal(sig)
except psutil.NoSuchProcess:
pass
- gone, alive = psutil.wait_procs(children, timeout=timeout,
- callback=on_terminate)
+ gone, alive = psutil.wait_procs(children, timeout=timeout, callback=on_terminate)
for p in alive:
log.warning(f"force killing child process: {p}")
p.kill()
-benchMarkRunner = BenchMarkRunner()
+benchmark_runner = BenchMarkRunner()
diff --git a/vectordb_bench/log_util.py b/vectordb_bench/log_util.py
index b923bdcd2..d75688137 100644
--- a/vectordb_bench/log_util.py
+++ b/vectordb_bench/log_util.py
@@ -1,102 +1,97 @@
import logging
from logging import config
-def init(log_level):
- LOGGING = {
- 'version': 1,
- 'disable_existing_loggers': False,
- 'formatters': {
- 'default': {
- 'format': '%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)',
+
+def init(log_level: str):
+ log_config = {
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "default": {
+ "format": "%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)",
},
- 'colorful_console': {
- 'format': '%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)',
- '()': ColorfulFormatter,
+ "colorful_console": {
+ "format": "%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)",
+ "()": ColorfulFormatter,
},
},
- 'handlers': {
- 'console': {
- 'class': 'logging.StreamHandler',
- 'formatter': 'colorful_console',
+ "handlers": {
+ "console": {
+ "class": "logging.StreamHandler",
+ "formatter": "colorful_console",
},
- 'no_color_console': {
- 'class': 'logging.StreamHandler',
- 'formatter': 'default',
+ "no_color_console": {
+ "class": "logging.StreamHandler",
+ "formatter": "default",
},
},
- 'loggers': {
- 'vectordb_bench': {
- 'handlers': ['console'],
- 'level': log_level,
- 'propagate': False
+ "loggers": {
+ "vectordb_bench": {
+ "handlers": ["console"],
+ "level": log_level,
+ "propagate": False,
},
- 'no_color': {
- 'handlers': ['no_color_console'],
- 'level': log_level,
- 'propagate': False
+ "no_color": {
+ "handlers": ["no_color_console"],
+ "level": log_level,
+ "propagate": False,
},
},
- 'propagate': False,
+ "propagate": False,
}
- config.dictConfig(LOGGING)
+ config.dictConfig(log_config)
-class colors:
- HEADER= '\033[95m'
- INFO= '\033[92m'
- DEBUG= '\033[94m'
- WARNING= '\033[93m'
- ERROR= '\033[95m'
- CRITICAL= '\033[91m'
- ENDC= '\033[0m'
+class colors:
+ HEADER = "\033[95m"
+ INFO = "\033[92m"
+ DEBUG = "\033[94m"
+ WARNING = "\033[93m"
+ ERROR = "\033[95m"
+ CRITICAL = "\033[91m"
+ ENDC = "\033[0m"
COLORS = {
- 'INFO': colors.INFO,
- 'INFOM': colors.INFO,
- 'DEBUG': colors.DEBUG,
- 'DEBUGM': colors.DEBUG,
- 'WARNING': colors.WARNING,
- 'WARNINGM': colors.WARNING,
- 'CRITICAL': colors.CRITICAL,
- 'CRITICALM': colors.CRITICAL,
- 'ERROR': colors.ERROR,
- 'ERRORM': colors.ERROR,
- 'ENDC': colors.ENDC,
+ "INFO": colors.INFO,
+ "INFOM": colors.INFO,
+ "DEBUG": colors.DEBUG,
+ "DEBUGM": colors.DEBUG,
+ "WARNING": colors.WARNING,
+ "WARNINGM": colors.WARNING,
+ "CRITICAL": colors.CRITICAL,
+ "CRITICALM": colors.CRITICAL,
+ "ERROR": colors.ERROR,
+ "ERRORM": colors.ERROR,
+ "ENDC": colors.ENDC,
}
class ColorFulFormatColMixin:
- def format_col(self, message_str, level_name):
- if level_name in COLORS.keys():
- message_str = COLORS[level_name] + message_str + COLORS['ENDC']
- return message_str
-
- def formatTime(self, record, datefmt=None):
- ret = super().formatTime(record, datefmt)
- return ret
+ def format_col(self, message: str, level_name: str):
+ if level_name in COLORS:
+ message = COLORS[level_name] + message + COLORS["ENDC"]
+ return message
class ColorfulLogRecordProxy(logging.LogRecord):
- def __init__(self, record):
+ def __init__(self, record: any):
self._record = record
- msg_level = record.levelname + 'M'
+ msg_level = record.levelname + "M"
self.msg = f"{COLORS[msg_level]}{record.msg}{COLORS['ENDC']}"
self.filename = record.filename
- self.lineno = f'{record.lineno}'
- self.process = f'{record.process}'
+ self.lineno = f"{record.lineno}"
+ self.process = f"{record.process}"
self.levelname = f"{COLORS[record.levelname]}{record.levelname}{COLORS['ENDC']}"
- def __getattr__(self, attr):
+ def __getattr__(self, attr: any):
if attr not in self.__dict__:
return getattr(self._record, attr)
return getattr(self, attr)
class ColorfulFormatter(ColorFulFormatColMixin, logging.Formatter):
- def format(self, record):
+ def format(self, record: any):
proxy = ColorfulLogRecordProxy(record)
- message_str = super().format(proxy)
-
- return message_str
+ return super().format(proxy)
diff --git a/vectordb_bench/metric.py b/vectordb_bench/metric.py
index 9f083a5c6..e0b6cff0e 100644
--- a/vectordb_bench/metric.py
+++ b/vectordb_bench/metric.py
@@ -1,8 +1,7 @@
import logging
-import numpy as np
-
from dataclasses import dataclass, field
+import numpy as np
log = logging.getLogger(__name__)
@@ -33,19 +32,19 @@ class Metric:
QPS_METRIC = "qps"
RECALL_METRIC = "recall"
-metricUnitMap = {
+metric_unit_map = {
LOAD_DURATION_METRIC: "s",
SERIAL_LATENCY_P99_METRIC: "ms",
MAX_LOAD_COUNT_METRIC: "K",
QURIES_PER_DOLLAR_METRIC: "K",
}
-lowerIsBetterMetricList = [
+lower_is_better_metrics = [
LOAD_DURATION_METRIC,
SERIAL_LATENCY_P99_METRIC,
]
-metricOrder = [
+metric_order = [
QPS_METRIC,
RECALL_METRIC,
LOAD_DURATION_METRIC,
@@ -55,7 +54,7 @@ class Metric:
def isLowerIsBetterMetric(metric: str) -> bool:
- return metric in lowerIsBetterMetricList
+ return metric in lower_is_better_metrics
def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float:
@@ -70,7 +69,7 @@ def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float:
def get_ideal_dcg(k: int):
ideal_dcg = 0
for i in range(k):
- ideal_dcg += 1 / np.log2(i+2)
+ ideal_dcg += 1 / np.log2(i + 2)
return ideal_dcg
@@ -78,8 +77,8 @@ def get_ideal_dcg(k: int):
def calc_ndcg(ground_truth: list[int], got: list[int], ideal_dcg: float) -> float:
dcg = 0
ground_truth = list(ground_truth)
- for id in set(got):
- if id in ground_truth:
- idx = ground_truth.index(id)
- dcg += 1 / np.log2(idx+2)
+ for got_id in set(got):
+ if got_id in ground_truth:
+ idx = ground_truth.index(got_id)
+ dcg += 1 / np.log2(idx + 2)
return dcg / ideal_dcg
diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py
index 648fb1727..49bb04ae0 100644
--- a/vectordb_bench/models.py
+++ b/vectordb_bench/models.py
@@ -2,29 +2,31 @@
import pathlib
from datetime import date, datetime
from enum import Enum, StrEnum, auto
-from typing import List, Self
+from typing import Self
import ujson
+from . import config
+from .backend.cases import CaseType
from .backend.clients import (
DB,
- DBConfig,
DBCaseConfig,
+ DBConfig,
)
-from .backend.cases import CaseType
from .base import BaseModel
-from . import config
from .metric import Metric
log = logging.getLogger(__name__)
class LoadTimeoutError(TimeoutError):
- pass
+ def __init__(self, duration: int):
+ super().__init__(f"capacity case load timeout in {duration}s")
class PerformanceTimeoutError(TimeoutError):
- pass
+ def __init__(self):
+ super().__init__("Performance case optimize timeout")
class CaseConfigParamType(Enum):
@@ -92,7 +94,7 @@ class CustomizedCase(BaseModel):
class ConcurrencySearchConfig(BaseModel):
- num_concurrency: List[int] = config.NUM_CONCURRENCY
+ num_concurrency: list[int] = config.NUM_CONCURRENCY
concurrency_duration: int = config.CONCURRENCY_DURATION
@@ -146,7 +148,7 @@ class TaskConfig(BaseModel):
db_config: DBConfig
db_case_config: DBCaseConfig
case_config: CaseConfig
- stages: List[TaskStage] = ALL_TASK_STAGES
+ stages: list[TaskStage] = ALL_TASK_STAGES
@property
def db_name(self):
@@ -210,26 +212,23 @@ def write_db_file(self, result_dir: pathlib.Path, partial: Self, db: str):
log.info(f"local result directory not exist, creating it: {result_dir}")
result_dir.mkdir(parents=True)
- file_name = self.file_fmt.format(
- date.today().strftime("%Y%m%d"), partial.task_label, db
- )
+ file_name = self.file_fmt.format(date.today().strftime("%Y%m%d"), partial.task_label, db)
result_file = result_dir.joinpath(file_name)
if result_file.exists():
- log.warning(
- f"Replacing existing result with the same file_name: {result_file}"
- )
+ log.warning(f"Replacing existing result with the same file_name: {result_file}")
log.info(f"write results to disk {result_file}")
- with open(result_file, "w") as f:
+ with pathlib.Path(result_file).open("w") as f:
b = partial.json(exclude={"db_config": {"password", "api_key"}})
f.write(b)
@classmethod
def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self:
if not full_path.exists():
- raise ValueError(f"No such file: {full_path}")
+ msg = f"No such file: {full_path}"
+ raise ValueError(msg)
- with open(full_path) as f:
+ with pathlib.Path(full_path).open("r") as f:
test_result = ujson.loads(f.read())
if "task_label" not in test_result:
test_result["task_label"] = test_result["run_id"]
@@ -248,19 +247,16 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self:
if trans_unit:
cur_max_count = case_result["metrics"]["max_load_count"]
case_result["metrics"]["max_load_count"] = (
- cur_max_count / 1000
- if int(cur_max_count) > 0
- else cur_max_count
+ cur_max_count / 1000 if int(cur_max_count) > 0 else cur_max_count
)
cur_latency = case_result["metrics"]["serial_latency_p99"]
case_result["metrics"]["serial_latency_p99"] = (
cur_latency * 1000 if cur_latency > 0 else cur_latency
)
- c = TestResult.validate(test_result)
-
- return c
+ return TestResult.validate(test_result)
+ # ruff: noqa
def display(self, dbs: list[DB] | None = None):
filter_list = dbs if dbs and isinstance(dbs, list) else None
sorted_results = sorted(
@@ -273,31 +269,18 @@ def display(self, dbs: list[DB] | None = None):
reverse=True,
)
- filtered_results = [
- r
- for r in sorted_results
- if not filter_list or r.task_config.db not in filter_list
- ]
+ filtered_results = [r for r in sorted_results if not filter_list or r.task_config.db not in filter_list]
- def append_return(x, y):
+ def append_return(x: any, y: any):
x.append(y)
return x
max_db = max(map(len, [f.task_config.db.name for f in filtered_results]))
- max_db_labels = (
- max(map(len, [f.task_config.db_config.db_label for f in filtered_results]))
- + 3
- )
- max_case = max(
- map(len, [f.task_config.case_config.case_id.name for f in filtered_results])
- )
- max_load_dur = (
- max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3
- )
+ max_db_labels = max(map(len, [f.task_config.db_config.db_label for f in filtered_results])) + 3
+ max_case = max(map(len, [f.task_config.case_config.case_id.name for f in filtered_results]))
+ max_load_dur = max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3
max_qps = max(map(len, [str(f.metrics.qps) for f in filtered_results])) + 3
- max_recall = (
- max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3
- )
+ max_recall = max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3
max_db_labels = 8 if max_db_labels < 8 else max_db_labels
max_load_dur = 11 if max_load_dur < 11 else max_load_dur
@@ -356,7 +339,7 @@ def append_return(x, y):
f.metrics.recall,
f.metrics.max_load_count,
f.label.value,
- )
+ ),
)
tmp_logger = logging.getLogger("no_color")