From b2394122d51fd2ccb6923855d5921c39c5ce4a8d Mon Sep 17 00:00:00 2001 From: yangxuan Date: Wed, 8 Jan 2025 18:39:35 +0800 Subject: [PATCH] enhance: Refine the coding style and enable lint-action Signed-off-by: yangxuan --- .github/workflows/pull_request.yml | 4 + .ruff.toml | 49 ----- Makefile | 8 + README.md | 6 +- pyproject.toml | 114 ++++++++++ vectordb_bench/__init__.py | 73 ++++--- vectordb_bench/__main__.py | 7 +- vectordb_bench/backend/assembler.py | 25 ++- vectordb_bench/backend/cases.py | 100 +++++---- vectordb_bench/backend/clients/__init__.py | 89 ++++++-- .../aliyun_elasticsearch.py | 3 +- .../clients/aliyun_elasticsearch/config.py | 7 +- .../aliyun_opensearch/aliyun_opensearch.py | 181 +++++++++------- .../clients/aliyun_opensearch/config.py | 13 +- .../backend/clients/alloydb/alloydb.py | 138 ++++++------- vectordb_bench/backend/clients/alloydb/cli.py | 85 +++++--- .../backend/clients/alloydb/config.py | 60 +++--- vectordb_bench/backend/clients/api.py | 14 +- .../clients/aws_opensearch/aws_opensearch.py | 90 ++++---- .../backend/clients/aws_opensearch/cli.py | 11 +- .../backend/clients/aws_opensearch/config.py | 22 +- .../backend/clients/aws_opensearch/run.py | 128 ++++++------ .../backend/clients/chroma/chroma.py | 74 +++---- .../backend/clients/chroma/config.py | 6 +- .../backend/clients/elastic_cloud/config.py | 10 +- .../clients/elastic_cloud/elastic_cloud.py | 45 ++-- .../backend/clients/memorydb/cli.py | 16 +- .../backend/clients/memorydb/config.py | 4 +- .../backend/clients/memorydb/memorydb.py | 118 ++++++----- vectordb_bench/backend/clients/milvus/cli.py | 124 ++++------- .../backend/clients/milvus/config.py | 26 ++- .../backend/clients/milvus/milvus.py | 37 ++-- .../backend/clients/pgdiskann/cli.py | 51 +++-- .../backend/clients/pgdiskann/config.py | 55 ++--- .../backend/clients/pgdiskann/pgdiskann.py | 128 +++++------- .../backend/clients/pgvecto_rs/cli.py | 20 +- .../backend/clients/pgvecto_rs/config.py | 22 +- .../backend/clients/pgvecto_rs/pgvecto_rs.py | 67 +++--- .../backend/clients/pgvector/cli.py | 77 ++++--- .../backend/clients/pgvector/config.py | 136 ++++++------ .../backend/clients/pgvector/pgvector.py | 195 +++++++++--------- .../backend/clients/pgvectorscale/cli.py | 62 +++--- .../backend/clients/pgvectorscale/config.py | 29 ++- .../clients/pgvectorscale/pgvectorscale.py | 81 ++++---- .../backend/clients/pinecone/config.py | 1 + .../backend/clients/pinecone/pinecone.py | 35 ++-- .../backend/clients/qdrant_cloud/config.py | 21 +- .../clients/qdrant_cloud/qdrant_cloud.py | 71 ++++--- vectordb_bench/backend/clients/redis/cli.py | 18 +- .../backend/clients/redis/config.py | 12 +- vectordb_bench/backend/clients/redis/redis.py | 143 ++++++++----- vectordb_bench/backend/clients/test/cli.py | 3 +- vectordb_bench/backend/clients/test/config.py | 4 +- vectordb_bench/backend/clients/test/test.py | 9 +- .../backend/clients/weaviate_cloud/cli.py | 7 +- .../backend/clients/weaviate_cloud/config.py | 4 +- .../clients/weaviate_cloud/weaviate_cloud.py | 58 ++++-- .../backend/clients/zilliz_cloud/cli.py | 25 ++- .../backend/clients/zilliz_cloud/config.py | 6 +- .../clients/zilliz_cloud/zilliz_cloud.py | 2 +- vectordb_bench/backend/data_source.py | 48 +++-- vectordb_bench/backend/dataset.py | 74 ++++--- vectordb_bench/backend/result_collector.py | 5 +- vectordb_bench/backend/runner/__init__.py | 10 +- vectordb_bench/backend/runner/mp_runner.py | 119 ++++++++--- vectordb_bench/backend/runner/rate_runner.py | 49 +++-- .../backend/runner/read_write_runner.py | 74 ++++--- .../backend/runner/serial_runner.py | 139 ++++++++----- vectordb_bench/backend/runner/util.py | 7 +- vectordb_bench/backend/task_runner.py | 164 ++++++++------- vectordb_bench/backend/utils.py | 27 ++- vectordb_bench/base.py | 1 - vectordb_bench/cli/cli.py | 125 +++++------ vectordb_bench/cli/vectordbbench.py | 13 +- .../components/check_results/charts.py | 27 +-- .../frontend/components/check_results/data.py | 20 +- .../components/check_results/filters.py | 24 +-- .../frontend/components/check_results/nav.py | 8 +- .../components/check_results/priceTable.py | 4 +- .../components/check_results/stPageConfig.py | 3 +- .../frontend/components/concurrent/charts.py | 24 +-- .../components/custom/displayCustomCase.py | 28 ++- .../components/custom/displaypPrams.py | 6 +- .../components/custom/getCustomConfig.py | 3 +- .../frontend/components/custom/initStyle.py | 2 +- .../components/get_results/saveAsImage.py | 2 + .../components/run_test/caseSelector.py | 12 +- .../components/run_test/dbConfigSetting.py | 5 +- .../components/run_test/dbSelector.py | 2 +- .../components/run_test/generateTasks.py | 16 +- .../components/run_test/submitTask.py | 32 ++- .../frontend/components/tables/data.py | 9 +- .../frontend/config/dbCaseConfigs.py | 135 +++++------- vectordb_bench/frontend/pages/concurrent.py | 8 +- vectordb_bench/frontend/pages/custom.py | 39 +++- .../frontend/pages/quries_per_dollar.py | 6 +- vectordb_bench/frontend/pages/run_test.py | 10 +- vectordb_bench/frontend/utils.py | 2 +- vectordb_bench/frontend/vdb_benchmark.py | 10 +- vectordb_bench/interface.py | 82 +++++--- vectordb_bench/log_util.py | 123 ++++++----- vectordb_bench/metric.py | 21 +- vectordb_bench/models.py | 69 +++---- 103 files changed, 2490 insertions(+), 2126 deletions(-) delete mode 100644 .ruff.toml diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml index 7c133cd5f..ea346dcd0 100644 --- a/.github/workflows/pull_request.yml +++ b/.github/workflows/pull_request.yml @@ -31,6 +31,10 @@ jobs: python -m pip install --upgrade pip pip install -e ".[test]" + - name: Run coding checks + run: | + make lint + - name: Test with pytest run: | make unittest diff --git a/.ruff.toml b/.ruff.toml deleted file mode 100644 index fb6b19c3d..000000000 --- a/.ruff.toml +++ /dev/null @@ -1,49 +0,0 @@ -# Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. -# Enable flake8-bugbear (`B`) rules. -select = ["E", "F", "B"] -ignore = [ - "E501", # (line length violations) -] - -# Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", - "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", - "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT", -] -unfixable = [] - -# Exclude a variety of commonly ignored directories. -exclude = [ - ".bzr", - ".direnv", - ".eggs", - ".git", - ".hg", - ".mypy_cache", - ".nox", - ".pants.d", - ".pytype", - ".ruff_cache", - ".svn", - ".tox", - ".venv", - "__pypackages__", - "_build", - "buck-out", - "build", - "dist", - "node_modules", - "venv", - "__pycache__", - "__init__.py", -] - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -# Assume Python 3.11. -target-version = "py311" - -[mccabe] -# Unlike Flake8, default to a complexity level of 10. -max-complexity = 10 diff --git a/Makefile b/Makefile index 562615f6d..ef8207c55 100644 --- a/Makefile +++ b/Makefile @@ -1,2 +1,10 @@ unittest: PYTHONPATH=`pwd` python3 -m pytest tests/test_dataset.py::TestDataSet::test_download_small -svv + +format: + PYTHONPATH=`pwd` python3 -m black vectordb_bench + PYTHONPATH=`pwd` python3 -m ruff check vectordb_bench --fix + +lint: + PYTHONPATH=`pwd` python3 -m black vectordb_bench --check + PYTHONPATH=`pwd` python3 -m ruff check vectordb_bench diff --git a/README.md b/README.md index 56d54c49d..737fc6064 100644 --- a/README.md +++ b/README.md @@ -240,13 +240,13 @@ After reopen the repository in container, run `python -m vectordb_bench` in the ### Check coding styles ```shell -$ ruff check vectordb_bench +$ make lint ``` -Add `--fix` if you want to fix the coding styles automatically +To fix the coding styles automatically ```shell -$ ruff check vectordb_bench --fix +$ make format ``` ## How does it work? diff --git a/pyproject.toml b/pyproject.toml index 8363aa4fa..312940634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dynamic = ["version"] [project.optional-dependencies] test = [ + "black", "ruff", "pytest", ] @@ -93,3 +94,116 @@ init_bench = "vectordb_bench.__main__:main" vectordbbench = "vectordb_bench.cli.vectordbbench:cli" [tool.setuptools_scm] + +[tool.black] +line-length = 120 +target-version = ['py311'] +include = '\.pyi?$' + +[tool.ruff] +lint.select = [ + "E", + "F", + "C90", + "I", + "N", + "B", "C", "G", + "A", + "ANN001", + "S", "T", "W", "ARG", "BLE", "COM", "DJ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT" +] +lint.ignore = [ + "BLE001", # blind-except (BLE001) + "SLF001", # SLF001 Private member accessed [E] + "TRY003", # [ruff] TRY003 Avoid specifying long messages outside the exception class [E] + "FBT001", "FBT002", "FBT003", + "G004", # [ruff] G004 Logging statement uses f-string [E] + "UP031", + "RUF012", + "EM101", + "N805", + "ARG002", + "ARG003", + "PIE796", # https://github.com/zilliztech/VectorDBBench/issues/438 + "INP001", # TODO + "TID252", # TODO + "N801", "N802", "N815", + "S101", "S108", "S603", "S311", + "PLR2004", + "RUF017", + "C416", + "PLW0603", +] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +lint.fixable = [ + "A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", + "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", + "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", + "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", + "YTT", +] +lint.unfixable = [] + +show-fixes = true + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "grpc_gen", + "__pycache__", + "frontend", # TODO + "tests", +] + +# Same as Black. +line-length = 120 + +# Allow unused variables when underscore-prefixed. +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +# Assume Python 3.11 +target-version = "py311" + +[tool.ruff.lint.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 18 + +[tool.ruff.lint.pycodestyle] +max-line-length = 120 +max-doc-length = 120 + +[tool.ruff.lint.pylint] +max-args = 20 +max-branches = 15 + +[tool.ruff.lint.flake8-builtins] +builtins-ignorelist = [ + # "format", + # "next", + # "object", # TODO + # "id", + # "dict", # TODO + # "filter", +] + diff --git a/vectordb_bench/__init__.py b/vectordb_bench/__init__.py index 568e97705..c07fc855d 100644 --- a/vectordb_bench/__init__.py +++ b/vectordb_bench/__init__.py @@ -22,46 +22,71 @@ class config: DROP_OLD = env.bool("DROP_OLD", True) USE_SHUFFLED_DATA = env.bool("USE_SHUFFLED_DATA", True) - NUM_CONCURRENCY = env.list("NUM_CONCURRENCY", [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90, 95, 100], subcast=int ) + NUM_CONCURRENCY = env.list( + "NUM_CONCURRENCY", + [ + 1, + 5, + 10, + 15, + 20, + 25, + 30, + 35, + 40, + 45, + 50, + 55, + 60, + 65, + 70, + 75, + 80, + 85, + 90, + 95, + 100, + ], + subcast=int, + ) CONCURRENCY_DURATION = 30 RESULTS_LOCAL_DIR = env.path( - "RESULTS_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("results") + "RESULTS_LOCAL_DIR", + pathlib.Path(__file__).parent.joinpath("results"), ) CONFIG_LOCAL_DIR = env.path( - "CONFIG_LOCAL_DIR", pathlib.Path(__file__).parent.joinpath("config-files") + "CONFIG_LOCAL_DIR", + pathlib.Path(__file__).parent.joinpath("config-files"), ) - K_DEFAULT = 100 # default return top k nearest neighbors during search CUSTOM_CONFIG_DIR = pathlib.Path(__file__).parent.joinpath("custom/custom_case.json") - CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h - LOAD_TIMEOUT_DEFAULT = 24 * 3600 # 24h - LOAD_TIMEOUT_768D_1M = 24 * 3600 # 24h - LOAD_TIMEOUT_768D_10M = 240 * 3600 # 10d - LOAD_TIMEOUT_768D_100M = 2400 * 3600 # 100d + CAPACITY_TIMEOUT_IN_SECONDS = 24 * 3600 # 24h + LOAD_TIMEOUT_DEFAULT = 24 * 3600 # 24h + LOAD_TIMEOUT_768D_1M = 24 * 3600 # 24h + LOAD_TIMEOUT_768D_10M = 240 * 3600 # 10d + LOAD_TIMEOUT_768D_100M = 2400 * 3600 # 100d - LOAD_TIMEOUT_1536D_500K = 24 * 3600 # 24h - LOAD_TIMEOUT_1536D_5M = 240 * 3600 # 10d + LOAD_TIMEOUT_1536D_500K = 24 * 3600 # 24h + LOAD_TIMEOUT_1536D_5M = 240 * 3600 # 10d - OPTIMIZE_TIMEOUT_DEFAULT = 24 * 3600 # 24h - OPTIMIZE_TIMEOUT_768D_1M = 24 * 3600 # 24h - OPTIMIZE_TIMEOUT_768D_10M = 240 * 3600 # 10d - OPTIMIZE_TIMEOUT_768D_100M = 2400 * 3600 # 100d + OPTIMIZE_TIMEOUT_DEFAULT = 24 * 3600 # 24h + OPTIMIZE_TIMEOUT_768D_1M = 24 * 3600 # 24h + OPTIMIZE_TIMEOUT_768D_10M = 240 * 3600 # 10d + OPTIMIZE_TIMEOUT_768D_100M = 2400 * 3600 # 100d + OPTIMIZE_TIMEOUT_1536D_500K = 24 * 3600 # 24h + OPTIMIZE_TIMEOUT_1536D_5M = 240 * 3600 # 10d - OPTIMIZE_TIMEOUT_1536D_500K = 24 * 3600 # 24h - OPTIMIZE_TIMEOUT_1536D_5M = 240 * 3600 # 10d - def display(self) -> str: - tmp = [ - i for i in inspect.getmembers(self) - if not inspect.ismethod(i[1]) - and not i[0].startswith('_') - and "TIMEOUT" not in i[0] + return [ + i + for i in inspect.getmembers(self) + if not inspect.ismethod(i[1]) and not i[0].startswith("_") and "TIMEOUT" not in i[0] ] - return tmp + log_util.init(config.LOG_LEVEL) diff --git a/vectordb_bench/__main__.py b/vectordb_bench/__main__.py index a8b4436bc..6663731f5 100644 --- a/vectordb_bench/__main__.py +++ b/vectordb_bench/__main__.py @@ -1,7 +1,8 @@ -import traceback import logging +import pathlib import subprocess -import os +import traceback + from . import config log = logging.getLogger("vectordb_bench") @@ -16,7 +17,7 @@ def run_streamlit(): cmd = [ "streamlit", "run", - f"{os.path.dirname(__file__)}/frontend/vdb_benchmark.py", + f"{pathlib.Path(__file__).parent}/frontend/vdb_benchmark.py", "--logger.level", "info", "--theme.base", diff --git a/vectordb_bench/backend/assembler.py b/vectordb_bench/backend/assembler.py index e7da4d49f..b81d315c2 100644 --- a/vectordb_bench/backend/assembler.py +++ b/vectordb_bench/backend/assembler.py @@ -1,24 +1,25 @@ -from .cases import CaseLabel -from .task_runner import CaseRunner, RunningStatus, TaskRunner -from ..models import TaskConfig -from ..backend.clients import EmptyDBCaseConfig -from ..backend.data_source import DatasetSource import logging +from vectordb_bench.backend.clients import EmptyDBCaseConfig +from vectordb_bench.backend.data_source import DatasetSource +from vectordb_bench.models import TaskConfig + +from .cases import CaseLabel +from .task_runner import CaseRunner, RunningStatus, TaskRunner log = logging.getLogger(__name__) class Assembler: @classmethod - def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunner: + def assemble(cls, run_id: str, task: TaskConfig, source: DatasetSource) -> CaseRunner: c_cls = task.case_config.case_id.case_cls c = c_cls(task.case_config.custom_case) - if type(task.db_case_config) != EmptyDBCaseConfig: + if type(task.db_case_config) is not EmptyDBCaseConfig: task.db_case_config.metric_type = c.dataset.data.metric_type - runner = CaseRunner( + return CaseRunner( run_id=run_id, config=task, ca=c, @@ -26,8 +27,6 @@ def assemble(cls, run_id , task: TaskConfig, source: DatasetSource) -> CaseRunne dataset_source=source, ) - return runner - @classmethod def assemble_all( cls, @@ -50,12 +49,12 @@ def assemble_all( db2runner[db].append(r) # check dbclient installed - for k in db2runner.keys(): + for k in db2runner: _ = k.init_cls # sort by dataset size - for k in db2runner.keys(): - db2runner[k].sort(key=lambda x:x.ca.dataset.data.size) + for k, _ in db2runner: + db2runner[k].sort(key=lambda x: x.ca.dataset.data.size) all_runners = [] all_runners.extend(load_runners) diff --git a/vectordb_bench/backend/cases.py b/vectordb_bench/backend/cases.py index 8b643b66b..15fc069cc 100644 --- a/vectordb_bench/backend/cases.py +++ b/vectordb_bench/backend/cases.py @@ -1,7 +1,5 @@ -import typing import logging from enum import Enum, auto -from typing import Type from vectordb_bench import config from vectordb_bench.backend.clients.api import MetricType @@ -12,7 +10,6 @@ from .dataset import CustomDataset, Dataset, DatasetManager - log = logging.getLogger(__name__) @@ -50,11 +47,10 @@ class CaseType(Enum): Custom = 100 PerformanceCustomDataset = 101 - def case_cls(self, custom_configs: dict | None = None) -> Type["Case"]: + def case_cls(self, custom_configs: dict | None = None) -> type["Case"]: if custom_configs is None: return type2case.get(self)() - else: - return type2case.get(self)(**custom_configs) + return type2case.get(self)(**custom_configs) def case_name(self, custom_configs: dict | None = None) -> str: c = self.case_cls(custom_configs) @@ -99,10 +95,10 @@ class Case(BaseModel): @property def filters(self) -> dict | None: if self.filter_rate is not None: - ID = round(self.filter_rate * self.dataset.data.size) + target_id = round(self.filter_rate * self.dataset.data.size) return { - "metadata": f">={ID}", - "id": ID, + "metadata": f">={target_id}", + "id": target_id, } return None @@ -126,8 +122,8 @@ class CapacityDim960(CapacityCase): case_id: CaseType = CaseType.CapacityDim960 dataset: DatasetManager = Dataset.GIST.manager(100_000) name: str = "Capacity Test (960 Dim Repeated)" - description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension - vectors (GIST 100K vectors, 960 dimensions) until it is fully loaded. Number of inserted vectors will be + description: str = """This case tests the vector database's loading capacity by repeatedly inserting large-dimension + vectors (GIST 100K vectors, 960 dimensions) until it is fully loaded. Number of inserted vectors will be reported.""" @@ -136,7 +132,7 @@ class CapacityDim128(CapacityCase): dataset: DatasetManager = Dataset.SIFT.manager(500_000) name: str = "Capacity Test (128 Dim Repeated)" description: str = """This case tests the vector database's loading capacity by repeatedly inserting small-dimension - vectors (SIFT 100K vectors, 128 dimensions) until it is fully loaded. Number of inserted vectors will be + vectors (SIFT 100K vectors, 128 dimensions) until it is fully loaded. Number of inserted vectors will be reported.""" @@ -144,8 +140,9 @@ class Performance768D10M(PerformanceCase): case_id: CaseType = CaseType.Performance768D10M dataset: DatasetManager = Dataset.COHERE.manager(10_000_000) name: str = "Search Performance Test (10M Dataset, 768 Dim)" - description: str = """This case tests the search performance of a vector database with a large dataset (Cohere 10M vectors, 768 dimensions) at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a large dataset + (Cohere 10M vectors, 768 dimensions) at varying parallel levels. + Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_768D_10M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_10M @@ -154,8 +151,9 @@ class Performance768D1M(PerformanceCase): case_id: CaseType = CaseType.Performance768D1M dataset: DatasetManager = Dataset.COHERE.manager(1_000_000) name: str = "Search Performance Test (1M Dataset, 768 Dim)" - description: str = """This case tests the search performance of a vector database with a medium dataset (Cohere 1M vectors, 768 dimensions) at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium dataset + (Cohere 1M vectors, 768 dimensions) at varying parallel levels. + Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_768D_1M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_1M @@ -165,8 +163,9 @@ class Performance768D10M1P(PerformanceCase): filter_rate: float | int | None = 0.01 dataset: DatasetManager = Dataset.COHERE.manager(10_000_000) name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 1%)" - description: str = """This case tests the search performance of a vector database with a large dataset (Cohere 10M vectors, 768 dimensions) under a low filtering rate (1% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a large dataset + (Cohere 10M vectors, 768 dimensions) under a low filtering rate (1% vectors), at varying parallel + levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_768D_10M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_10M @@ -176,8 +175,9 @@ class Performance768D1M1P(PerformanceCase): filter_rate: float | int | None = 0.01 dataset: DatasetManager = Dataset.COHERE.manager(1_000_000) name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 1%)" - description: str = """This case tests the search performance of a vector database with a medium dataset (Cohere 1M vectors, 768 dimensions) under a low filtering rate (1% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium dataset + (Cohere 1M vectors, 768 dimensions) under a low filtering rate (1% vectors), + at varying parallel levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_768D_1M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_1M @@ -187,8 +187,9 @@ class Performance768D10M99P(PerformanceCase): filter_rate: float | int | None = 0.99 dataset: DatasetManager = Dataset.COHERE.manager(10_000_000) name: str = "Filtering Search Performance Test (10M Dataset, 768 Dim, Filter 99%)" - description: str = """This case tests the search performance of a vector database with a large dataset (Cohere 10M vectors, 768 dimensions) under a high filtering rate (99% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a large dataset + (Cohere 10M vectors, 768 dimensions) under a high filtering rate (99% vectors), + at varying parallel levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_768D_10M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_10M @@ -198,8 +199,9 @@ class Performance768D1M99P(PerformanceCase): filter_rate: float | int | None = 0.99 dataset: DatasetManager = Dataset.COHERE.manager(1_000_000) name: str = "Filtering Search Performance Test (1M Dataset, 768 Dim, Filter 99%)" - description: str = """This case tests the search performance of a vector database with a medium dataset (Cohere 1M vectors, 768 dimensions) under a high filtering rate (99% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium dataset + (Cohere 1M vectors, 768 dimensions) under a high filtering rate (99% vectors), + at varying parallel levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_768D_1M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_1M @@ -209,8 +211,9 @@ class Performance768D100M(PerformanceCase): filter_rate: float | int | None = None dataset: DatasetManager = Dataset.LAION.manager(100_000_000) name: str = "Search Performance Test (100M Dataset, 768 Dim)" - description: str = """This case tests the search performance of a vector database with a large 100M dataset (LAION 100M vectors, 768 dimensions), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a large 100M dataset + (LAION 100M vectors, 768 dimensions), at varying parallel levels. Results will show index building time, + recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_768D_100M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_768D_100M @@ -220,8 +223,9 @@ class Performance1536D500K(PerformanceCase): filter_rate: float | int | None = None dataset: DatasetManager = Dataset.OPENAI.manager(500_000) name: str = "Search Performance Test (500K Dataset, 1536 Dim)" - description: str = """This case tests the search performance of a vector database with a medium 500K dataset (OpenAI 500K vectors, 1536 dimensions), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium 500K dataset + (OpenAI 500K vectors, 1536 dimensions), at varying parallel levels. Results will show index building time, + recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K @@ -231,8 +235,9 @@ class Performance1536D5M(PerformanceCase): filter_rate: float | int | None = None dataset: DatasetManager = Dataset.OPENAI.manager(5_000_000) name: str = "Search Performance Test (5M Dataset, 1536 Dim)" - description: str = """This case tests the search performance of a vector database with a medium 5M dataset (OpenAI 5M vectors, 1536 dimensions), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium 5M dataset + (OpenAI 5M vectors, 1536 dimensions), at varying parallel levels. Results will show index building time, + recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M @@ -242,8 +247,9 @@ class Performance1536D500K1P(PerformanceCase): filter_rate: float | int | None = 0.01 dataset: DatasetManager = Dataset.OPENAI.manager(500_000) name: str = "Filtering Search Performance Test (500K Dataset, 1536 Dim, Filter 1%)" - description: str = """This case tests the search performance of a vector database with a large dataset (OpenAI 500K vectors, 1536 dimensions) under a low filtering rate (1% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a large dataset + (OpenAI 500K vectors, 1536 dimensions) under a low filtering rate (1% vectors), + at varying parallel levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K @@ -253,8 +259,9 @@ class Performance1536D5M1P(PerformanceCase): filter_rate: float | int | None = 0.01 dataset: DatasetManager = Dataset.OPENAI.manager(5_000_000) name: str = "Filtering Search Performance Test (5M Dataset, 1536 Dim, Filter 1%)" - description: str = """This case tests the search performance of a vector database with a large dataset (OpenAI 5M vectors, 1536 dimensions) under a low filtering rate (1% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a large dataset + (OpenAI 5M vectors, 1536 dimensions) under a low filtering rate (1% vectors), + at varying parallel levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M @@ -264,8 +271,9 @@ class Performance1536D500K99P(PerformanceCase): filter_rate: float | int | None = 0.99 dataset: DatasetManager = Dataset.OPENAI.manager(500_000) name: str = "Filtering Search Performance Test (500K Dataset, 1536 Dim, Filter 99%)" - description: str = """This case tests the search performance of a vector database with a medium dataset (OpenAI 500K vectors, 1536 dimensions) under a high filtering rate (99% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium dataset + (OpenAI 500K vectors, 1536 dimensions) under a high filtering rate (99% vectors), + at varying parallel levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_1536D_500K optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_500K @@ -275,8 +283,9 @@ class Performance1536D5M99P(PerformanceCase): filter_rate: float | int | None = 0.99 dataset: DatasetManager = Dataset.OPENAI.manager(5_000_000) name: str = "Filtering Search Performance Test (5M Dataset, 1536 Dim, Filter 99%)" - description: str = """This case tests the search performance of a vector database with a medium dataset (OpenAI 5M vectors, 1536 dimensions) under a high filtering rate (99% vectors), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium dataset + (OpenAI 5M vectors, 1536 dimensions) under a high filtering rate (99% vectors), + at varying parallel levels. Results will show index building time, recall, and maximum QPS.""" load_timeout: float | int = config.LOAD_TIMEOUT_1536D_5M optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_1536D_5M @@ -286,8 +295,9 @@ class Performance1536D50K(PerformanceCase): filter_rate: float | int | None = None dataset: DatasetManager = Dataset.OPENAI.manager(50_000) name: str = "Search Performance Test (50K Dataset, 1536 Dim)" - description: str = """This case tests the search performance of a vector database with a medium 50K dataset (OpenAI 50K vectors, 1536 dimensions), at varying parallel levels. -Results will show index building time, recall, and maximum QPS.""" + description: str = """This case tests the search performance of a vector database with a medium 50K dataset + (OpenAI 50K vectors, 1536 dimensions), at varying parallel levels. Results will show index building time, + recall, and maximum QPS.""" load_timeout: float | int = 3600 optimize_timeout: float | int | None = config.OPTIMIZE_TIMEOUT_DEFAULT @@ -312,11 +322,11 @@ class PerformanceCustomDataset(PerformanceCase): def __init__( self, - name, - description, - load_timeout, - optimize_timeout, - dataset_config, + name: str, + description: str, + load_timeout: float, + optimize_timeout: float, + dataset_config: dict, **kwargs, ): dataset_config = CustomDatasetConfig(**dataset_config) diff --git a/vectordb_bench/backend/clients/__init__.py b/vectordb_bench/backend/clients/__init__.py index e1b66a81d..773cd4948 100644 --- a/vectordb_bench/backend/clients/__init__.py +++ b/vectordb_bench/backend/clients/__init__.py @@ -1,12 +1,12 @@ from enum import Enum -from typing import Type + from .api import ( - VectorDB, - DBConfig, DBCaseConfig, + DBConfig, EmptyDBCaseConfig, IndexType, MetricType, + VectorDB, ) @@ -41,200 +41,255 @@ class DB(Enum): Test = "test" AliyunOpenSearch = "AliyunOpenSearch" - @property - def init_cls(self) -> Type[VectorDB]: + def init_cls(self) -> type[VectorDB]: # noqa: PLR0911, PLR0912 """Import while in use""" if self == DB.Milvus: from .milvus.milvus import Milvus + return Milvus if self == DB.ZillizCloud: from .zilliz_cloud.zilliz_cloud import ZillizCloud + return ZillizCloud if self == DB.Pinecone: from .pinecone.pinecone import Pinecone + return Pinecone if self == DB.ElasticCloud: from .elastic_cloud.elastic_cloud import ElasticCloud + return ElasticCloud if self == DB.QdrantCloud: from .qdrant_cloud.qdrant_cloud import QdrantCloud + return QdrantCloud if self == DB.WeaviateCloud: from .weaviate_cloud.weaviate_cloud import WeaviateCloud + return WeaviateCloud if self == DB.PgVector: from .pgvector.pgvector import PgVector + return PgVector if self == DB.PgVectoRS: from .pgvecto_rs.pgvecto_rs import PgVectoRS + return PgVectoRS - + if self == DB.PgVectorScale: from .pgvectorscale.pgvectorscale import PgVectorScale + return PgVectorScale if self == DB.PgDiskANN: from .pgdiskann.pgdiskann import PgDiskANN + return PgDiskANN if self == DB.Redis: from .redis.redis import Redis + return Redis - + if self == DB.MemoryDB: from .memorydb.memorydb import MemoryDB + return MemoryDB if self == DB.Chroma: from .chroma.chroma import ChromaClient + return ChromaClient if self == DB.AWSOpenSearch: from .aws_opensearch.aws_opensearch import AWSOpenSearch + return AWSOpenSearch - + if self == DB.AlloyDB: from .alloydb.alloydb import AlloyDB + return AlloyDB if self == DB.AliyunElasticsearch: from .aliyun_elasticsearch.aliyun_elasticsearch import AliyunElasticsearch + return AliyunElasticsearch if self == DB.AliyunOpenSearch: from .aliyun_opensearch.aliyun_opensearch import AliyunOpenSearch + return AliyunOpenSearch + msg = f"Unknown DB: {self.name}" + raise ValueError(msg) + @property - def config_cls(self) -> Type[DBConfig]: + def config_cls(self) -> type[DBConfig]: # noqa: PLR0911, PLR0912 """Import while in use""" if self == DB.Milvus: from .milvus.config import MilvusConfig + return MilvusConfig if self == DB.ZillizCloud: from .zilliz_cloud.config import ZillizCloudConfig + return ZillizCloudConfig if self == DB.Pinecone: from .pinecone.config import PineconeConfig + return PineconeConfig if self == DB.ElasticCloud: from .elastic_cloud.config import ElasticCloudConfig + return ElasticCloudConfig if self == DB.QdrantCloud: from .qdrant_cloud.config import QdrantConfig + return QdrantConfig if self == DB.WeaviateCloud: from .weaviate_cloud.config import WeaviateConfig + return WeaviateConfig if self == DB.PgVector: from .pgvector.config import PgVectorConfig + return PgVectorConfig if self == DB.PgVectoRS: from .pgvecto_rs.config import PgVectoRSConfig + return PgVectoRSConfig if self == DB.PgVectorScale: from .pgvectorscale.config import PgVectorScaleConfig + return PgVectorScaleConfig if self == DB.PgDiskANN: from .pgdiskann.config import PgDiskANNConfig + return PgDiskANNConfig if self == DB.Redis: from .redis.config import RedisConfig + return RedisConfig - + if self == DB.MemoryDB: from .memorydb.config import MemoryDBConfig + return MemoryDBConfig if self == DB.Chroma: from .chroma.config import ChromaConfig + return ChromaConfig if self == DB.AWSOpenSearch: from .aws_opensearch.config import AWSOpenSearchConfig + return AWSOpenSearchConfig - + if self == DB.AlloyDB: from .alloydb.config import AlloyDBConfig + return AlloyDBConfig if self == DB.AliyunElasticsearch: from .aliyun_elasticsearch.config import AliyunElasticsearchConfig + return AliyunElasticsearchConfig if self == DB.AliyunOpenSearch: from .aliyun_opensearch.config import AliyunOpenSearchConfig + return AliyunOpenSearchConfig - def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseConfig]: + msg = f"Unknown DB: {self.name}" + raise ValueError(msg) + + def case_config_cls( # noqa: PLR0911 + self, + index_type: IndexType | None = None, + ) -> type[DBCaseConfig]: if self == DB.Milvus: from .milvus.config import _milvus_case_config + return _milvus_case_config.get(index_type) if self == DB.ZillizCloud: from .zilliz_cloud.config import AutoIndexConfig + return AutoIndexConfig if self == DB.ElasticCloud: from .elastic_cloud.config import ElasticCloudIndexConfig + return ElasticCloudIndexConfig if self == DB.QdrantCloud: from .qdrant_cloud.config import QdrantIndexConfig + return QdrantIndexConfig if self == DB.WeaviateCloud: from .weaviate_cloud.config import WeaviateIndexConfig + return WeaviateIndexConfig if self == DB.PgVector: from .pgvector.config import _pgvector_case_config + return _pgvector_case_config.get(index_type) if self == DB.PgVectoRS: from .pgvecto_rs.config import _pgvecto_rs_case_config + return _pgvecto_rs_case_config.get(index_type) if self == DB.AWSOpenSearch: from .aws_opensearch.config import AWSOpenSearchIndexConfig + return AWSOpenSearchIndexConfig if self == DB.PgVectorScale: from .pgvectorscale.config import _pgvectorscale_case_config + return _pgvectorscale_case_config.get(index_type) if self == DB.PgDiskANN: from .pgdiskann.config import _pgdiskann_case_config + return _pgdiskann_case_config.get(index_type) - + if self == DB.AlloyDB: from .alloydb.config import _alloydb_case_config + return _alloydb_case_config.get(index_type) if self == DB.AliyunElasticsearch: from .elastic_cloud.config import ElasticCloudIndexConfig + return ElasticCloudIndexConfig if self == DB.AliyunOpenSearch: from .aliyun_opensearch.config import AliyunOpenSearchIndexConfig + return AliyunOpenSearchIndexConfig # DB.Pinecone, DB.Chroma, DB.Redis @@ -242,5 +297,11 @@ def case_config_cls(self, index_type: IndexType | None = None) -> Type[DBCaseCon __all__ = [ - "DB", "VectorDB", "DBConfig", "DBCaseConfig", "IndexType", "MetricType", "EmptyDBCaseConfig", + "DB", + "DBCaseConfig", + "DBConfig", + "EmptyDBCaseConfig", + "IndexType", + "MetricType", + "VectorDB", ] diff --git a/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py b/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py index 41253ca1e..92ec905e5 100644 --- a/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py +++ b/vectordb_bench/backend/clients/aliyun_elasticsearch/aliyun_elasticsearch.py @@ -1,5 +1,5 @@ -from ..elastic_cloud.elastic_cloud import ElasticCloud from ..elastic_cloud.config import ElasticCloudIndexConfig +from ..elastic_cloud.elastic_cloud import ElasticCloud class AliyunElasticsearch(ElasticCloud): @@ -24,4 +24,3 @@ def __init__( drop_old=drop_old, **kwargs, ) - diff --git a/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py b/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py index a2de4dc75..a8f5100bf 100644 --- a/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py +++ b/vectordb_bench/backend/clients/aliyun_elasticsearch/config.py @@ -1,7 +1,6 @@ -from enum import Enum -from pydantic import SecretStr, BaseModel +from pydantic import BaseModel, SecretStr -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType +from ..api import DBConfig class AliyunElasticsearchConfig(DBConfig, BaseModel): @@ -14,6 +13,6 @@ class AliyunElasticsearchConfig(DBConfig, BaseModel): def to_dict(self) -> dict: return { - "hosts": [{'scheme': self.scheme, 'host': self.host, 'port': self.port}], + "hosts": [{"scheme": self.scheme, "host": self.host, "port": self.port}], "basic_auth": (self.user, self.password.get_secret_value()), } diff --git a/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py b/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py index 5d4dcbfa6..00227cfff 100644 --- a/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py +++ b/vectordb_bench/backend/clients/aliyun_opensearch/aliyun_opensearch.py @@ -1,32 +1,32 @@ import json import logging -from contextlib import contextmanager import time +from contextlib import contextmanager +from alibabacloud_ha3engine_vector import client, models from alibabacloud_ha3engine_vector.models import QueryRequest - -from ..api import VectorDB, MetricType -from .config import AliyunOpenSearchIndexConfig - -from alibabacloud_searchengine20211025.client import Client as searchengineClient from alibabacloud_searchengine20211025 import models as searchengine_models +from alibabacloud_searchengine20211025.client import Client as searchengineClient from alibabacloud_tea_openapi import models as open_api_models -from alibabacloud_ha3engine_vector import models, client + +from ..api import MetricType, VectorDB +from .config import AliyunOpenSearchIndexConfig log = logging.getLogger(__name__) ALIYUN_OPENSEARCH_MAX_SIZE_PER_BATCH = 2 * 1024 * 1024 # 2MB ALIYUN_OPENSEARCH_MAX_NUM_PER_BATCH = 100 + class AliyunOpenSearch(VectorDB): def __init__( - self, - dim: int, - db_config: dict, - db_case_config: AliyunOpenSearchIndexConfig, - collection_name: str = "VectorDBBenchCollection", - drop_old: bool = False, - **kwargs, + self, + dim: int, + db_config: dict, + db_case_config: AliyunOpenSearchIndexConfig, + collection_name: str = "VectorDBBenchCollection", + drop_old: bool = False, + **kwargs, ): self.control_client = None self.dim = dim @@ -41,14 +41,17 @@ def __init__( self._index_name = "vector_idx" self.batch_size = int( - min(ALIYUN_OPENSEARCH_MAX_SIZE_PER_BATCH / (dim * 25), ALIYUN_OPENSEARCH_MAX_NUM_PER_BATCH) + min( + ALIYUN_OPENSEARCH_MAX_SIZE_PER_BATCH / (dim * 25), + ALIYUN_OPENSEARCH_MAX_NUM_PER_BATCH, + ), ) log.info(f"Aliyun_OpenSearch client config: {self.db_config}") control_config = open_api_models.Config( access_key_id=self.db_config["ak"], access_key_secret=self.db_config["sk"], - endpoint=self.db_config["control_host"] + endpoint=self.db_config["control_host"], ) self.control_client = searchengineClient(control_config) @@ -67,7 +70,7 @@ def _create_index(self, client: searchengineClient): create_table_request.field_schema = { self._primary_field: "INT64", self._vector_field: "MULTI_FLOAT", - self._scalar_field: "INT64" + self._scalar_field: "INT64", } vector_index = searchengine_models.ModifyTableRequestVectorIndex() vector_index.index_name = self._index_name @@ -77,8 +80,25 @@ def _create_index(self, client: searchengineClient): vector_index.vector_index_type = "HNSW" advance_params = searchengine_models.ModifyTableRequestVectorIndexAdvanceParams() - advance_params.build_index_params = "{\"proxima.hnsw.builder.max_neighbor_count\":" + str(self.case_config.M) + ",\"proxima.hnsw.builder.efconstruction\":" + str(self.case_config.efConstruction) + ",\"proxima.hnsw.builder.enable_adsampling\":true,\"proxima.hnsw.builder.slack_pruning_factor\":1.1,\"proxima.hnsw.builder.thread_count\":16}" - advance_params.search_index_params = "{\"proxima.hnsw.searcher.ef\":400,\"proxima.hnsw.searcher.dynamic_termination.prob_threshold\":0.7}" + str_max_neighbor_count = f'"proxima.hnsw.builder.max_neighbor_count":{self.case_config.M}' + str_efc = f'"proxima.hnsw.builder.efconstruction":{self.case_config.ef_construction}' + str_enable_adsampling = '"proxima.hnsw.builder.enable_adsampling":true' + str_slack_pruning_factor = '"proxima.hnsw.builder.slack_pruning_factor":1.1' + str_thread_count = '"proxima.hnsw.builder.thread_count":16' + + params = ",".join( + [ + str_max_neighbor_count, + str_efc, + str_enable_adsampling, + str_slack_pruning_factor, + str_thread_count, + ], + ) + advance_params.build_index_params = params + advance_params.search_index_params = ( + '{"proxima.hnsw.searcher.ef":400,"proxima.hnsw.searcher.dynamic_termination.prob_threshold":0.7}' + ) vector_index.advance_params = advance_params create_table_request.vector_index = [vector_index] @@ -88,7 +108,7 @@ def _create_index(self, client: searchengineClient): except Exception as error: log.info(error.message) log.info(error.data.get("Recommend")) - log.info(f"Failed to create index: error: {str(error)}") + log.info(f"Failed to create index: error: {error!s}") raise error from None # check if index create success @@ -102,22 +122,22 @@ def _active_index(self, client: searchengineClient) -> None: log.info(f"begin to {retry_times} times get table") retry_times += 1 response = client.get_table(self.instance_id, self.collection_name) - if response.body.result.status == 'IN_USE': + if response.body.result.status == "IN_USE": log.info(f"{self.collection_name} table begin to use.") return def _index_exists(self, client: searchengineClient) -> bool: try: client.get_table(self.instance_id, self.collection_name) - return True - except Exception as error: - log.info(f'get table from searchengine error') - log.info(error.message) + except Exception as err: + log.warning(f"get table from searchengine error, err={err}") return False + else: + return True # check if index build success, Insert the embeddings to the vector database after index build success def _index_build_success(self, client: searchengineClient) -> None: - log.info(f"begin to check if table build success.") + log.info("begin to check if table build success.") time.sleep(50) retry_times = 0 @@ -139,9 +159,9 @@ def _index_build_success(self, client: searchengineClient) -> None: cur_fsm = fsm break if cur_fsm is None: - print("no build index fsm") + log.warning("no build index fsm") return - if "success" == cur_fsm["status"]: + if cur_fsm["status"] == "success": return def _modify_index(self, client: searchengineClient) -> None: @@ -154,7 +174,7 @@ def _modify_index(self, client: searchengineClient) -> None: modify_table_request.field_schema = { self._primary_field: "INT64", self._vector_field: "MULTI_FLOAT", - self._scalar_field: "INT64" + self._scalar_field: "INT64", } vector_index = searchengine_models.ModifyTableRequestVectorIndex() vector_index.index_name = self._index_name @@ -163,19 +183,41 @@ def _modify_index(self, client: searchengineClient) -> None: vector_index.vector_field = self._vector_field vector_index.vector_index_type = "HNSW" advance_params = searchengine_models.ModifyTableRequestVectorIndexAdvanceParams() - advance_params.build_index_params = "{\"proxima.hnsw.builder.max_neighbor_count\":" + str(self.case_config.M) + ",\"proxima.hnsw.builder.efconstruction\":" + str(self.case_config.efConstruction) + ",\"proxima.hnsw.builder.enable_adsampling\":true,\"proxima.hnsw.builder.slack_pruning_factor\":1.1,\"proxima.hnsw.builder.thread_count\":16}" - advance_params.search_index_params = "{\"proxima.hnsw.searcher.ef\":400,\"proxima.hnsw.searcher.dynamic_termination.prob_threshold\":0.7}" + + str_max_neighbor_count = f'"proxima.hnsw.builder.max_neighbor_count":{self.case_config.M}' + str_efc = f'"proxima.hnsw.builder.efconstruction":{self.case_config.ef_construction}' + str_enable_adsampling = '"proxima.hnsw.builder.enable_adsampling":true' + str_slack_pruning_factor = '"proxima.hnsw.builder.slack_pruning_factor":1.1' + str_thread_count = '"proxima.hnsw.builder.thread_count":16' + + params = ",".join( + [ + str_max_neighbor_count, + str_efc, + str_enable_adsampling, + str_slack_pruning_factor, + str_thread_count, + ], + ) + advance_params.build_index_params = params + advance_params.search_index_params = ( + '{"proxima.hnsw.searcher.ef":400,"proxima.hnsw.searcher.dynamic_termination.prob_threshold":0.7}' + ) vector_index.advance_params = advance_params modify_table_request.vector_index = [vector_index] try: - response = client.modify_table(self.instance_id, self.collection_name, modify_table_request) + response = client.modify_table( + self.instance_id, + self.collection_name, + modify_table_request, + ) log.info(f"modify table success: {response.body}") except Exception as error: log.info(error.message) log.info(error.data.get("Recommend")) - log.info(f"Failed to modify index: error: {str(error)}") + log.info(f"Failed to modify index: error: {error!s}") raise error from None # check if modify index & delete data fsm success @@ -185,15 +227,14 @@ def _modify_index(self, client: searchengineClient) -> None: def _get_total_count(self): try: response = self.client.stats(self.collection_name) + except Exception as e: + log.warning(f"Error querying index: {e}") + else: body = json.loads(response.body) log.info(f"stats info: {response.body}") if "result" in body and "totalDocCount" in body.get("result"): return body.get("result").get("totalDocCount") - else: - return 0 - except Exception as e: - print(f"Error querying index: {e}") return 0 @contextmanager @@ -203,21 +244,20 @@ def init(self) -> None: endpoint=self.db_config["host"], protocol="http", access_user_name=self.db_config["user"], - access_pass_word=self.db_config["password"] + access_pass_word=self.db_config["password"], ) self.client = client.Client(config) yield - # self.client.transport.close() self.client = None del self.client def insert_embeddings( - self, - embeddings: list[list[float]], - metadata: list[int], - **kwargs, + self, + embeddings: list[list[float]], + metadata: list[int], + **kwargs, ) -> tuple[int, Exception]: """Insert the embeddings to the opensearch.""" assert self.client is not None, "should self.init() first" @@ -226,25 +266,24 @@ def insert_embeddings( try: for batch_start_offset in range(0, len(embeddings), self.batch_size): - batch_end_offset = min( - batch_start_offset + self.batch_size, len(embeddings) - ) + batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) documents = [] for i in range(batch_start_offset, batch_end_offset): - documentFields = { + document_fields = { self._primary_field: metadata[i], self._vector_field: embeddings[i], self._scalar_field: metadata[i], - "ops_build_channel": "inc" - } - document = { - "fields": documentFields, - "cmd": "add" + "ops_build_channel": "inc", } + document = {"fields": document_fields, "cmd": "add"} documents.append(document) - pushDocumentsRequest = models.PushDocumentsRequest({}, documents) - self.client.push_documents(self.collection_name, self._primary_field, pushDocumentsRequest) + push_doc_req = models.PushDocumentsRequest({}, documents) + self.client.push_documents( + self.collection_name, + self._primary_field, + push_doc_req, + ) insert_count += batch_end_offset - batch_start_offset except Exception as e: log.info(f"Failed to insert data: {e}") @@ -252,33 +291,36 @@ def insert_embeddings( return (insert_count, None) def search_embedding( - self, - query: list[float], - k: int = 100, - filters: dict | None = None, + self, + query: list[float], + k: int = 100, + filters: dict | None = None, ) -> list[int]: assert self.client is not None, "should self.init() first" - search_params = "{\"proxima.hnsw.searcher.ef\":"+ str(self.case_config.ef_search) +"}" + search_params = '{"proxima.hnsw.searcher.ef":' + str(self.case_config.ef_search) + "}" os_filter = f"{self._scalar_field} {filters.get('metadata')}" if filters else "" try: - request = QueryRequest(table_name=self.collection_name, - vector=query, - top_k=k, - search_params=search_params, filter=os_filter) + request = QueryRequest( + table_name=self.collection_name, + vector=query, + top_k=k, + search_params=search_params, + filter=os_filter, + ) result = self.client.query(request) except Exception as e: log.info(f"Error querying index: {e}") - raise e - res = json.loads(result.body) - id_res = [one_res["id"] for one_res in res["result"]] - return id_res + raise e from e + else: + res = json.loads(result.body) + return [one_res["id"] for one_res in res["result"]] def need_normalize_cosine(self) -> bool: """Wheather this database need to normalize dataset to support COSINE""" if self.case_config.metric_type == MetricType.COSINE: - log.info(f"cosine dataset need normalize.") + log.info("cosine dataset need normalize.") return True return False @@ -296,9 +338,8 @@ def optimize_with_size(self, data_size: int): total_count = self._get_total_count() # check if the data is inserted if total_count == data_size: - log.info(f"optimize table finish.") + log.info("optimize table finish.") return def ready_to_load(self): """ready_to_load will be called before load in load cases.""" - pass diff --git a/vectordb_bench/backend/clients/aliyun_opensearch/config.py b/vectordb_bench/backend/clients/aliyun_opensearch/config.py index 7b2b9ad13..e215b7d68 100644 --- a/vectordb_bench/backend/clients/aliyun_opensearch/config.py +++ b/vectordb_bench/backend/clients/aliyun_opensearch/config.py @@ -1,8 +1,8 @@ import logging -from enum import Enum -from pydantic import SecretStr, BaseModel -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, MetricType log = logging.getLogger(__name__) @@ -26,18 +26,17 @@ def to_dict(self) -> dict: "control_host": self.control_host, } + class AliyunOpenSearchIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType = MetricType.L2 - efConstruction: int = 500 + ef_construction: int = 500 M: int = 100 ef_search: int = 40 def distance_type(self) -> str: if self.metric_type == MetricType.L2: return "SquaredEuclidean" - elif self.metric_type == MetricType.IP: - return "InnerProduct" - elif self.metric_type == MetricType.COSINE: + if self.metric_type in (MetricType.IP, MetricType.COSINE): return "InnerProduct" return "SquaredEuclidean" diff --git a/vectordb_bench/backend/clients/alloydb/alloydb.py b/vectordb_bench/backend/clients/alloydb/alloydb.py index 5b275b30f..c81f77675 100644 --- a/vectordb_bench/backend/clients/alloydb/alloydb.py +++ b/vectordb_bench/backend/clients/alloydb/alloydb.py @@ -1,9 +1,9 @@ """Wrapper around the alloydb vector database over VectorDB""" import logging -import pprint +from collections.abc import Generator, Sequence from contextlib import contextmanager -from typing import Any, Generator, Optional, Tuple, Sequence +from typing import Any import numpy as np import psycopg @@ -11,7 +11,7 @@ from psycopg import Connection, Cursor, sql from ..api import VectorDB -from .config import AlloyDBConfigDict, AlloyDBIndexConfig, AlloyDBScaNNConfig +from .config import AlloyDBConfigDict, AlloyDBIndexConfig log = logging.getLogger(__name__) @@ -56,13 +56,14 @@ def __init__( ( self.case_config.create_index_before_load, self.case_config.create_index_after_load, - ) + ), ): - err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" - log.error(err) - raise RuntimeError( - f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + msg = ( + f"{self.name} config must create an index using create_index_before_load or create_index_after_load" + "\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" ) + log.warning(msg) + raise RuntimeError(msg) if drop_old: self._drop_index() @@ -77,7 +78,7 @@ def __init__( self.conn = None @staticmethod - def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + def _create_connection(**kwargs) -> tuple[Connection, Cursor]: conn = psycopg.connect(**kwargs) register_vector(conn) conn.autocommit = False @@ -86,21 +87,20 @@ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: assert conn is not None, "Connection is not initialized" assert cursor is not None, "Cursor is not initialized" return conn, cursor - - def _generate_search_query(self, filtered: bool=False) -> sql.Composed: - search_query = sql.Composed( + + def _generate_search_query(self, filtered: bool = False) -> sql.Composed: + return sql.Composed( [ sql.SQL( - "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding " + "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ", ).format( table_name=sql.Identifier(self.table_name), where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), - ] + ], ) - return search_query @contextmanager def init(self) -> Generator[None, None, None]: @@ -119,8 +119,8 @@ def init(self) -> Generator[None, None, None]: if len(session_options) > 0: for setting in session_options: command = sql.SQL("SET {setting_name} " + "= {val};").format( - setting_name=sql.Identifier(setting['parameter']['setting_name']), - val=sql.Identifier(str(setting['parameter']['val'])), + setting_name=sql.Identifier(setting["parameter"]["setting_name"]), + val=sql.Identifier(str(setting["parameter"]["val"])), ) log.debug(command.as_string(self.cursor)) self.cursor.execute(command) @@ -144,8 +144,8 @@ def _drop_table(self): self.cursor.execute( sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) self.conn.commit() @@ -167,7 +167,7 @@ def _drop_index(self): log.info(f"{self.name} client drop index : {self._index_name}") drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( - index_name=sql.Identifier(self._index_name) + index_name=sql.Identifier(self._index_name), ) log.debug(drop_index_sql.as_string(self.cursor)) self.cursor.execute(drop_index_sql) @@ -181,78 +181,64 @@ def _set_parallel_index_build_param(self): if index_param["enable_pca"] is not None: self.cursor.execute( - sql.SQL("SET scann.enable_pca TO {};").format( - index_param["enable_pca"] - ) + sql.SQL("SET scann.enable_pca TO {};").format(index_param["enable_pca"]), ) self.cursor.execute( sql.SQL("ALTER USER {} SET scann.enable_pca TO {};").format( sql.Identifier(self.db_config["user"]), index_param["enable_pca"], - ) + ), ) self.conn.commit() if index_param["maintenance_work_mem"] is not None: self.cursor.execute( sql.SQL("SET maintenance_work_mem TO {};").format( - index_param["maintenance_work_mem"] - ) + index_param["maintenance_work_mem"], + ), ) self.cursor.execute( sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format( sql.Identifier(self.db_config["user"]), index_param["maintenance_work_mem"], - ) + ), ) self.conn.commit() if index_param["max_parallel_workers"] is not None: self.cursor.execute( sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format( - index_param["max_parallel_workers"] - ) + index_param["max_parallel_workers"], + ), ) self.cursor.execute( - sql.SQL( - "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';" - ).format( + sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format( sql.Identifier(self.db_config["user"]), index_param["max_parallel_workers"], - ) + ), ) self.cursor.execute( sql.SQL("SET max_parallel_workers TO '{}';").format( - index_param["max_parallel_workers"] - ) + index_param["max_parallel_workers"], + ), ) self.cursor.execute( - sql.SQL( - "ALTER USER {} SET max_parallel_workers TO '{}';" - ).format( + sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format( sql.Identifier(self.db_config["user"]), index_param["max_parallel_workers"], - ) + ), ) self.cursor.execute( - sql.SQL( - "ALTER TABLE {} SET (parallel_workers = {});" - ).format( + sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format( sql.Identifier(self.table_name), index_param["max_parallel_workers"], - ) + ), ) self.conn.commit() - results = self.cursor.execute( - sql.SQL("SHOW max_parallel_maintenance_workers;") - ).fetchall() - results.extend( - self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall() - ) - results.extend( - self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall() - ) + results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall() + results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()) + results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()) log.info(f"{self.name} parallel index creation parameters: {results}") def _create_index(self): @@ -264,23 +250,20 @@ def _create_index(self): self._set_parallel_index_build_param() options = [] for option in index_param["index_creation_with_options"]: - if option['val'] is not None: + if option["val"] is not None: options.append( sql.SQL("{option_name} = {val}").format( - option_name=sql.Identifier(option['option_name']), - val=sql.Identifier(str(option['val'])), - ) + option_name=sql.Identifier(option["option_name"]), + val=sql.Identifier(str(option["val"])), + ), ) - if any(options): - with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) - else: - with_clause = sql.Composed(()) + with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) index_create_sql = sql.SQL( """ - CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric}) - """ + """, ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), @@ -288,9 +271,7 @@ def _create_index(self): embedding_metric=sql.Identifier(index_param["metric"]), ) - index_create_sql_with_with_clause = ( - index_create_sql + with_clause - ).join(" ") + index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ") log.debug(index_create_sql_with_with_clause.as_string(self.cursor)) self.cursor.execute(index_create_sql_with_with_clause) self.conn.commit() @@ -305,14 +286,12 @@ def _create_table(self, dim: int): # create table self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" - ).format(table_name=sql.Identifier(self.table_name), dim=dim) + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", + ).format(table_name=sql.Identifier(self.table_name), dim=dim), ) self.conn.commit() except Exception as e: - log.warning( - f"Failed to create alloydb table: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create alloydb table: {self.table_name} error: {e}") raise e from None def insert_embeddings( @@ -320,7 +299,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -330,8 +309,8 @@ def insert_embeddings( with self.cursor.copy( sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) as copy: copy.set_types(["bigint", "vector"]) for i, row in enumerate(metadata_arr): @@ -343,9 +322,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: - log.warning( - f"Failed to insert data into alloydb table ({self.table_name}), error: {e}" - ) + log.warning(f"Failed to insert data into alloydb table ({self.table_name}), error: {e}") return 0, e def search_embedding( @@ -362,11 +339,12 @@ def search_embedding( if filters: gt = filters.get("id") result = self.cursor.execute( - self._filtered_search, (gt, q, k), prepare=True, binary=True + self._filtered_search, + (gt, q, k), + prepare=True, + binary=True, ) else: - result = self.cursor.execute( - self._unfiltered_search, (q, k), prepare=True, binary=True - ) + result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True) return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/backend/clients/alloydb/cli.py b/vectordb_bench/backend/clients/alloydb/cli.py index 54d9b9fa3..e207c9c21 100644 --- a/vectordb_bench/backend/clients/alloydb/cli.py +++ b/vectordb_bench/backend/clients/alloydb/cli.py @@ -1,10 +1,10 @@ -from typing import Annotated, Optional, TypedDict, Unpack +import os +from typing import Annotated, Unpack import click -import os from pydantic import SecretStr -from vectordb_bench.backend.clients.api import MetricType +from vectordb_bench.backend.clients import DB from ....cli.cli import ( CommonTypedDict, @@ -13,31 +13,28 @@ get_custom_case_config, run, ) -from vectordb_bench.backend.clients import DB class AlloyDBTypedDict(CommonTypedDict): user_name: Annotated[ - str, click.option("--user-name", type=str, help="Db username", required=True) + str, + click.option("--user-name", type=str, help="Db username", required=True), ] password: Annotated[ str, - click.option("--password", - type=str, - help="Postgres database password", - default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), - show_default="$POSTGRES_PASSWORD", - ), + click.option( + "--password", + type=str, + help="Postgres database password", + default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), + show_default="$POSTGRES_PASSWORD", + ), ] - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] - db_name: Annotated[ - str, click.option("--db-name", type=str, help="Db name", required=True) - ] + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] + db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)] maintenance_work_mem: Annotated[ - Optional[str], + str | None, click.option( "--maintenance-work-mem", type=str, @@ -49,7 +46,7 @@ class AlloyDBTypedDict(CommonTypedDict): ), ] max_parallel_workers: Annotated[ - Optional[int], + int | None, click.option( "--max-parallel-workers", type=int, @@ -58,32 +55,51 @@ class AlloyDBTypedDict(CommonTypedDict): ), ] - class AlloyDBScaNNTypedDict(AlloyDBTypedDict): num_leaves: Annotated[ int, - click.option("--num-leaves", type=int, help="Number of leaves", required=True) + click.option("--num-leaves", type=int, help="Number of leaves", required=True), ] num_leaves_to_search: Annotated[ int, - click.option("--num-leaves-to-search", type=int, help="Number of leaves to search", required=True) + click.option( + "--num-leaves-to-search", + type=int, + help="Number of leaves to search", + required=True, + ), ] pre_reordering_num_neighbors: Annotated[ int, - click.option("--pre-reordering-num-neighbors", type=int, help="Pre-reordering number of neighbors", default=200) + click.option( + "--pre-reordering-num-neighbors", + type=int, + help="Pre-reordering number of neighbors", + default=200, + ), ] max_top_neighbors_buffer_size: Annotated[ int, - click.option("--max-top-neighbors-buffer-size", type=int, help="Maximum top neighbors buffer size", default=20_000) + click.option( + "--max-top-neighbors-buffer-size", + type=int, + help="Maximum top neighbors buffer size", + default=20_000, + ), ] num_search_threads: Annotated[ int, - click.option("--num-search-threads", type=int, help="Number of search threads", default=2) + click.option("--num-search-threads", type=int, help="Number of search threads", default=2), ] max_num_prefetch_datasets: Annotated[ int, - click.option("--max-num-prefetch-datasets", type=int, help="Maximum number of prefetch datasets", default=100) + click.option( + "--max-num-prefetch-datasets", + type=int, + help="Maximum number of prefetch datasets", + default=100, + ), ] quantizer: Annotated[ str, @@ -91,16 +107,17 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict): "--quantizer", type=click.Choice(["SQ8", "FLAT"]), help="Quantizer type", - default="SQ8" - ) + default="SQ8", + ), ] enable_pca: Annotated[ - bool, click.option( + bool, + click.option( "--enable-pca", type=click.Choice(["on", "off"]), help="Enable PCA", - default="on" - ) + default="on", + ), ] max_num_levels: Annotated[ int, @@ -108,8 +125,8 @@ class AlloyDBScaNNTypedDict(AlloyDBTypedDict): "--max-num-levels", type=click.Choice(["1", "2"]), help="Maximum number of levels", - default=1 - ) + default=1, + ), ] @@ -144,4 +161,4 @@ def AlloyDBScaNN( maintenance_work_mem=parameters["maintenance_work_mem"], ), **parameters, - ) \ No newline at end of file + ) diff --git a/vectordb_bench/backend/clients/alloydb/config.py b/vectordb_bench/backend/clients/alloydb/config.py index 1d5dde519..d6e54e487 100644 --- a/vectordb_bench/backend/clients/alloydb/config.py +++ b/vectordb_bench/backend/clients/alloydb/config.py @@ -1,7 +1,9 @@ from abc import abstractmethod -from typing import Any, Mapping, Optional, Sequence, TypedDict +from collections.abc import Mapping, Sequence +from typing import Any, LiteralString, TypedDict + from pydantic import BaseModel, SecretStr -from typing_extensions import LiteralString + from ..api import DBCaseConfig, DBConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" @@ -9,7 +11,7 @@ class AlloyDBConfigDict(TypedDict): """These keys will be directly used as kwargs in psycopg connection string, - so the names must match exactly psycopg API""" + so the names must match exactly psycopg API""" user: str password: str @@ -41,8 +43,8 @@ class AlloyDBIndexParam(TypedDict): metric: str index_type: str index_creation_with_options: Sequence[dict[str, Any]] - maintenance_work_mem: Optional[str] - max_parallel_workers: Optional[int] + maintenance_work_mem: str | None + max_parallel_workers: int | None class AlloyDBSearchParam(TypedDict): @@ -61,31 +63,30 @@ class AlloyDBIndexConfig(BaseModel, DBCaseConfig): def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "l2" - elif self.metric_type == MetricType.DP: + if self.metric_type == MetricType.DP: return "dot_product" return "cosine" def parse_metric_fun_op(self) -> LiteralString: if self.metric_type == MetricType.L2: return "<->" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "<#>" return "<=>" @abstractmethod - def index_param(self) -> AlloyDBIndexParam: - ... + def index_param(self) -> AlloyDBIndexParam: ... @abstractmethod - def search_param(self) -> AlloyDBSearchParam: - ... + def search_param(self) -> AlloyDBSearchParam: ... @abstractmethod - def session_param(self) -> AlloyDBSessionCommands: - ... + def session_param(self) -> AlloyDBSessionCommands: ... @staticmethod - def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]: + def _optionally_build_with_options( + with_options: Mapping[str, Any], + ) -> Sequence[dict[str, Any]]: """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause""" options = [] for option_name, value in with_options.items(): @@ -94,24 +95,25 @@ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[ { "option_name": option_name, "val": str(value), - } + }, ) return options @staticmethod def _optionally_build_set_options( - set_mapping: Mapping[str, Any] + set_mapping: Mapping[str, Any], ) -> Sequence[dict[str, Any]]: """Walk through options, creating 'SET 'key1 = "value1";' list""" session_options = [] for setting_name, value in set_mapping.items(): if value: session_options.append( - {"parameter": { + { + "parameter": { "setting_name": setting_name, "val": str(value), }, - } + }, ) return session_options @@ -124,22 +126,22 @@ class AlloyDBScaNNConfig(AlloyDBIndexConfig): max_num_levels: int | None num_leaves_to_search: int | None max_top_neighbors_buffer_size: int | None - pre_reordering_num_neighbors: int | None - num_search_threads: int | None + pre_reordering_num_neighbors: int | None + num_search_threads: int | None max_num_prefetch_datasets: int | None - maintenance_work_mem: Optional[str] = None - max_parallel_workers: Optional[int] = None + maintenance_work_mem: str | None = None + max_parallel_workers: int | None = None def index_param(self) -> AlloyDBIndexParam: index_parameters = { - "num_leaves": self.num_leaves, "max_num_levels": self.max_num_levels, "quantizer": self.quantizer, + "num_leaves": self.num_leaves, + "max_num_levels": self.max_num_levels, + "quantizer": self.quantizer, } return { "metric": self.parse_metric(), "index_type": self.index.value, - "index_creation_with_options": self._optionally_build_with_options( - index_parameters - ), + "index_creation_with_options": self._optionally_build_with_options(index_parameters), "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "enable_pca": self.enable_pca, @@ -158,11 +160,9 @@ def session_param(self) -> AlloyDBSessionCommands: "scann.num_search_threads": self.num_search_threads, "scann.max_num_prefetch_datasets": self.max_num_prefetch_datasets, } - return { - "session_options": self._optionally_build_set_options(session_parameters) - } + return {"session_options": self._optionally_build_set_options(session_parameters)} _alloydb_case_config = { - IndexType.SCANN: AlloyDBScaNNConfig, + IndexType.SCANN: AlloyDBScaNNConfig, } diff --git a/vectordb_bench/backend/clients/api.py b/vectordb_bench/backend/clients/api.py index fe2e554f3..aa93abc12 100644 --- a/vectordb_bench/backend/clients/api.py +++ b/vectordb_bench/backend/clients/api.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod -from enum import Enum -from typing import Any, Type from contextlib import contextmanager +from enum import Enum -from pydantic import BaseModel, validator, SecretStr +from pydantic import BaseModel, SecretStr, validator class MetricType(str, Enum): @@ -65,13 +64,10 @@ def to_dict(self) -> dict: raise NotImplementedError @validator("*") - def not_empty_field(cls, v, field): - if ( - field.name in cls.common_short_configs() - or field.name in cls.common_long_configs() - ): + def not_empty_field(cls, v: any, field: any): + if field.name in cls.common_short_configs() or field.name in cls.common_long_configs(): return v - if not v and isinstance(v, (str, SecretStr)): + if not v and isinstance(v, str | SecretStr): raise ValueError("Empty string!") return v diff --git a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py index 60e519d7a..487ec67cc 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py +++ b/vectordb_bench/backend/clients/aws_opensearch/aws_opensearch.py @@ -1,14 +1,18 @@ import logging -from contextlib import contextmanager import time -from typing import Iterable, Type -from ..api import VectorDB, DBCaseConfig, DBConfig, IndexType -from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine +from collections.abc import Iterable +from contextlib import contextmanager + from opensearchpy import OpenSearch -from opensearchpy.helpers import bulk + +from ..api import IndexType, VectorDB +from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig, AWSOS_Engine log = logging.getLogger(__name__) +WAITING_FOR_REFRESH_SEC = 30 +WAITING_FOR_FORCE_MERGE_SEC = 30 + class AWSOpenSearch(VectorDB): def __init__( @@ -27,9 +31,7 @@ def __init__( self.case_config = db_case_config self.index_name = index_name self.id_col_name = id_col_name - self.category_col_names = [ - f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000] - ] + self.category_col_names = [f"scalar-{categoryCount}" for categoryCount in [2, 5, 10, 100, 1000]] self.vector_col_name = vector_col_name log.info(f"AWS_OpenSearch client config: {self.db_config}") @@ -46,38 +48,32 @@ def config_cls(cls) -> AWSOpenSearchConfig: return AWSOpenSearchConfig @classmethod - def case_config_cls( - cls, index_type: IndexType | None = None - ) -> AWSOpenSearchIndexConfig: + def case_config_cls(cls, index_type: IndexType | None = None) -> AWSOpenSearchIndexConfig: return AWSOpenSearchIndexConfig def _create_index(self, client: OpenSearch): settings = { "index": { "knn": True, - # "number_of_shards": 5, - # "refresh_interval": "600s", - } + }, } mappings = { "properties": { - **{ - categoryCol: {"type": "keyword"} - for categoryCol in self.category_col_names - }, + **{categoryCol: {"type": "keyword"} for categoryCol in self.category_col_names}, self.vector_col_name: { "type": "knn_vector", "dimension": self.dim, "method": self.case_config.index_param(), }, - } + }, } try: client.indices.create( - index=self.index_name, body=dict(settings=settings, mappings=mappings) + index=self.index_name, + body={"settings": settings, "mappings": mappings}, ) except Exception as e: - log.warning(f"Failed to create index: {self.index_name} error: {str(e)}") + log.warning(f"Failed to create index: {self.index_name} error: {e!s}") raise e from None @contextmanager @@ -86,7 +82,6 @@ def init(self) -> None: self.client = OpenSearch(**self.db_config) yield - # self.client.transport.close() self.client = None del self.client @@ -101,16 +96,20 @@ def insert_embeddings( insert_data = [] for i in range(len(embeddings)): - insert_data.append({"index": {"_index": self.index_name, self.id_col_name: metadata[i]}}) + insert_data.append( + {"index": {"_index": self.index_name, self.id_col_name: metadata[i]}}, + ) insert_data.append({self.vector_col_name: embeddings[i]}) try: resp = self.client.bulk(insert_data) log.info(f"AWS_OpenSearch adding documents: {len(resp['items'])}") resp = self.client.indices.stats(self.index_name) - log.info(f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}") + log.info( + f"Total document count in index: {resp['_all']['primaries']['indexing']['index_total']}", + ) return (len(embeddings), None) except Exception as e: - log.warning(f"Failed to insert data: {self.index_name} error: {str(e)}") + log.warning(f"Failed to insert data: {self.index_name} error: {e!s}") time.sleep(10) return self.insert_embeddings(embeddings, metadata) @@ -135,20 +134,23 @@ def search_embedding( body = { "size": k, "query": {"knn": {self.vector_col_name: {"vector": query, "k": k}}}, - **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}) + **({"filter": {"range": {self.id_col_name: {"gt": filters["id"]}}}} if filters else {}), } try: - resp = self.client.search(index=self.index_name, body=body,size=k,_source=False,docvalue_fields=[self.id_col_name],stored_fields="_none_",) + resp = self.client.search( + index=self.index_name, + body=body, + size=k, + _source=False, + docvalue_fields=[self.id_col_name], + stored_fields="_none_", + ) log.info(f'Search took: {resp["took"]}') log.info(f'Search shards: {resp["_shards"]}') log.info(f'Search hits total: {resp["hits"]["total"]}') - result = [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]] - #result = [int(d["_id"]) for d in resp["hits"]["hits"]] - # log.info(f'success! length={len(res)}') - - return result + return [int(h["fields"][self.id_col_name][0]) for h in resp["hits"]["hits"]] except Exception as e: - log.warning(f"Failed to search: {self.index_name} error: {str(e)}") + log.warning(f"Failed to search: {self.index_name} error: {e!s}") raise e from None def optimize(self): @@ -163,37 +165,35 @@ def optimize(self): def _refresh_index(self): log.debug(f"Starting refresh for index {self.index_name}") - SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30 while True: try: - log.info(f"Starting the Refresh Index..") + log.info("Starting the Refresh Index..") self.client.indices.refresh(index=self.index_name) break except Exception as e: log.info( - f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}") - time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC) + f"Refresh errored out. Sleeping for {WAITING_FOR_REFRESH_SEC} sec and then Retrying : {e}", + ) + time.sleep(WAITING_FOR_REFRESH_SEC) continue log.debug(f"Completed refresh for index {self.index_name}") def _do_force_merge(self): log.debug(f"Starting force merge for index {self.index_name}") - force_merge_endpoint = f'/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false' - force_merge_task_id = self.client.transport.perform_request('POST', force_merge_endpoint)['task'] - SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30 + force_merge_endpoint = f"/{self.index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false" + force_merge_task_id = self.client.transport.perform_request("POST", force_merge_endpoint)["task"] while True: - time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC) + time.sleep(WAITING_FOR_FORCE_MERGE_SEC) task_status = self.client.tasks.get(task_id=force_merge_task_id) - if task_status['completed']: + if task_status["completed"]: break log.debug(f"Completed force merge for index {self.index_name}") def _load_graphs_to_memory(self): if self.case_config.engine != AWSOS_Engine.lucene: log.info("Calling warmup API to load graphs into memory") - warmup_endpoint = f'/_plugins/_knn/warmup/{self.index_name}' - self.client.transport.perform_request('GET', warmup_endpoint) + warmup_endpoint = f"/_plugins/_knn/warmup/{self.index_name}" + self.client.transport.perform_request("GET", warmup_endpoint) def ready_to_load(self): """ready_to_load will be called before load in load cases.""" - pass diff --git a/vectordb_bench/backend/clients/aws_opensearch/cli.py b/vectordb_bench/backend/clients/aws_opensearch/cli.py index 5cb4ebbe1..bb0c2450d 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/cli.py +++ b/vectordb_bench/backend/clients/aws_opensearch/cli.py @@ -14,22 +14,20 @@ class AWSOpenSearchTypedDict(TypedDict): - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] port: Annotated[int, click.option("--port", type=int, default=443, help="Db Port")] user: Annotated[str, click.option("--user", type=str, default="admin", help="Db User")] password: Annotated[str, click.option("--password", type=str, help="Db password")] -class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): - ... +class AWSOpenSearchHNSWTypedDict(CommonTypedDict, AWSOpenSearchTypedDict, HNSWFlavor2): ... @cli.command() @click_parameter_decorators_from_typed_dict(AWSOpenSearchHNSWTypedDict) def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]): from .config import AWSOpenSearchConfig, AWSOpenSearchIndexConfig + run( db=DB.AWSOpenSearch, db_config=AWSOpenSearchConfig( @@ -38,7 +36,6 @@ def AWSOpenSearch(**parameters: Unpack[AWSOpenSearchHNSWTypedDict]): user=parameters["user"], password=SecretStr(parameters["password"]), ), - db_case_config=AWSOpenSearchIndexConfig( - ), + db_case_config=AWSOpenSearchIndexConfig(), **parameters, ) diff --git a/vectordb_bench/backend/clients/aws_opensearch/config.py b/vectordb_bench/backend/clients/aws_opensearch/config.py index 15cd4ead8..e9ccc7277 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/config.py +++ b/vectordb_bench/backend/clients/aws_opensearch/config.py @@ -1,10 +1,13 @@ import logging from enum import Enum -from pydantic import SecretStr, BaseModel -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, MetricType log = logging.getLogger(__name__) + + class AWSOpenSearchConfig(DBConfig, BaseModel): host: str = "" port: int = 443 @@ -13,7 +16,7 @@ class AWSOpenSearchConfig(DBConfig, BaseModel): def to_dict(self) -> dict: return { - "hosts": [{'host': self.host, 'port': self.port}], + "hosts": [{"host": self.host, "port": self.port}], "http_auth": (self.user, self.password.get_secret_value()), "use_ssl": True, "http_compress": True, @@ -40,25 +43,26 @@ class AWSOpenSearchIndexConfig(BaseModel, DBCaseConfig): def parse_metric(self) -> str: if self.metric_type == MetricType.IP: return "innerproduct" - elif self.metric_type == MetricType.COSINE: + if self.metric_type == MetricType.COSINE: if self.engine == AWSOS_Engine.faiss: - log.info(f"Using metric type as innerproduct because faiss doesn't support cosine as metric type for Opensearch") + log.info( + "Using innerproduct because faiss doesn't support cosine as metric type for Opensearch", + ) return "innerproduct" return "cosinesimil" return "l2" def index_param(self) -> dict: - params = { + return { "name": "hnsw", "space_type": self.parse_metric(), "engine": self.engine.value, "parameters": { "ef_construction": self.efConstruction, "m": self.M, - "ef_search": self.efSearch - } + "ef_search": self.efSearch, + }, } - return params def search_param(self) -> dict: return {} diff --git a/vectordb_bench/backend/clients/aws_opensearch/run.py b/vectordb_bench/backend/clients/aws_opensearch/run.py index d2698d139..68aa200a5 100644 --- a/vectordb_bench/backend/clients/aws_opensearch/run.py +++ b/vectordb_bench/backend/clients/aws_opensearch/run.py @@ -1,12 +1,16 @@ -import time, random +import logging +import random +import time + from opensearchpy import OpenSearch -from opensearch_dsl import Search, Document, Text, Keyword -_HOST = 'xxxxxx.us-west-2.es.amazonaws.com' +log = logging.getLogger(__name__) + +_HOST = "xxxxxx.us-west-2.es.amazonaws.com" _PORT = 443 -_AUTH = ('admin', 'xxxxxx') # For testing only. Don't store credentials in code. +_AUTH = ("admin", "xxxxxx") # For testing only. Don't store credentials in code. -_INDEX_NAME = 'my-dsl-index' +_INDEX_NAME = "my-dsl-index" _BATCH = 100 _ROWS = 100 _DIM = 128 @@ -14,25 +18,24 @@ def create_client(): - client = OpenSearch( - hosts=[{'host': _HOST, 'port': _PORT}], - http_compress=True, # enables gzip compression for request bodies + return OpenSearch( + hosts=[{"host": _HOST, "port": _PORT}], + http_compress=True, # enables gzip compression for request bodies http_auth=_AUTH, use_ssl=True, verify_certs=True, ssl_assert_hostname=False, ssl_show_warn=False, ) - return client -def create_index(client, index_name): +def create_index(client: OpenSearch, index_name: str): settings = { "index": { "knn": True, "number_of_shards": 1, "refresh_interval": "5s", - } + }, } mappings = { "properties": { @@ -46,41 +49,46 @@ def create_index(client, index_name): "parameters": { "ef_construction": 256, "m": 16, - } - } - } - } + }, + }, + }, + }, } - response = client.indices.create(index=index_name, body=dict(settings=settings, mappings=mappings)) - print('\nCreating index:') - print(response) + response = client.indices.create( + index=index_name, + body={"settings": settings, "mappings": mappings}, + ) + log.info("\nCreating index:") + log.info(response) -def delete_index(client, index_name): +def delete_index(client: OpenSearch, index_name: str): response = client.indices.delete(index=index_name) - print('\nDeleting index:') - print(response) + log.info("\nDeleting index:") + log.info(response) -def bulk_insert(client, index_name): +def bulk_insert(client: OpenSearch, index_name: str): # Perform bulk operations - ids = [i for i in range(_ROWS)] + ids = list(range(_ROWS)) vec = [[random.random() for _ in range(_DIM)] for _ in range(_ROWS)] docs = [] for i in range(0, _ROWS, _BATCH): docs.clear() - for j in range(0, _BATCH): - docs.append({"index": {"_index": index_name, "_id": ids[i+j]}}) - docs.append({"embedding": vec[i+j]}) + for j in range(_BATCH): + docs.append({"index": {"_index": index_name, "_id": ids[i + j]}}) + docs.append({"embedding": vec[i + j]}) response = client.bulk(docs) - print('\nAdding documents:', len(response['items']), response['errors']) + log.info(f"Adding documents: {len(response['items'])}, {response['errors']}") response = client.indices.stats(index_name) - print('\nTotal document count in index:', response['_all']['primaries']['indexing']['index_total']) + log.info( + f'Total document count in index: { response["_all"]["primaries"]["indexing"]["index_total"] }', + ) -def search(client, index_name): +def search(client: OpenSearch, index_name: str): # Search for the document. search_body = { "size": _TOPK, @@ -89,53 +97,55 @@ def search(client, index_name): "embedding": { "vector": [random.random() for _ in range(_DIM)], "k": _TOPK, - } - } - } + }, + }, + }, } while True: response = client.search(index=index_name, body=search_body) - print(f'\nSearch took: {response["took"]}') - print(f'\nSearch shards: {response["_shards"]}') - print(f'\nSearch hits total: {response["hits"]["total"]}') + log.info(f'\nSearch took: {response["took"]}') + log.info(f'\nSearch shards: {response["_shards"]}') + log.info(f'\nSearch hits total: {response["hits"]["total"]}') result = response["hits"]["hits"] if len(result) != 0: - print('\nSearch results:') + log.info("\nSearch results:") for hit in response["hits"]["hits"]: - print(hit["_id"], hit["_score"]) + log.info(hit["_id"], hit["_score"]) break - else: - print('\nSearch not ready, sleep 1s') - time.sleep(1) - -def optimize_index(client, index_name): - print(f"Starting force merge for index {index_name}") - force_merge_endpoint = f'/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false' - force_merge_task_id = client.transport.perform_request('POST', force_merge_endpoint)['task'] - SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30 + log.info("\nSearch not ready, sleep 1s") + time.sleep(1) + + +SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30 +WAITINT_FOR_REFRESH_SEC = 30 + + +def optimize_index(client: OpenSearch, index_name: str): + log.info(f"Starting force merge for index {index_name}") + force_merge_endpoint = f"/{index_name}/_forcemerge?max_num_segments=1&wait_for_completion=false" + force_merge_task_id = client.transport.perform_request("POST", force_merge_endpoint)["task"] while True: time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC) task_status = client.tasks.get(task_id=force_merge_task_id) - if task_status['completed']: + if task_status["completed"]: break - print(f"Completed force merge for index {index_name}") + log.info(f"Completed force merge for index {index_name}") -def refresh_index(client, index_name): - print(f"Starting refresh for index {index_name}") - SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC = 30 +def refresh_index(client: OpenSearch, index_name: str): + log.info(f"Starting refresh for index {index_name}") while True: try: - print(f"Starting the Refresh Index..") + log.info("Starting the Refresh Index..") client.indices.refresh(index=index_name) break except Exception as e: - print( - f"Refresh errored out. Sleeping for {SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC} sec and then Retrying : {e}") - time.sleep(SECONDS_WAITING_FOR_REFRESH_API_CALL_SEC) + log.info( + f"Refresh errored out. Sleeping for {WAITINT_FOR_REFRESH_SEC} sec and then Retrying : {e}", + ) + time.sleep(WAITINT_FOR_REFRESH_SEC) continue - print(f"Completed refresh for index {index_name}") - + log.info(f"Completed refresh for index {index_name}") def main(): @@ -148,9 +158,9 @@ def main(): search(client, _INDEX_NAME) delete_index(client, _INDEX_NAME) except Exception as e: - print(e) + log.info(e) delete_index(client, _INDEX_NAME) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/vectordb_bench/backend/clients/chroma/chroma.py b/vectordb_bench/backend/clients/chroma/chroma.py index 235cb595b..a148fa141 100644 --- a/vectordb_bench/backend/clients/chroma/chroma.py +++ b/vectordb_bench/backend/clients/chroma/chroma.py @@ -1,55 +1,55 @@ -import chromadb -import logging +import logging from contextlib import contextmanager from typing import Any -from ..api import VectorDB, DBCaseConfig + +import chromadb + +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) + + class ChromaClient(VectorDB): - """Chroma client for VectorDB. + """Chroma client for VectorDB. To set up Chroma in docker, see https://docs.trychroma.com/usage-guide or the instructions in tests/test_chroma.py To change to running in process, modify the HttpClient() in __init__() and init(). - """ + """ def __init__( - self, - dim: int, - db_config: dict, - db_case_config: DBCaseConfig, - drop_old: bool = False, - - **kwargs - ): - + self, + dim: int, + db_config: dict, + db_case_config: DBCaseConfig, + drop_old: bool = False, + **kwargs, + ): self.db_config = db_config self.case_config = db_case_config - self.collection_name = 'example2' + self.collection_name = "example2" - client = chromadb.HttpClient(host=self.db_config["host"], - port=self.db_config["port"]) + client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) assert client.heartbeat() is not None if drop_old: try: - client.reset() # Reset the database - except: + client.reset() # Reset the database + except Exception: drop_old = False log.info(f"Chroma client drop_old collection: {self.collection_name}") @contextmanager def init(self) -> None: - """ create and destory connections to database. + """create and destory connections to database. Examples: >>> with self.init(): >>> self.insert_embeddings() """ - #create connection - self.client = chromadb.HttpClient(host=self.db_config["host"], - port=self.db_config["port"]) - - self.collection = self.client.get_or_create_collection('example2') + # create connection + self.client = chromadb.HttpClient(host=self.db_config["host"], port=self.db_config["port"]) + + self.collection = self.client.get_or_create_collection("example2") yield self.client = None self.collection = None @@ -79,12 +79,12 @@ def insert_embeddings( Returns: (int, Exception): number of embeddings inserted and exception if any """ - ids=[str(i) for i in metadata] - metadata = [{"id": int(i)} for i in metadata] + ids = [str(i) for i in metadata] + metadata = [{"id": int(i)} for i in metadata] if len(embeddings) > 0: self.collection.add(embeddings=embeddings, ids=ids, metadatas=metadata) return len(embeddings), None - + def search_embedding( self, query: list[float], @@ -100,17 +100,19 @@ def search_embedding( kwargs: other arguments Returns: - Dict {ids: list[list[int]], - embedding: list[list[float]] + Dict {ids: list[list[int]], + embedding: list[list[float]] distance: list[list[float]]} """ if filters: # assumes benchmark test filters of format: {'metadata': '>=10000', 'id': 10000} id_value = filters.get("id") - results = self.collection.query(query_embeddings=query, n_results=k, - where={"id": {"$gt": id_value}}) - #return list of id's in results - return [int(i) for i in results.get('ids')[0]] + results = self.collection.query( + query_embeddings=query, + n_results=k, + where={"id": {"$gt": id_value}}, + ) + # return list of id's in results + return [int(i) for i in results.get("ids")[0]] results = self.collection.query(query_embeddings=query, n_results=k) - return [int(i) for i in results.get('ids')[0]] - + return [int(i) for i in results.get("ids")[0]] diff --git a/vectordb_bench/backend/clients/chroma/config.py b/vectordb_bench/backend/clients/chroma/config.py index 85c59c973..af34cf513 100644 --- a/vectordb_bench/backend/clients/chroma/config.py +++ b/vectordb_bench/backend/clients/chroma/config.py @@ -1,14 +1,16 @@ from pydantic import SecretStr + from ..api import DBConfig + class ChromaConfig(DBConfig): password: SecretStr host: SecretStr - port: int + port: int def to_dict(self) -> dict: return { "host": self.host.get_secret_value(), "port": self.port, "password": self.password.get_secret_value(), - } \ No newline at end of file + } diff --git a/vectordb_bench/backend/clients/elastic_cloud/config.py b/vectordb_bench/backend/clients/elastic_cloud/config.py index 204a4fc9d..d35ee68ec 100644 --- a/vectordb_bench/backend/clients/elastic_cloud/config.py +++ b/vectordb_bench/backend/clients/elastic_cloud/config.py @@ -1,7 +1,8 @@ from enum import Enum -from pydantic import SecretStr, BaseModel -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType class ElasticCloudConfig(DBConfig, BaseModel): @@ -32,12 +33,12 @@ class ElasticCloudIndexConfig(BaseModel, DBCaseConfig): def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "l2_norm" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "dot_product" return "cosine" def index_param(self) -> dict: - params = { + return { "type": "dense_vector", "index": True, "element_type": self.element_type.value, @@ -48,7 +49,6 @@ def index_param(self) -> dict: "ef_construction": self.efConstruction, }, } - return params def search_param(self) -> dict: return { diff --git a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py index 64f27e490..a3183bcb7 100644 --- a/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py +++ b/vectordb_bench/backend/clients/elastic_cloud/elastic_cloud.py @@ -1,17 +1,22 @@ import logging import time +from collections.abc import Iterable from contextlib import contextmanager -from typing import Iterable -from ..api import VectorDB -from .config import ElasticCloudIndexConfig + from elasticsearch.helpers import bulk +from ..api import VectorDB +from .config import ElasticCloudIndexConfig for logger in ("elasticsearch", "elastic_transport"): logging.getLogger(logger).setLevel(logging.WARNING) log = logging.getLogger(__name__) + +SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30 + + class ElasticCloud(VectorDB): def __init__( self, @@ -46,14 +51,14 @@ def __init__( def init(self) -> None: """connect to elasticsearch""" from elasticsearch import Elasticsearch + self.client = Elasticsearch(**self.db_config, request_timeout=180) yield - # self.client.transport.close() self.client = None - del(self.client) + del self.client - def _create_indice(self, client) -> None: + def _create_indice(self, client: any) -> None: mappings = { "_source": {"excludes": [self.vector_col_name]}, "properties": { @@ -62,13 +67,13 @@ def _create_indice(self, client) -> None: "dims": self.dim, **self.case_config.index_param(), }, - } + }, } try: client.indices.create(index=self.indice, mappings=mappings) except Exception as e: - log.warning(f"Failed to create indice: {self.indice} error: {str(e)}") + log.warning(f"Failed to create indice: {self.indice} error: {e!s}") raise e from None def insert_embeddings( @@ -94,7 +99,7 @@ def insert_embeddings( bulk_insert_res = bulk(self.client, insert_data) return (bulk_insert_res[0], None) except Exception as e: - log.warning(f"Failed to insert data: {self.indice} error: {str(e)}") + log.warning(f"Failed to insert data: {self.indice} error: {e!s}") return (0, e) def search_embedding( @@ -114,16 +119,12 @@ def search_embedding( list[tuple[int, float]]: list of k most similar embeddings in (id, score) tuple to the query embedding. """ assert self.client is not None, "should self.init() first" - # is_existed_res = self.client.indices.exists(index=self.indice) - # assert is_existed_res.raw == True, "should self.init() first" knn = { "field": self.vector_col_name, "k": k, "num_candidates": self.case_config.num_candidates, - "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] - if filters - else [], + "filter": [{"range": {self.id_col_name: {"gt": filters["id"]}}}] if filters else [], "query_vector": query, } size = k @@ -137,26 +138,26 @@ def search_embedding( stored_fields="_none_", filter_path=[f"hits.hits.fields.{self.id_col_name}"], ) - res = [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]] - - return res + return [h["fields"][self.id_col_name][0] for h in res["hits"]["hits"]] except Exception as e: - log.warning(f"Failed to search: {self.indice} error: {str(e)}") + log.warning(f"Failed to search: {self.indice} error: {e!s}") raise e from None def optimize(self): """optimize will be called between insertion and search in performance cases.""" assert self.client is not None, "should self.init() first" self.client.indices.refresh(index=self.indice) - force_merge_task_id = self.client.indices.forcemerge(index=self.indice, max_num_segments=1, wait_for_completion=False)['task'] + force_merge_task_id = self.client.indices.forcemerge( + index=self.indice, + max_num_segments=1, + wait_for_completion=False, + )["task"] log.info(f"Elasticsearch force merge task id: {force_merge_task_id}") - SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC = 30 while True: time.sleep(SECONDS_WAITING_FOR_FORCE_MERGE_API_CALL_SEC) task_status = self.client.tasks.get(task_id=force_merge_task_id) - if task_status['completed']: + if task_status["completed"]: return def ready_to_load(self): """ready_to_load will be called before load in load cases.""" - pass diff --git a/vectordb_bench/backend/clients/memorydb/cli.py b/vectordb_bench/backend/clients/memorydb/cli.py index 50b5f89ba..ae00bfd17 100644 --- a/vectordb_bench/backend/clients/memorydb/cli.py +++ b/vectordb_bench/backend/clients/memorydb/cli.py @@ -14,9 +14,7 @@ class MemoryDBTypedDict(TypedDict): - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] password: Annotated[str, click.option("--password", type=str, help="Db password")] port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")] ssl: Annotated[ @@ -44,7 +42,10 @@ class MemoryDBTypedDict(TypedDict): is_flag=True, show_default=True, default=False, - help="Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance. In production, MemoryDB only supports cluster mode (CME)", + help=( + "Cluster Mode Disabled (CMD), use this flag when testing locally on a single node instance.", + " In production, MemoryDB only supports cluster mode (CME)", + ), ), ] insert_batch_size: Annotated[ @@ -58,8 +59,7 @@ class MemoryDBTypedDict(TypedDict): ] -class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): - ... +class MemoryDBHNSWTypedDict(CommonTypedDict, MemoryDBTypedDict, HNSWFlavor2): ... @cli.command() @@ -82,7 +82,7 @@ def MemoryDB(**parameters: Unpack[MemoryDBHNSWTypedDict]): M=parameters["m"], ef_construction=parameters["ef_construction"], ef_runtime=parameters["ef_runtime"], - insert_batch_size=parameters["insert_batch_size"] + insert_batch_size=parameters["insert_batch_size"], ), **parameters, - ) \ No newline at end of file + ) diff --git a/vectordb_bench/backend/clients/memorydb/config.py b/vectordb_bench/backend/clients/memorydb/config.py index 1284d3449..2c40ff546 100644 --- a/vectordb_bench/backend/clients/memorydb/config.py +++ b/vectordb_bench/backend/clients/memorydb/config.py @@ -29,7 +29,7 @@ class MemoryDBIndexConfig(BaseModel, DBCaseConfig): def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "l2" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "ip" return "cosine" @@ -51,4 +51,4 @@ def index_param(self) -> dict: def search_param(self) -> dict: return { "ef_runtime": self.ef_runtime, - } \ No newline at end of file + } diff --git a/vectordb_bench/backend/clients/memorydb/memorydb.py b/vectordb_bench/backend/clients/memorydb/memorydb.py index c5f80eb2a..d05e30be1 100644 --- a/vectordb_bench/backend/clients/memorydb/memorydb.py +++ b/vectordb_bench/backend/clients/memorydb/memorydb.py @@ -1,30 +1,33 @@ -import logging, time +import logging +import time +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator, Optional, Tuple, Type -from ..api import VectorDB, DBCaseConfig, IndexType -from .config import MemoryDBIndexConfig +from typing import Any + +import numpy as np import redis from redis import Redis from redis.cluster import RedisCluster -from redis.commands.search.field import TagField, VectorField, NumericField -from redis.commands.search.indexDefinition import IndexDefinition, IndexType +from redis.commands.search.field import NumericField, TagField, VectorField +from redis.commands.search.indexDefinition import IndexDefinition from redis.commands.search.query import Query -import numpy as np +from ..api import IndexType, VectorDB +from .config import MemoryDBIndexConfig log = logging.getLogger(__name__) -INDEX_NAME = "index" # Vector Index Name +INDEX_NAME = "index" # Vector Index Name + class MemoryDB(VectorDB): def __init__( - self, - dim: int, - db_config: dict, - db_case_config: MemoryDBIndexConfig, - drop_old: bool = False, - **kwargs - ): - + self, + dim: int, + db_config: dict, + db_case_config: MemoryDBIndexConfig, + drop_old: bool = False, + **kwargs, + ): self.db_config = db_config self.case_config = db_case_config self.collection_name = INDEX_NAME @@ -44,10 +47,10 @@ def __init__( info = conn.ft(INDEX_NAME).info() log.info(f"Index info: {info}") except redis.exceptions.ResponseError as e: - log.error(e) + log.warning(e) drop_old = False log.info(f"MemoryDB client drop_old collection: {self.collection_name}") - + log.info("Executing FLUSHALL") conn.flushall() @@ -59,7 +62,7 @@ def __init__( self.wait_until(self.wait_for_empty_db, 3, "", rc) log.debug(f"Flushall done in the host: {host}") rc.close() - + self.make_index(dim, conn) conn.close() conn = None @@ -69,7 +72,7 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis): # check to see if index exists conn.ft(INDEX_NAME).info() except Exception as e: - log.warn(f"Error getting info for index '{INDEX_NAME}': {e}") + log.warning(f"Error getting info for index '{INDEX_NAME}': {e}") index_param = self.case_config.index_param() search_param = self.case_config.search_param() vector_parameters = { # Vector Index Type: FLAT or HNSW @@ -85,17 +88,19 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis): vector_parameters["EF_RUNTIME"] = search_param["ef_runtime"] schema = ( - TagField("id"), - NumericField("metadata"), - VectorField("vector", # Vector Field Name - "HNSW", vector_parameters + TagField("id"), + NumericField("metadata"), + VectorField( + "vector", # Vector Field Name + "HNSW", + vector_parameters, ), ) definition = IndexDefinition(index_type=IndexType.HASH) rs = conn.ft(INDEX_NAME) rs.create_index(schema, definition=definition) - + def get_client(self, **kwargs): """ Gets either cluster connection or normal connection based on `cmd` flag. @@ -143,7 +148,7 @@ def get_client(self, **kwargs): @contextmanager def init(self) -> Generator[None, None, None]: - """ create and destory connections to database. + """create and destory connections to database. Examples: >>> with self.init(): @@ -170,7 +175,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception | None]: """Insert embeddings into the database. Should call self.init() first. """ @@ -178,12 +183,15 @@ def insert_embeddings( try: with self.conn.pipeline(transaction=False) as pipe: for i, embedding in enumerate(embeddings): - embedding = np.array(embedding).astype(np.float32) - pipe.hset(metadata[i], mapping = { - "id": str(metadata[i]), - "metadata": metadata[i], - "vector": embedding.tobytes(), - }) + ndarr_emb = np.array(embedding).astype(np.float32) + pipe.hset( + metadata[i], + mapping={ + "id": str(metadata[i]), + "metadata": metadata[i], + "vector": ndarr_emb.tobytes(), + }, + ) # Execute the pipe so we don't keep too much in memory at once if (i + 1) % self.insert_batch_size == 0: pipe.execute() @@ -192,9 +200,9 @@ def insert_embeddings( result_len = i + 1 except Exception as e: return 0, e - + return result_len, None - + def _post_insert(self): """Wait for indexing to finish""" client = self.get_client(primary=True) @@ -208,21 +216,17 @@ def _post_insert(self): self.wait_until(*args) log.debug(f"Background indexing completed in the host: {host_name}") rc.close() - - def wait_until( - self, condition, interval=5, message="Operation took too long", *args - ): + + def wait_until(self, condition: any, interval: int = 5, message: str = "Operation took too long", *args): while not condition(*args): time.sleep(interval) - + def wait_for_no_activity(self, client: redis.RedisCluster | redis.Redis): - return ( - client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY" - ) - + return client.info("search")["search_background_indexing_status"] == "NO_ACTIVITY" + def wait_for_empty_db(self, client: redis.RedisCluster | redis.Redis): return client.execute_command("DBSIZE") == 0 - + def search_embedding( self, query: list[float], @@ -230,13 +234,13 @@ def search_embedding( filters: dict | None = None, timeout: int | None = None, **kwargs: Any, - ) -> (list[int]): + ) -> list[int]: assert self.conn is not None - + query_vector = np.array(query).astype(np.float32).tobytes() query_obj = Query(f"*=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) query_params = {"vec": query_vector} - + if filters: # benchmark test filters of format: {'metadata': '>=10000', 'id': 10000} # gets exact match for id, and range for metadata if they exist in filters @@ -244,11 +248,19 @@ def search_embedding( # Removing '>=' from the id_value: '>=10000' metadata_value = filters.get("metadata")[2:] if id_value and metadata_value: - query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) + query_obj = ( + Query( + f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec]", + ) + .return_fields("id") + .paging(0, k) + ) elif id_value: - #gets exact match for id + # gets exact match for id query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) - else: #metadata only case, greater than or equal to metadata value - query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) + else: # metadata only case, greater than or equal to metadata value + query_obj = ( + Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec]").return_fields("id").paging(0, k) + ) res = self.conn.ft(INDEX_NAME).search(query_obj, query_params) - return [int(doc["id"]) for doc in res.docs] \ No newline at end of file + return [int(doc["id"]) for doc in res.docs] diff --git a/vectordb_bench/backend/clients/milvus/cli.py b/vectordb_bench/backend/clients/milvus/cli.py index 885995def..51ea82eff 100644 --- a/vectordb_bench/backend/clients/milvus/cli.py +++ b/vectordb_bench/backend/clients/milvus/cli.py @@ -1,8 +1,9 @@ -from typing import Annotated, TypedDict, Unpack, Optional +from typing import Annotated, TypedDict, Unpack import click from pydantic import SecretStr +from vectordb_bench.backend.clients import DB from vectordb_bench.cli.cli import ( CommonTypedDict, HNSWFlavor3, @@ -10,33 +11,33 @@ cli, click_parameter_decorators_from_typed_dict, run, - ) -from vectordb_bench.backend.clients import DB DBTYPE = DB.Milvus class MilvusTypedDict(TypedDict): uri: Annotated[ - str, click.option("--uri", type=str, help="uri connection string", required=True) + str, + click.option("--uri", type=str, help="uri connection string", required=True), ] user_name: Annotated[ - Optional[str], click.option("--user-name", type=str, help="Db username", required=False) + str | None, + click.option("--user-name", type=str, help="Db username", required=False), ] password: Annotated[ - Optional[str], click.option("--password", type=str, help="Db password", required=False) + str | None, + click.option("--password", type=str, help="Db password", required=False), ] -class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): - ... +class MilvusAutoIndexTypedDict(CommonTypedDict, MilvusTypedDict): ... @cli.command() @click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict) def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]): - from .config import MilvusConfig, AutoIndexConfig + from .config import AutoIndexConfig, MilvusConfig run( db=DBTYPE, @@ -54,7 +55,7 @@ def MilvusAutoIndex(**parameters: Unpack[MilvusAutoIndexTypedDict]): @cli.command() @click_parameter_decorators_from_typed_dict(MilvusAutoIndexTypedDict) def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]): - from .config import MilvusConfig, FLATConfig + from .config import FLATConfig, MilvusConfig run( db=DBTYPE, @@ -69,14 +70,13 @@ def MilvusFlat(**parameters: Unpack[MilvusAutoIndexTypedDict]): ) -class MilvusHNSWTypedDict(CommonTypedDict, MilvusTypedDict, HNSWFlavor3): - ... +class MilvusHNSWTypedDict(CommonTypedDict, MilvusTypedDict, HNSWFlavor3): ... @cli.command() @click_parameter_decorators_from_typed_dict(MilvusHNSWTypedDict) def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]): - from .config import MilvusConfig, HNSWConfig + from .config import HNSWConfig, MilvusConfig run( db=DBTYPE, @@ -95,14 +95,13 @@ def MilvusHNSW(**parameters: Unpack[MilvusHNSWTypedDict]): ) -class MilvusIVFFlatTypedDict(CommonTypedDict, MilvusTypedDict, IVFFlatTypedDictN): - ... +class MilvusIVFFlatTypedDict(CommonTypedDict, MilvusTypedDict, IVFFlatTypedDictN): ... @cli.command() @click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict) def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]): - from .config import MilvusConfig, IVFFlatConfig + from .config import IVFFlatConfig, MilvusConfig run( db=DBTYPE, @@ -123,7 +122,7 @@ def MilvusIVFFlat(**parameters: Unpack[MilvusIVFFlatTypedDict]): @cli.command() @click_parameter_decorators_from_typed_dict(MilvusIVFFlatTypedDict) def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]): - from .config import MilvusConfig, IVFSQ8Config + from .config import IVFSQ8Config, MilvusConfig run( db=DBTYPE, @@ -142,17 +141,13 @@ def MilvusIVFSQ8(**parameters: Unpack[MilvusIVFFlatTypedDict]): class MilvusDISKANNTypedDict(CommonTypedDict, MilvusTypedDict): - search_list: Annotated[ - str, click.option("--search-list", - type=int, - required=True) - ] + search_list: Annotated[str, click.option("--search-list", type=int, required=True)] @cli.command() @click_parameter_decorators_from_typed_dict(MilvusDISKANNTypedDict) def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]): - from .config import MilvusConfig, DISKANNConfig + from .config import DISKANNConfig, MilvusConfig run( db=DBTYPE, @@ -171,21 +166,16 @@ def MilvusDISKANN(**parameters: Unpack[MilvusDISKANNTypedDict]): class MilvusGPUIVFTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict): cache_dataset_on_device: Annotated[ - str, click.option("--cache-dataset-on-device", - type=str, - required=True) - ] - refine_ratio: Annotated[ - str, click.option("--refine-ratio", - type=float, - required=True) + str, + click.option("--cache-dataset-on-device", type=str, required=True), ] + refine_ratio: Annotated[str, click.option("--refine-ratio", type=float, required=True)] @cli.command() @click_parameter_decorators_from_typed_dict(MilvusGPUIVFTypedDict) def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): - from .config import MilvusConfig, GPUIVFFlatConfig + from .config import GPUIVFFlatConfig, MilvusConfig run( db=DBTYPE, @@ -205,23 +195,20 @@ def MilvusGPUIVFFlat(**parameters: Unpack[MilvusGPUIVFTypedDict]): ) -class MilvusGPUIVFPQTypedDict(CommonTypedDict, MilvusTypedDict, MilvusIVFFlatTypedDict, MilvusGPUIVFTypedDict): - m: Annotated[ - str, click.option("--m", - type=int, help="hnsw m", - required=True) - ] - nbits: Annotated[ - str, click.option("--nbits", - type=int, - required=True) - ] +class MilvusGPUIVFPQTypedDict( + CommonTypedDict, + MilvusTypedDict, + MilvusIVFFlatTypedDict, + MilvusGPUIVFTypedDict, +): + m: Annotated[str, click.option("--m", type=int, help="hnsw m", required=True)] + nbits: Annotated[str, click.option("--nbits", type=int, required=True)] @cli.command() @click_parameter_decorators_from_typed_dict(MilvusGPUIVFPQTypedDict) def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]): - from .config import MilvusConfig, GPUIVFPQConfig + from .config import GPUIVFPQConfig, MilvusConfig run( db=DBTYPE, @@ -245,51 +232,22 @@ def MilvusGPUIVFPQ(**parameters: Unpack[MilvusGPUIVFPQTypedDict]): class MilvusGPUCAGRATypedDict(CommonTypedDict, MilvusTypedDict, MilvusGPUIVFTypedDict): intermediate_graph_degree: Annotated[ - str, click.option("--intermediate-graph-degree", - type=int, - required=True) - ] - graph_degree: Annotated[ - str, click.option("--graph-degree", - type=int, - required=True) - ] - build_algo: Annotated[ - str, click.option("--build_algo", - type=str, - required=True) - ] - team_size: Annotated[ - str, click.option("--team-size", - type=int, - required=True) - ] - search_width: Annotated[ - str, click.option("--search-width", - type=int, - required=True) - ] - itopk_size: Annotated[ - str, click.option("--itopk-size", - type=int, - required=True) - ] - min_iterations: Annotated[ - str, click.option("--min-iterations", - type=int, - required=True) - ] - max_iterations: Annotated[ - str, click.option("--max-iterations", - type=int, - required=True) + str, + click.option("--intermediate-graph-degree", type=int, required=True), ] + graph_degree: Annotated[str, click.option("--graph-degree", type=int, required=True)] + build_algo: Annotated[str, click.option("--build_algo", type=str, required=True)] + team_size: Annotated[str, click.option("--team-size", type=int, required=True)] + search_width: Annotated[str, click.option("--search-width", type=int, required=True)] + itopk_size: Annotated[str, click.option("--itopk-size", type=int, required=True)] + min_iterations: Annotated[str, click.option("--min-iterations", type=int, required=True)] + max_iterations: Annotated[str, click.option("--max-iterations", type=int, required=True)] @cli.command() @click_parameter_decorators_from_typed_dict(MilvusGPUCAGRATypedDict) def MilvusGPUCAGRA(**parameters: Unpack[MilvusGPUCAGRATypedDict]): - from .config import MilvusConfig, GPUCAGRAConfig + from .config import GPUCAGRAConfig, MilvusConfig run( db=DBTYPE, diff --git a/vectordb_bench/backend/clients/milvus/config.py b/vectordb_bench/backend/clients/milvus/config.py index 059ef0461..7d0df803a 100644 --- a/vectordb_bench/backend/clients/milvus/config.py +++ b/vectordb_bench/backend/clients/milvus/config.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, SecretStr, validator -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType class MilvusConfig(DBConfig): @@ -15,10 +16,14 @@ def to_dict(self) -> dict: } @validator("*") - def not_empty_field(cls, v, field): - if field.name in cls.common_short_configs() or field.name in cls.common_long_configs() or field.name in ["user", "password"]: + def not_empty_field(cls, v: any, field: any): + if ( + field.name in cls.common_short_configs() + or field.name in cls.common_long_configs() + or field.name in ["user", "password"] + ): return v - if isinstance(v, (str, SecretStr)) and len(v) == 0: + if isinstance(v, str | SecretStr) and len(v) == 0: raise ValueError("Empty string!") return v @@ -28,10 +33,14 @@ class MilvusIndexConfig(BaseModel): index: IndexType metric_type: MetricType | None = None - + @property def is_gpu_index(self) -> bool: - return self.index in [IndexType.GPU_CAGRA, IndexType.GPU_IVF_FLAT, IndexType.GPU_IVF_PQ] + return self.index in [ + IndexType.GPU_CAGRA, + IndexType.GPU_IVF_FLAT, + IndexType.GPU_IVF_PQ, + ] def parse_metric(self) -> str: if not self.metric_type: @@ -113,7 +122,8 @@ def search_param(self) -> dict: "metric_type": self.parse_metric(), "params": {"nprobe": self.nprobe}, } - + + class IVFSQ8Config(MilvusIndexConfig, DBCaseConfig): nlist: int nprobe: int | None = None @@ -210,7 +220,7 @@ class GPUCAGRAConfig(MilvusIndexConfig, DBCaseConfig): search_width: int = 4 min_iterations: int = 0 max_iterations: int = 0 - build_algo: str = "IVF_PQ" # IVF_PQ; NN_DESCENT; + build_algo: str = "IVF_PQ" # IVF_PQ; NN_DESCENT; cache_dataset_on_device: str refine_ratio: float | None = None index: IndexType = IndexType.GPU_CAGRA diff --git a/vectordb_bench/backend/clients/milvus/milvus.py b/vectordb_bench/backend/clients/milvus/milvus.py index 251dee8ad..45fe7269b 100644 --- a/vectordb_bench/backend/clients/milvus/milvus.py +++ b/vectordb_bench/backend/clients/milvus/milvus.py @@ -2,19 +2,18 @@ import logging import time +from collections.abc import Iterable from contextlib import contextmanager -from typing import Iterable -from pymilvus import Collection, utility -from pymilvus import CollectionSchema, DataType, FieldSchema, MilvusException +from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, MilvusException, utility from ..api import VectorDB from .config import MilvusIndexConfig - log = logging.getLogger(__name__) -MILVUS_LOAD_REQS_SIZE = 1.5 * 1024 *1024 +MILVUS_LOAD_REQS_SIZE = 1.5 * 1024 * 1024 + class Milvus(VectorDB): def __init__( @@ -32,7 +31,7 @@ def __init__( self.db_config = db_config self.case_config = db_case_config self.collection_name = collection_name - self.batch_size = int(MILVUS_LOAD_REQS_SIZE / (dim *4)) + self.batch_size = int(MILVUS_LOAD_REQS_SIZE / (dim * 4)) self._primary_field = "pk" self._scalar_field = "id" @@ -40,6 +39,7 @@ def __init__( self._index_name = "vector_idx" from pymilvus import connections + connections.connect(**self.db_config, timeout=30) if drop_old and utility.has_collection(self.collection_name): log.info(f"{self.name} client drop_old collection: {self.collection_name}") @@ -49,7 +49,7 @@ def __init__( fields = [ FieldSchema(self._primary_field, DataType.INT64, is_primary=True), FieldSchema(self._scalar_field, DataType.INT64), - FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim) + FieldSchema(self._vector_field, DataType.FLOAT_VECTOR, dim=dim), ] log.info(f"{self.name} create collection: {self.collection_name}") @@ -79,6 +79,7 @@ def init(self) -> None: >>> self.search_embedding() """ from pymilvus import connections + self.col: Collection | None = None connections.connect(**self.db_config, timeout=60) @@ -108,6 +109,7 @@ def _post_insert(self): ) utility.wait_for_index_building_complete(self.collection_name) + def wait_index(): while True: progress = utility.index_building_progress(self.collection_name) @@ -120,18 +122,17 @@ def wait_index(): # Skip compaction if use GPU indexType if self.case_config.is_gpu_index: log.debug("skip compaction for gpu index type.") - else : + else: try: self.col.compact() self.col.wait_for_compaction_completed() except Exception as e: log.warning(f"{self.name} compact error: {e}") - if hasattr(e, 'code'): - if e.code().name == 'PERMISSION_DENIED': + if hasattr(e, "code"): + if e.code().name == "PERMISSION_DENIED": log.warning("Skip compact due to permission denied.") - pass else: - raise e + raise e from e wait_index() except Exception as e: log.warning(f"{self.name} optimize error: {e}") @@ -156,7 +157,6 @@ def _pre_load(self, coll: Collection): log.warning(f"{self.name} pre load error: {e}") raise e from None - def optimize(self): assert self.col, "Please call self.init() before" self._optimize() @@ -164,7 +164,7 @@ def optimize(self): def need_normalize_cosine(self) -> bool: """Wheather this database need to normalize dataset to support COSINE""" if self.case_config.is_gpu_index: - log.info(f"current gpu_index only supports IP / L2, cosine dataset need normalize.") + log.info("current gpu_index only supports IP / L2, cosine dataset need normalize.") return True return False @@ -184,9 +184,9 @@ def insert_embeddings( for batch_start_offset in range(0, len(embeddings), self.batch_size): batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) insert_data = [ - metadata[batch_start_offset : batch_end_offset], - metadata[batch_start_offset : batch_end_offset], - embeddings[batch_start_offset : batch_end_offset], + metadata[batch_start_offset:batch_end_offset], + metadata[batch_start_offset:batch_end_offset], + embeddings[batch_start_offset:batch_end_offset], ] res = self.col.insert(insert_data) insert_count += len(res.primary_keys) @@ -217,5 +217,4 @@ def search_embedding( ) # Organize results. - ret = [result.id for result in res[0]] - return ret + return [result.id for result in res[0]] diff --git a/vectordb_bench/backend/clients/pgdiskann/cli.py b/vectordb_bench/backend/clients/pgdiskann/cli.py index 18a9ecbd5..19f47988f 100644 --- a/vectordb_bench/backend/clients/pgdiskann/cli.py +++ b/vectordb_bench/backend/clients/pgdiskann/cli.py @@ -1,57 +1,63 @@ -import click import os +from typing import Annotated, Unpack + +import click from pydantic import SecretStr +from vectordb_bench.backend.clients import DB + from ....cli.cli import ( CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run, ) -from typing import Annotated, Optional, Unpack -from vectordb_bench.backend.clients import DB class PgDiskAnnTypedDict(CommonTypedDict): user_name: Annotated[ - str, click.option("--user-name", type=str, help="Db username", required=True) + str, + click.option("--user-name", type=str, help="Db username", required=True), ] password: Annotated[ str, - click.option("--password", - type=str, - help="Postgres database password", - default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), - show_default="$POSTGRES_PASSWORD", - ), + click.option( + "--password", + type=str, + help="Postgres database password", + default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), + show_default="$POSTGRES_PASSWORD", + ), ] - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] - db_name: Annotated[ - str, click.option("--db-name", type=str, help="Db name", required=True) - ] + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] + db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)] max_neighbors: Annotated[ int, click.option( - "--max-neighbors", type=int, help="PgDiskAnn max neighbors", + "--max-neighbors", + type=int, + help="PgDiskAnn max neighbors", ), ] l_value_ib: Annotated[ int, click.option( - "--l-value-ib", type=int, help="PgDiskAnn l_value_ib", + "--l-value-ib", + type=int, + help="PgDiskAnn l_value_ib", ), ] l_value_is: Annotated[ float, click.option( - "--l-value-is", type=float, help="PgDiskAnn l_value_is", + "--l-value-is", + type=float, + help="PgDiskAnn l_value_is", ), ] maintenance_work_mem: Annotated[ - Optional[str], + str | None, click.option( "--maintenance-work-mem", type=str, @@ -63,7 +69,7 @@ class PgDiskAnnTypedDict(CommonTypedDict): ), ] max_parallel_workers: Annotated[ - Optional[int], + int | None, click.option( "--max-parallel-workers", type=int, @@ -72,6 +78,7 @@ class PgDiskAnnTypedDict(CommonTypedDict): ), ] + @cli.command() @click_parameter_decorators_from_typed_dict(PgDiskAnnTypedDict) def PgDiskAnn( @@ -96,4 +103,4 @@ def PgDiskAnn( maintenance_work_mem=parameters["maintenance_work_mem"], ), **parameters, - ) \ No newline at end of file + ) diff --git a/vectordb_bench/backend/clients/pgdiskann/config.py b/vectordb_bench/backend/clients/pgdiskann/config.py index 970720afa..ed478acc2 100644 --- a/vectordb_bench/backend/clients/pgdiskann/config.py +++ b/vectordb_bench/backend/clients/pgdiskann/config.py @@ -1,7 +1,9 @@ from abc import abstractmethod -from typing import Any, Mapping, Optional, Sequence, TypedDict +from collections.abc import Mapping, Sequence +from typing import Any, LiteralString, TypedDict + from pydantic import BaseModel, SecretStr -from typing_extensions import LiteralString + from ..api import DBCaseConfig, DBConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" @@ -9,7 +11,7 @@ class PgDiskANNConfigDict(TypedDict): """These keys will be directly used as kwargs in psycopg connection string, - so the names must match exactly psycopg API""" + so the names must match exactly psycopg API""" user: str password: str @@ -41,44 +43,43 @@ class PgDiskANNIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None create_index_before_load: bool = False create_index_after_load: bool = True - maintenance_work_mem: Optional[str] - max_parallel_workers: Optional[int] + maintenance_work_mem: str | None + max_parallel_workers: int | None def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "vector_l2_ops" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "vector_ip_ops" return "vector_cosine_ops" def parse_metric_fun_op(self) -> LiteralString: if self.metric_type == MetricType.L2: return "<->" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "<#>" return "<=>" def parse_metric_fun_str(self) -> str: if self.metric_type == MetricType.L2: return "l2_distance" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "max_inner_product" return "cosine_distance" - + @abstractmethod - def index_param(self) -> dict: - ... + def index_param(self) -> dict: ... @abstractmethod - def search_param(self) -> dict: - ... + def search_param(self) -> dict: ... @abstractmethod - def session_param(self) -> dict: - ... + def session_param(self) -> dict: ... @staticmethod - def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]: + def _optionally_build_with_options( + with_options: Mapping[str, Any], + ) -> Sequence[dict[str, Any]]: """Walk through mappings, creating a List of {key1 = value} pairs. That will be used to build a where clause""" options = [] for option_name, value in with_options.items(): @@ -87,35 +88,36 @@ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[ { "option_name": option_name, "val": str(value), - } + }, ) return options @staticmethod def _optionally_build_set_options( - set_mapping: Mapping[str, Any] + set_mapping: Mapping[str, Any], ) -> Sequence[dict[str, Any]]: """Walk through options, creating 'SET 'key1 = "value1";' list""" session_options = [] for setting_name, value in set_mapping.items(): if value: session_options.append( - {"parameter": { + { + "parameter": { "setting_name": setting_name, "val": str(value), }, - } + }, ) return session_options - + class PgDiskANNImplConfig(PgDiskANNIndexConfig): index: IndexType = IndexType.DISKANN max_neighbors: int | None l_value_ib: int | None l_value_is: float | None - maintenance_work_mem: Optional[str] = None - max_parallel_workers: Optional[int] = None + maintenance_work_mem: str | None = None + max_parallel_workers: int | None = None def index_param(self) -> dict: return { @@ -128,18 +130,19 @@ def index_param(self) -> dict: "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, } - + def search_param(self) -> dict: return { "metric": self.parse_metric(), "metric_fun_op": self.parse_metric_fun_op(), } - + def session_param(self) -> dict: return { "diskann.l_value_is": self.l_value_is, } - + + _pgdiskann_case_config = { IndexType.DISKANN: PgDiskANNImplConfig, } diff --git a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py index c363490f7..c21972902 100644 --- a/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py +++ b/vectordb_bench/backend/clients/pgdiskann/pgdiskann.py @@ -1,9 +1,9 @@ """Wrapper around the pg_diskann vector database over VectorDB""" import logging -import pprint +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator, Optional, Tuple +from typing import Any import numpy as np import psycopg @@ -44,20 +44,21 @@ def __init__( self._primary_field = "id" self._vector_field = "embedding" - self.conn, self.cursor = self._create_connection(**self.db_config) + self.conn, self.cursor = self._create_connection(**self.db_config) log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}") if not any( ( self.case_config.create_index_before_load, self.case_config.create_index_after_load, - ) + ), ): - err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" - log.error(err) - raise RuntimeError( - f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + msg = ( + f"{self.name} config must create an index using create_index_before_load or create_index_after_load" + f"{self.name} config values: {self.db_config}\n{self.case_config}" ) + log.error(msg) + raise RuntimeError(msg) if drop_old: self._drop_index() @@ -72,7 +73,7 @@ def __init__( self.conn = None @staticmethod - def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + def _create_connection(**kwargs) -> tuple[Connection, Cursor]: conn = psycopg.connect(**kwargs) conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS pg_diskann CASCADE") conn.commit() @@ -101,25 +102,25 @@ def init(self) -> Generator[None, None, None]: log.debug(command.as_string(self.cursor)) self.cursor.execute(command) self.conn.commit() - + self._filtered_search = sql.Composed( [ sql.SQL( - "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " - ).format(table_name=sql.Identifier(self.table_name)), + "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ", + ).format(table_name=sql.Identifier(self.table_name)), sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), - ] + ], ) self._unfiltered_search = sql.Composed( [ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( - sql.Identifier(self.table_name) + sql.Identifier(self.table_name), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), - ] + ], ) try: @@ -137,8 +138,8 @@ def _drop_table(self): self.cursor.execute( sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) self.conn.commit() @@ -160,7 +161,7 @@ def _drop_index(self): log.info(f"{self.name} client drop index : {self._index_name}") drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( - index_name=sql.Identifier(self._index_name) + index_name=sql.Identifier(self._index_name), ) log.debug(drop_index_sql.as_string(self.cursor)) self.cursor.execute(drop_index_sql) @@ -175,64 +176,53 @@ def _set_parallel_index_build_param(self): if index_param["maintenance_work_mem"] is not None: self.cursor.execute( sql.SQL("SET maintenance_work_mem TO {};").format( - index_param["maintenance_work_mem"] - ) + index_param["maintenance_work_mem"], + ), ) self.cursor.execute( sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format( sql.Identifier(self.db_config["user"]), index_param["maintenance_work_mem"], - ) + ), ) self.conn.commit() if index_param["max_parallel_workers"] is not None: self.cursor.execute( sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format( - index_param["max_parallel_workers"] - ) + index_param["max_parallel_workers"], + ), ) self.cursor.execute( - sql.SQL( - "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';" - ).format( + sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format( sql.Identifier(self.db_config["user"]), index_param["max_parallel_workers"], - ) + ), ) self.cursor.execute( sql.SQL("SET max_parallel_workers TO '{}';").format( - index_param["max_parallel_workers"] - ) + index_param["max_parallel_workers"], + ), ) self.cursor.execute( - sql.SQL( - "ALTER USER {} SET max_parallel_workers TO '{}';" - ).format( + sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format( sql.Identifier(self.db_config["user"]), index_param["max_parallel_workers"], - ) + ), ) self.cursor.execute( - sql.SQL( - "ALTER TABLE {} SET (parallel_workers = {});" - ).format( + sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format( sql.Identifier(self.table_name), index_param["max_parallel_workers"], - ) + ), ) self.conn.commit() - results = self.cursor.execute( - sql.SQL("SHOW max_parallel_maintenance_workers;") - ).fetchall() - results.extend( - self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall() - ) - results.extend( - self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall() - ) + results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall() + results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()) + results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()) log.info(f"{self.name} parallel index creation parameters: {results}") + def _create_index(self): assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -248,28 +238,23 @@ def _create_index(self): sql.SQL("{option_name} = {val}").format( option_name=sql.Identifier(option_name), val=sql.Identifier(str(option_val)), - ) + ), ) - - if any(options): - with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) - else: - with_clause = sql.Composed(()) + + with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) index_create_sql = sql.SQL( """ - CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric}) - """ + """, ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), index_type=sql.Identifier(index_param["index_type"].lower()), embedding_metric=sql.Identifier(index_param["metric"]), ) - index_create_sql_with_with_clause = ( - index_create_sql + with_clause - ).join(" ") + index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ") log.debug(index_create_sql_with_with_clause.as_string(self.cursor)) self.cursor.execute(index_create_sql_with_with_clause) self.conn.commit() @@ -283,14 +268,12 @@ def _create_table(self, dim: int): self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" - ).format(table_name=sql.Identifier(self.table_name), dim=dim) + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", + ).format(table_name=sql.Identifier(self.table_name), dim=dim), ) self.conn.commit() except Exception as e: - log.warning( - f"Failed to create pgdiskann table: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create pgdiskann table: {self.table_name} error: {e}") raise e from None def insert_embeddings( @@ -298,7 +281,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -308,8 +291,8 @@ def insert_embeddings( with self.cursor.copy( sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) as copy: copy.set_types(["bigint", "vector"]) for i, row in enumerate(metadata_arr): @@ -321,9 +304,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: - log.warning( - f"Failed to insert data into table ({self.table_name}), error: {e}" - ) + log.warning(f"Failed to insert data into table ({self.table_name}), error: {e}") return 0, e def search_embedding( @@ -340,11 +321,12 @@ def search_embedding( if filters: gt = filters.get("id") result = self.cursor.execute( - self._filtered_search, (gt, q, k), prepare=True, binary=True - ) + self._filtered_search, + (gt, q, k), + prepare=True, + binary=True, + ) else: - result = self.cursor.execute( - self._unfiltered_search, (q, k), prepare=True, binary=True - ) + result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True) return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/backend/clients/pgvecto_rs/cli.py b/vectordb_bench/backend/clients/pgvecto_rs/cli.py index 10dbff556..24d2cf800 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/cli.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/cli.py @@ -1,9 +1,11 @@ -from typing import Annotated, Optional, Unpack +import os +from typing import Annotated, Unpack import click -import os from pydantic import SecretStr +from vectordb_bench.backend.clients import DB + from ....cli.cli import ( CommonTypedDict, HNSWFlavor1, @@ -12,12 +14,12 @@ click_parameter_decorators_from_typed_dict, run, ) -from vectordb_bench.backend.clients import DB class PgVectoRSTypedDict(CommonTypedDict): user_name: Annotated[ - str, click.option("--user-name", type=str, help="Db username", required=True) + str, + click.option("--user-name", type=str, help="Db username", required=True), ] password: Annotated[ str, @@ -30,14 +32,10 @@ class PgVectoRSTypedDict(CommonTypedDict): ), ] - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] - db_name: Annotated[ - str, click.option("--db-name", type=str, help="Db name", required=True) - ] + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] + db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)] max_parallel_workers: Annotated[ - Optional[int], + int | None, click.option( "--max-parallel-workers", type=int, diff --git a/vectordb_bench/backend/clients/pgvecto_rs/config.py b/vectordb_bench/backend/clients/pgvecto_rs/config.py index c671a236c..fbb7c5d81 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/config.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/config.py @@ -1,11 +1,11 @@ from abc import abstractmethod from typing import TypedDict +from pgvecto_rs.types import Flat, Hnsw, IndexOption, Ivf, Quantization +from pgvecto_rs.types.index import QuantizationRatio, QuantizationType from pydantic import BaseModel, SecretStr -from pgvecto_rs.types import IndexOption, Ivf, Hnsw, Flat, Quantization -from pgvecto_rs.types.index import QuantizationType, QuantizationRatio -from ..api import DBConfig, DBCaseConfig, IndexType, MetricType +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" @@ -52,14 +52,14 @@ class PgVectoRSIndexConfig(BaseModel, DBCaseConfig): def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "vector_l2_ops" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "vector_dot_ops" return "vector_cos_ops" def parse_metric_fun_op(self) -> str: if self.metric_type == MetricType.L2: return "<->" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "<#>" return "<=>" @@ -85,9 +85,7 @@ def index_param(self) -> dict[str, str]: if self.quantization_type is None: quantization = None else: - quantization = Quantization( - typ=self.quantization_type, ratio=self.quantization_ratio - ) + quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio) option = IndexOption( index=Hnsw( @@ -115,9 +113,7 @@ def index_param(self) -> dict[str, str]: if self.quantization_type is None: quantization = None else: - quantization = Quantization( - typ=self.quantization_type, ratio=self.quantization_ratio - ) + quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio) option = IndexOption( index=Ivf(nlist=self.lists, quantization=quantization), @@ -139,9 +135,7 @@ def index_param(self) -> dict[str, str]: if self.quantization_type is None: quantization = None else: - quantization = Quantization( - typ=self.quantization_type, ratio=self.quantization_ratio - ) + quantization = Quantization(typ=self.quantization_type, ratio=self.quantization_ratio) option = IndexOption( index=Flat( diff --git a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py index bc042cc57..fc4f17807 100644 --- a/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py +++ b/vectordb_bench/backend/clients/pgvecto_rs/pgvecto_rs.py @@ -1,14 +1,14 @@ """Wrapper around the Pgvecto.rs vector database over VectorDB""" import logging -import pprint +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator, Optional, Tuple +from typing import Any import numpy as np import psycopg -from psycopg import Connection, Cursor, sql from pgvecto_rs.psycopg import register_vector +from psycopg import Connection, Cursor, sql from ..api import VectorDB from .config import PgVectoRSConfig, PgVectoRSIndexConfig @@ -33,7 +33,6 @@ def __init__( drop_old: bool = False, **kwargs, ): - self.name = "PgVectorRS" self.db_config = db_config self.case_config = db_case_config @@ -52,13 +51,14 @@ def __init__( ( self.case_config.create_index_before_load, self.case_config.create_index_after_load, - ) + ), ): - err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" - log.error(err) - raise RuntimeError( - f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + msg = ( + f"{self.name} config must create an index using create_index_before_load or create_index_after_load" + f"{self.name} config values: {self.db_config}\n{self.case_config}" ) + log.error(msg) + raise RuntimeError(msg) if drop_old: log.info(f"Pgvecto.rs client drop table : {self.table_name}") @@ -74,7 +74,7 @@ def __init__( self.conn = None @staticmethod - def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + def _create_connection(**kwargs) -> tuple[Connection, Cursor]: conn = psycopg.connect(**kwargs) # create vector extension @@ -116,21 +116,21 @@ def init(self) -> Generator[None, None, None]: self._filtered_search = sql.Composed( [ sql.SQL( - "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding " + "SELECT id FROM public.{table_name} WHERE id >= %s ORDER BY embedding ", ).format(table_name=sql.Identifier(self.table_name)), sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), - ] + ], ) self._unfiltered_search = sql.Composed( [ - sql.SQL( - "SELECT id FROM public.{table_name} ORDER BY embedding " - ).format(table_name=sql.Identifier(self.table_name)), + sql.SQL("SELECT id FROM public.{table_name} ORDER BY embedding ").format( + table_name=sql.Identifier(self.table_name), + ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), - ] + ], ) try: @@ -148,8 +148,8 @@ def _drop_table(self): self.cursor.execute( sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) self.conn.commit() @@ -171,7 +171,7 @@ def _drop_index(self): log.info(f"{self.name} client drop index : {self._index_name}") drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( - index_name=sql.Identifier(self._index_name) + index_name=sql.Identifier(self._index_name), ) log.debug(drop_index_sql.as_string(self.cursor)) self.cursor.execute(drop_index_sql) @@ -186,9 +186,9 @@ def _create_index(self): index_create_sql = sql.SQL( """ - CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING vectors (embedding {embedding_metric}) WITH (options = {index_options}) - """ + """, ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), @@ -202,7 +202,7 @@ def _create_index(self): except Exception as e: log.warning( f"Failed to create pgvecto.rs index {self._index_name} \ - at table {self.table_name} error: {e}" + at table {self.table_name} error: {e}", ) raise e from None @@ -214,7 +214,7 @@ def _create_table(self, dim: int): """ CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim})) - """ + """, ).format( table_name=sql.Identifier(self.table_name), dim=dim, @@ -224,9 +224,7 @@ def _create_table(self, dim: int): self.cursor.execute(table_create_sql) self.conn.commit() except Exception as e: - log.warning( - f"Failed to create pgvecto.rs table: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create pgvecto.rs table: {self.table_name} error: {e}") raise e from None def insert_embeddings( @@ -234,7 +232,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -247,8 +245,8 @@ def insert_embeddings( with self.cursor.copy( sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) as copy: copy.set_types(["bigint", "vector"]) for i, row in enumerate(metadata_arr): @@ -261,7 +259,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: log.warning( - f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}" + f"Failed to insert data into pgvecto.rs table ({self.table_name}), error: {e}", ) return 0, e @@ -281,12 +279,13 @@ def search_embedding( log.debug(self._filtered_search.as_string(self.cursor)) gt = filters.get("id") result = self.cursor.execute( - self._filtered_search, (gt, q, k), prepare=True, binary=True + self._filtered_search, + (gt, q, k), + prepare=True, + binary=True, ) else: log.debug(self._unfiltered_search.as_string(self.cursor)) - result = self.cursor.execute( - self._unfiltered_search, (q, k), prepare=True, binary=True - ) + result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True) return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/backend/clients/pgvector/cli.py b/vectordb_bench/backend/clients/pgvector/cli.py index ef8914be0..55a462055 100644 --- a/vectordb_bench/backend/clients/pgvector/cli.py +++ b/vectordb_bench/backend/clients/pgvector/cli.py @@ -1,9 +1,10 @@ -from typing import Annotated, Optional, TypedDict, Unpack +import os +from typing import Annotated, Unpack import click -import os from pydantic import SecretStr +from vectordb_bench.backend.clients import DB from vectordb_bench.backend.clients.api import MetricType from ....cli.cli import ( @@ -15,49 +16,48 @@ get_custom_case_config, run, ) -from vectordb_bench.backend.clients import DB -def set_default_quantized_fetch_limit(ctx, param, value): +# ruff: noqa +def set_default_quantized_fetch_limit(ctx: any, param: any, value: any): if ctx.params.get("reranking") and value is None: # ef_search is the default value for quantized_fetch_limit as it's bound by ef_search. # 100 is default value for quantized_fetch_limit for IVFFlat. - default_value = ctx.params["ef_search"] if ctx.command.name == "pgvectorhnsw" else 100 - return default_value + return ctx.params["ef_search"] if ctx.command.name == "pgvectorhnsw" else 100 return value + class PgVectorTypedDict(CommonTypedDict): user_name: Annotated[ - str, click.option("--user-name", type=str, help="Db username", required=True) + str, + click.option("--user-name", type=str, help="Db username", required=True), ] password: Annotated[ str, - click.option("--password", - type=str, - help="Postgres database password", - default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), - show_default="$POSTGRES_PASSWORD", - ), + click.option( + "--password", + type=str, + help="Postgres database password", + default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), + show_default="$POSTGRES_PASSWORD", + ), ] - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] - port: Annotated[ + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] + port: Annotated[ int, - click.option("--port", - type=int, - help="Postgres database port", - default=5432, - show_default=True, - required=False - ), - ] - db_name: Annotated[ - str, click.option("--db-name", type=str, help="Db name", required=True) + click.option( + "--port", + type=int, + help="Postgres database port", + default=5432, + show_default=True, + required=False, + ), ] + db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)] maintenance_work_mem: Annotated[ - Optional[str], + str | None, click.option( "--maintenance-work-mem", type=str, @@ -69,7 +69,7 @@ class PgVectorTypedDict(CommonTypedDict): ), ] max_parallel_workers: Annotated[ - Optional[int], + int | None, click.option( "--max-parallel-workers", type=int, @@ -78,7 +78,7 @@ class PgVectorTypedDict(CommonTypedDict): ), ] quantization_type: Annotated[ - Optional[str], + str | None, click.option( "--quantization-type", type=click.Choice(["none", "bit", "halfvec"]), @@ -87,7 +87,7 @@ class PgVectorTypedDict(CommonTypedDict): ), ] reranking: Annotated[ - Optional[bool], + bool | None, click.option( "--reranking/--skip-reranking", type=bool, @@ -96,11 +96,11 @@ class PgVectorTypedDict(CommonTypedDict): ), ] reranking_metric: Annotated[ - Optional[str], + str | None, click.option( "--reranking-metric", type=click.Choice( - [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]] + [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]], ), help="Distance metric for reranking", default="COSINE", @@ -108,7 +108,7 @@ class PgVectorTypedDict(CommonTypedDict): ), ] quantized_fetch_limit: Annotated[ - Optional[int], + int | None, click.option( "--quantized-fetch-limit", type=int, @@ -116,13 +116,11 @@ class PgVectorTypedDict(CommonTypedDict): -- bound by ef_search", required=False, callback=set_default_quantized_fetch_limit, - ) + ), ] - -class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): - ... +class PgVectorIVFFlatTypedDict(PgVectorTypedDict, IVFFlatTypedDict): ... @cli.command() @@ -156,8 +154,7 @@ def PgVectorIVFFlat( ) -class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1): - ... +class PgVectorHNSWTypedDict(PgVectorTypedDict, HNSWFlavor1): ... @cli.command() diff --git a/vectordb_bench/backend/clients/pgvector/config.py b/vectordb_bench/backend/clients/pgvector/config.py index 16d547445..c386d75ef 100644 --- a/vectordb_bench/backend/clients/pgvector/config.py +++ b/vectordb_bench/backend/clients/pgvector/config.py @@ -1,7 +1,9 @@ from abc import abstractmethod -from typing import Any, Mapping, Optional, Sequence, TypedDict +from collections.abc import Mapping, Sequence +from typing import Any, LiteralString, TypedDict + from pydantic import BaseModel, SecretStr -from typing_extensions import LiteralString + from ..api import DBCaseConfig, DBConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" @@ -9,7 +11,7 @@ class PgVectorConfigDict(TypedDict): """These keys will be directly used as kwargs in psycopg connection string, - so the names must match exactly psycopg API""" + so the names must match exactly psycopg API""" user: str password: str @@ -41,8 +43,8 @@ class PgVectorIndexParam(TypedDict): metric: str index_type: str index_creation_with_options: Sequence[dict[str, Any]] - maintenance_work_mem: Optional[str] - max_parallel_workers: Optional[int] + maintenance_work_mem: str | None + max_parallel_workers: int | None class PgVectorSearchParam(TypedDict): @@ -59,61 +61,60 @@ class PgVectorIndexConfig(BaseModel, DBCaseConfig): create_index_after_load: bool = True def parse_metric(self) -> str: - if self.quantization_type == "halfvec": - if self.metric_type == MetricType.L2: - return "halfvec_l2_ops" - elif self.metric_type == MetricType.IP: - return "halfvec_ip_ops" - return "halfvec_cosine_ops" - elif self.quantization_type == "bit": - if self.metric_type == MetricType.JACCARD: - return "bit_jaccard_ops" - return "bit_hamming_ops" - else: - if self.metric_type == MetricType.L2: - return "vector_l2_ops" - elif self.metric_type == MetricType.IP: - return "vector_ip_ops" - return "vector_cosine_ops" + d = { + "halfvec": { + MetricType.L2: "halfvec_l2_ops", + MetricType.IP: "halfvec_ip_ops", + MetricType.COSINE: "halfvec_cosine_ops", + }, + "bit": { + MetricType.JACCARD: "bit_jaccard_ops", + MetricType.HAMMING: "bit_hamming_ops", + }, + "_fallback": { + MetricType.L2: "vector_l2_ops", + MetricType.IP: "vector_ip_ops", + MetricType.COSINE: "vector_cosine_ops", + }, + } + + if d.get(self.quantization_type) is None: + return d.get("_fallback").get(self.metric_type) + return d.get(self.quantization_type).get(self.metric_type) def parse_metric_fun_op(self) -> LiteralString: if self.quantization_type == "bit": if self.metric_type == MetricType.JACCARD: return "<%>" return "<~>" - else: - if self.metric_type == MetricType.L2: - return "<->" - elif self.metric_type == MetricType.IP: - return "<#>" - return "<=>" + if self.metric_type == MetricType.L2: + return "<->" + if self.metric_type == MetricType.IP: + return "<#>" + return "<=>" def parse_metric_fun_str(self) -> str: if self.metric_type == MetricType.L2: return "l2_distance" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "max_inner_product" return "cosine_distance" - + def parse_reranking_metric_fun_op(self) -> LiteralString: if self.reranking_metric == MetricType.L2: return "<->" - elif self.reranking_metric == MetricType.IP: + if self.reranking_metric == MetricType.IP: return "<#>" return "<=>" - @abstractmethod - def index_param(self) -> PgVectorIndexParam: - ... + def index_param(self) -> PgVectorIndexParam: ... @abstractmethod - def search_param(self) -> PgVectorSearchParam: - ... + def search_param(self) -> PgVectorSearchParam: ... @abstractmethod - def session_param(self) -> PgVectorSessionCommands: - ... + def session_param(self) -> PgVectorSessionCommands: ... @staticmethod def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[dict[str, Any]]: @@ -125,24 +126,23 @@ def _optionally_build_with_options(with_options: Mapping[str, Any]) -> Sequence[ { "option_name": option_name, "val": str(value), - } + }, ) return options @staticmethod - def _optionally_build_set_options( - set_mapping: Mapping[str, Any] - ) -> Sequence[dict[str, Any]]: + def _optionally_build_set_options(set_mapping: Mapping[str, Any]) -> Sequence[dict[str, Any]]: """Walk through options, creating 'SET 'key1 = "value1";' list""" session_options = [] for setting_name, value in set_mapping.items(): if value: session_options.append( - {"parameter": { + { + "parameter": { "setting_name": setting_name, "val": str(value), }, - } + }, ) return session_options @@ -165,12 +165,12 @@ class PgVectorIVFFlatConfig(PgVectorIndexConfig): lists: int | None probes: int | None index: IndexType = IndexType.ES_IVFFlat - maintenance_work_mem: Optional[str] = None - max_parallel_workers: Optional[int] = None - quantization_type: Optional[str] = None - reranking: Optional[bool] = None - quantized_fetch_limit: Optional[int] = None - reranking_metric: Optional[str] = None + maintenance_work_mem: str | None = None + max_parallel_workers: int | None = None + quantization_type: str | None = None + reranking: bool | None = None + quantized_fetch_limit: int | None = None + reranking_metric: str | None = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"lists": self.lists} @@ -179,9 +179,7 @@ def index_param(self) -> PgVectorIndexParam: return { "metric": self.parse_metric(), "index_type": self.index.value, - "index_creation_with_options": self._optionally_build_with_options( - index_parameters - ), + "index_creation_with_options": self._optionally_build_with_options(index_parameters), "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, @@ -197,9 +195,7 @@ def search_param(self) -> PgVectorSearchParam: def session_param(self) -> PgVectorSessionCommands: session_parameters = {"ivfflat.probes": self.probes} - return { - "session_options": self._optionally_build_set_options(session_parameters) - } + return {"session_options": self._optionally_build_set_options(session_parameters)} class PgVectorHNSWConfig(PgVectorIndexConfig): @@ -210,17 +206,15 @@ class PgVectorHNSWConfig(PgVectorIndexConfig): """ m: int | None # DETAIL: Valid values are between "2" and "100". - ef_construction: ( - int | None - ) # ef_construction must be greater than or equal to 2 * m + ef_construction: int | None # ef_construction must be greater than or equal to 2 * m ef_search: int | None index: IndexType = IndexType.ES_HNSW - maintenance_work_mem: Optional[str] = None - max_parallel_workers: Optional[int] = None - quantization_type: Optional[str] = None - reranking: Optional[bool] = None - quantized_fetch_limit: Optional[int] = None - reranking_metric: Optional[str] = None + maintenance_work_mem: str | None = None + max_parallel_workers: int | None = None + quantization_type: str | None = None + reranking: bool | None = None + quantized_fetch_limit: int | None = None + reranking_metric: str | None = None def index_param(self) -> PgVectorIndexParam: index_parameters = {"m": self.m, "ef_construction": self.ef_construction} @@ -229,9 +223,7 @@ def index_param(self) -> PgVectorIndexParam: return { "metric": self.parse_metric(), "index_type": self.index.value, - "index_creation_with_options": self._optionally_build_with_options( - index_parameters - ), + "index_creation_with_options": self._optionally_build_with_options(index_parameters), "maintenance_work_mem": self.maintenance_work_mem, "max_parallel_workers": self.max_parallel_workers, "quantization_type": self.quantization_type, @@ -247,13 +239,11 @@ def search_param(self) -> PgVectorSearchParam: def session_param(self) -> PgVectorSessionCommands: session_parameters = {"hnsw.ef_search": self.ef_search} - return { - "session_options": self._optionally_build_set_options(session_parameters) - } + return {"session_options": self._optionally_build_set_options(session_parameters)} _pgvector_case_config = { - IndexType.HNSW: PgVectorHNSWConfig, - IndexType.ES_HNSW: PgVectorHNSWConfig, - IndexType.IVFFlat: PgVectorIVFFlatConfig, + IndexType.HNSW: PgVectorHNSWConfig, + IndexType.ES_HNSW: PgVectorHNSWConfig, + IndexType.IVFFlat: PgVectorIVFFlatConfig, } diff --git a/vectordb_bench/backend/clients/pgvector/pgvector.py b/vectordb_bench/backend/clients/pgvector/pgvector.py index 069b89381..62a7971bb 100644 --- a/vectordb_bench/backend/clients/pgvector/pgvector.py +++ b/vectordb_bench/backend/clients/pgvector/pgvector.py @@ -1,9 +1,9 @@ """Wrapper around the Pgvector vector database over VectorDB""" import logging -import pprint +from collections.abc import Generator, Sequence from contextlib import contextmanager -from typing import Any, Generator, Optional, Tuple, Sequence +from typing import Any import numpy as np import psycopg @@ -11,7 +11,7 @@ from psycopg import Connection, Cursor, sql from ..api import VectorDB -from .config import PgVectorConfigDict, PgVectorIndexConfig, PgVectorHNSWConfig +from .config import PgVectorConfigDict, PgVectorIndexConfig log = logging.getLogger(__name__) @@ -56,13 +56,14 @@ def __init__( ( self.case_config.create_index_before_load, self.case_config.create_index_after_load, - ) + ), ): - err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" - log.error(err) - raise RuntimeError( - f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + msg = ( + f"{self.name} config must create an index using create_index_before_load or create_index_after_load" + f"{self.name} config values: {self.db_config}\n{self.case_config}" ) + log.error(msg) + raise RuntimeError(msg) if drop_old: self._drop_index() @@ -77,7 +78,7 @@ def __init__( self.conn = None @staticmethod - def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + def _create_connection(**kwargs) -> tuple[Connection, Cursor]: conn = psycopg.connect(**kwargs) register_vector(conn) conn.autocommit = False @@ -87,8 +88,8 @@ def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: assert cursor is not None, "Cursor is not initialized" return conn, cursor - - def _generate_search_query(self, filtered: bool=False) -> sql.Composed: + + def _generate_search_query(self, filtered: bool = False) -> sql.Composed: index_param = self.case_config.index_param() reranking = self.case_config.search_param()["reranking"] column_name = ( @@ -103,23 +104,25 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: ) # The following sections assume that the quantization_type value matches the quantization function name - if index_param["quantization_type"] != None: + if index_param["quantization_type"] is not None: if index_param["quantization_type"] == "bit" and reranking: # Embeddings needs to be passed to binary_quantize function if quantization_type is bit search_query = sql.Composed( [ sql.SQL( """ - SELECT i.id + SELECT i.id FROM ( - SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance + SELECT id, embedding {reranking_metric_fun_op} %s::vector AS distance FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) - """ + """, ).format( table_name=sql.Identifier(self.table_name), column_name=column_name, - reranking_metric_fun_op=sql.SQL(self.case_config.search_param()["reranking_metric_fun_op"]), + reranking_metric_fun_op=sql.SQL( + self.case_config.search_param()["reranking_metric_fun_op"], + ), quantization_type=sql.SQL(index_param["quantization_type"]), dim=sql.Literal(self.dim), where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), @@ -127,25 +130,28 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL( """ - {search_vector} + {search_vector} LIMIT {quantized_fetch_limit} ) i - ORDER BY i.distance + ORDER BY i.distance LIMIT %s::int - """ + """, ).format( search_vector=search_vector, quantized_fetch_limit=sql.Literal( - self.case_config.search_param()["quantized_fetch_limit"] + self.case_config.search_param()["quantized_fetch_limit"], ), ), - ] + ], ) else: search_query = sql.Composed( [ sql.SQL( - "SELECT id FROM public.{table_name} {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) " + """ + SELECT id FROM public.{table_name} + {where_clause} ORDER BY {column_name}::{quantization_type}({dim}) + """, ).format( table_name=sql.Identifier(self.table_name), column_name=column_name, @@ -154,25 +160,26 @@ def _generate_search_query(self, filtered: bool=False) -> sql.Composed: where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" {search_vector} LIMIT %s::int").format(search_vector=search_vector), - ] + sql.SQL(" {search_vector} LIMIT %s::int").format( + search_vector=search_vector, + ), + ], ) else: search_query = sql.Composed( [ sql.SQL( - "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding " + "SELECT id FROM public.{table_name} {where_clause} ORDER BY embedding ", ).format( table_name=sql.Identifier(self.table_name), where_clause=sql.SQL("WHERE id >= %s") if filtered else sql.SQL(""), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), - ] + ], ) - + return search_query - @contextmanager def init(self) -> Generator[None, None, None]: @@ -191,8 +198,8 @@ def init(self) -> Generator[None, None, None]: if len(session_options) > 0: for setting in session_options: command = sql.SQL("SET {setting_name} " + "= {val};").format( - setting_name=sql.Identifier(setting['parameter']['setting_name']), - val=sql.Identifier(str(setting['parameter']['val'])), + setting_name=sql.Identifier(setting["parameter"]["setting_name"]), + val=sql.Identifier(str(setting["parameter"]["val"])), ) log.debug(command.as_string(self.cursor)) self.cursor.execute(command) @@ -216,8 +223,8 @@ def _drop_table(self): self.cursor.execute( sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) self.conn.commit() @@ -239,7 +246,7 @@ def _drop_index(self): log.info(f"{self.name} client drop index : {self._index_name}") drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( - index_name=sql.Identifier(self._index_name) + index_name=sql.Identifier(self._index_name), ) log.debug(drop_index_sql.as_string(self.cursor)) self.cursor.execute(drop_index_sql) @@ -254,63 +261,51 @@ def _set_parallel_index_build_param(self): if index_param["maintenance_work_mem"] is not None: self.cursor.execute( sql.SQL("SET maintenance_work_mem TO {};").format( - index_param["maintenance_work_mem"] - ) + index_param["maintenance_work_mem"], + ), ) self.cursor.execute( sql.SQL("ALTER USER {} SET maintenance_work_mem TO {};").format( sql.Identifier(self.db_config["user"]), index_param["maintenance_work_mem"], - ) + ), ) self.conn.commit() if index_param["max_parallel_workers"] is not None: self.cursor.execute( sql.SQL("SET max_parallel_maintenance_workers TO '{}';").format( - index_param["max_parallel_workers"] - ) + index_param["max_parallel_workers"], + ), ) self.cursor.execute( - sql.SQL( - "ALTER USER {} SET max_parallel_maintenance_workers TO '{}';" - ).format( + sql.SQL("ALTER USER {} SET max_parallel_maintenance_workers TO '{}';").format( sql.Identifier(self.db_config["user"]), index_param["max_parallel_workers"], - ) + ), ) self.cursor.execute( sql.SQL("SET max_parallel_workers TO '{}';").format( - index_param["max_parallel_workers"] - ) + index_param["max_parallel_workers"], + ), ) self.cursor.execute( - sql.SQL( - "ALTER USER {} SET max_parallel_workers TO '{}';" - ).format( + sql.SQL("ALTER USER {} SET max_parallel_workers TO '{}';").format( sql.Identifier(self.db_config["user"]), index_param["max_parallel_workers"], - ) + ), ) self.cursor.execute( - sql.SQL( - "ALTER TABLE {} SET (parallel_workers = {});" - ).format( + sql.SQL("ALTER TABLE {} SET (parallel_workers = {});").format( sql.Identifier(self.table_name), index_param["max_parallel_workers"], - ) + ), ) self.conn.commit() - results = self.cursor.execute( - sql.SQL("SHOW max_parallel_maintenance_workers;") - ).fetchall() - results.extend( - self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall() - ) - results.extend( - self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall() - ) + results = self.cursor.execute(sql.SQL("SHOW max_parallel_maintenance_workers;")).fetchall() + results.extend(self.cursor.execute(sql.SQL("SHOW max_parallel_workers;")).fetchall()) + results.extend(self.cursor.execute(sql.SQL("SHOW maintenance_work_mem;")).fetchall()) log.info(f"{self.name} parallel index creation parameters: {results}") def _create_index(self): @@ -322,24 +317,21 @@ def _create_index(self): self._set_parallel_index_build_param() options = [] for option in index_param["index_creation_with_options"]: - if option['val'] is not None: + if option["val"] is not None: options.append( sql.SQL("{option_name} = {val}").format( - option_name=sql.Identifier(option['option_name']), - val=sql.Identifier(str(option['val'])), - ) + option_name=sql.Identifier(option["option_name"]), + val=sql.Identifier(str(option["val"])), + ), ) - if any(options): - with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) - else: - with_clause = sql.Composed(()) + with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) - if index_param["quantization_type"] != None: + if index_param["quantization_type"] is not None: index_create_sql = sql.SQL( """ CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (({column_name}::{quantization_type}({dim})) {embedding_metric}) - """ + """, ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), @@ -357,9 +349,9 @@ def _create_index(self): else: index_create_sql = sql.SQL( """ - CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric}) - """ + """, ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), @@ -367,9 +359,7 @@ def _create_index(self): embedding_metric=sql.Identifier(index_param["metric"]), ) - index_create_sql_with_with_clause = ( - index_create_sql + with_clause - ).join(" ") + index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ") log.debug(index_create_sql_with_with_clause.as_string(self.cursor)) self.cursor.execute(index_create_sql_with_with_clause) self.conn.commit() @@ -384,19 +374,17 @@ def _create_table(self, dim: int): # create table self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" - ).format(table_name=sql.Identifier(self.table_name), dim=dim) + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", + ).format(table_name=sql.Identifier(self.table_name), dim=dim), ) self.cursor.execute( sql.SQL( - "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;" - ).format(table_name=sql.Identifier(self.table_name)) + "ALTER TABLE public.{table_name} ALTER COLUMN embedding SET STORAGE PLAIN;", + ).format(table_name=sql.Identifier(self.table_name)), ) self.conn.commit() except Exception as e: - log.warning( - f"Failed to create pgvector table: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create pgvector table: {self.table_name} error: {e}") raise e from None def insert_embeddings( @@ -404,7 +392,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -414,8 +402,8 @@ def insert_embeddings( with self.cursor.copy( sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) as copy: copy.set_types(["bigint", "vector"]) for i, row in enumerate(metadata_arr): @@ -428,7 +416,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: log.warning( - f"Failed to insert data into pgvector table ({self.table_name}), error: {e}" + f"Failed to insert data into pgvector table ({self.table_name}), error: {e}", ) return 0, e @@ -449,21 +437,32 @@ def search_embedding( gt = filters.get("id") if index_param["quantization_type"] == "bit" and search_param["reranking"]: result = self.cursor.execute( - self._filtered_search, (q, gt, q, k), prepare=True, binary=True + self._filtered_search, + (q, gt, q, k), + prepare=True, + binary=True, ) else: result = self.cursor.execute( - self._filtered_search, (gt, q, k), prepare=True, binary=True + self._filtered_search, + (gt, q, k), + prepare=True, + binary=True, ) - + + elif index_param["quantization_type"] == "bit" and search_param["reranking"]: + result = self.cursor.execute( + self._unfiltered_search, + (q, q, k), + prepare=True, + binary=True, + ) else: - if index_param["quantization_type"] == "bit" and search_param["reranking"]: - result = self.cursor.execute( - self._unfiltered_search, (q, q, k), prepare=True, binary=True - ) - else: - result = self.cursor.execute( - self._unfiltered_search, (q, k), prepare=True, binary=True - ) + result = self.cursor.execute( + self._unfiltered_search, + (q, k), + prepare=True, + binary=True, + ) return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/backend/clients/pgvectorscale/cli.py b/vectordb_bench/backend/clients/pgvectorscale/cli.py index e5a161c6b..fca9d51d8 100644 --- a/vectordb_bench/backend/clients/pgvectorscale/cli.py +++ b/vectordb_bench/backend/clients/pgvectorscale/cli.py @@ -1,80 +1,94 @@ -import click import os +from typing import Annotated, Unpack + +import click from pydantic import SecretStr +from vectordb_bench.backend.clients import DB + from ....cli.cli import ( CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run, ) -from typing import Annotated, Unpack -from vectordb_bench.backend.clients import DB class PgVectorScaleTypedDict(CommonTypedDict): user_name: Annotated[ - str, click.option("--user-name", type=str, help="Db username", required=True) + str, + click.option("--user-name", type=str, help="Db username", required=True), ] password: Annotated[ str, - click.option("--password", - type=str, - help="Postgres database password", - default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), - show_default="$POSTGRES_PASSWORD", - ), + click.option( + "--password", + type=str, + help="Postgres database password", + default=lambda: os.environ.get("POSTGRES_PASSWORD", ""), + show_default="$POSTGRES_PASSWORD", + ), ] - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] - db_name: Annotated[ - str, click.option("--db-name", type=str, help="Db name", required=True) - ] + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] + db_name: Annotated[str, click.option("--db-name", type=str, help="Db name", required=True)] class PgVectorScaleDiskAnnTypedDict(PgVectorScaleTypedDict): storage_layout: Annotated[ str, click.option( - "--storage-layout", type=str, help="Streaming DiskANN storage layout", + "--storage-layout", + type=str, + help="Streaming DiskANN storage layout", ), ] num_neighbors: Annotated[ int, click.option( - "--num-neighbors", type=int, help="Streaming DiskANN num neighbors", + "--num-neighbors", + type=int, + help="Streaming DiskANN num neighbors", ), ] search_list_size: Annotated[ int, click.option( - "--search-list-size", type=int, help="Streaming DiskANN search list size", + "--search-list-size", + type=int, + help="Streaming DiskANN search list size", ), ] max_alpha: Annotated[ float, click.option( - "--max-alpha", type=float, help="Streaming DiskANN max alpha", + "--max-alpha", + type=float, + help="Streaming DiskANN max alpha", ), ] num_dimensions: Annotated[ int, click.option( - "--num-dimensions", type=int, help="Streaming DiskANN num dimensions", + "--num-dimensions", + type=int, + help="Streaming DiskANN num dimensions", ), ] query_search_list_size: Annotated[ int, click.option( - "--query-search-list-size", type=int, help="Streaming DiskANN query search list size", + "--query-search-list-size", + type=int, + help="Streaming DiskANN query search list size", ), ] query_rescore: Annotated[ int, click.option( - "--query-rescore", type=int, help="Streaming DiskANN query rescore", + "--query-rescore", + type=int, + help="Streaming DiskANN query rescore", ), ] @@ -105,4 +119,4 @@ def PgVectorScaleDiskAnn( query_rescore=parameters["query_rescore"], ), **parameters, - ) \ No newline at end of file + ) diff --git a/vectordb_bench/backend/clients/pgvectorscale/config.py b/vectordb_bench/backend/clients/pgvectorscale/config.py index bd9f6106b..e22c45c8d 100644 --- a/vectordb_bench/backend/clients/pgvectorscale/config.py +++ b/vectordb_bench/backend/clients/pgvectorscale/config.py @@ -1,7 +1,8 @@ from abc import abstractmethod -from typing import TypedDict +from typing import LiteralString, TypedDict + from pydantic import BaseModel, SecretStr -from typing_extensions import LiteralString + from ..api import DBCaseConfig, DBConfig, IndexType, MetricType POSTGRE_URL_PLACEHOLDER = "postgresql://%s:%s@%s/%s" @@ -9,7 +10,7 @@ class PgVectorScaleConfigDict(TypedDict): """These keys will be directly used as kwargs in psycopg connection string, - so the names must match exactly psycopg API""" + so the names must match exactly psycopg API""" user: str password: str @@ -46,7 +47,7 @@ def parse_metric(self) -> str: if self.metric_type == MetricType.COSINE: return "vector_cosine_ops" return "" - + def parse_metric_fun_op(self) -> LiteralString: if self.metric_type == MetricType.COSINE: return "<=>" @@ -56,19 +57,16 @@ def parse_metric_fun_str(self) -> str: if self.metric_type == MetricType.COSINE: return "cosine_distance" return "" - + @abstractmethod - def index_param(self) -> dict: - ... + def index_param(self) -> dict: ... @abstractmethod - def search_param(self) -> dict: - ... + def search_param(self) -> dict: ... @abstractmethod - def session_param(self) -> dict: - ... - + def session_param(self) -> dict: ... + class PgVectorScaleStreamingDiskANNConfig(PgVectorScaleIndexConfig): index: IndexType = IndexType.STREAMING_DISKANN @@ -93,19 +91,20 @@ def index_param(self) -> dict: "num_dimensions": self.num_dimensions, }, } - + def search_param(self) -> dict: return { "metric": self.parse_metric(), "metric_fun_op": self.parse_metric_fun_op(), } - + def session_param(self) -> dict: return { "diskann.query_search_list_size": self.query_search_list_size, "diskann.query_rescore": self.query_rescore, } - + + _pgvectorscale_case_config = { IndexType.STREAMING_DISKANN: PgVectorScaleStreamingDiskANNConfig, } diff --git a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py index d8f26394c..981accc2e 100644 --- a/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py +++ b/vectordb_bench/backend/clients/pgvectorscale/pgvectorscale.py @@ -1,9 +1,9 @@ """Wrapper around the Pgvectorscale vector database over VectorDB""" import logging -import pprint +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator, Optional, Tuple +from typing import Any import numpy as np import psycopg @@ -44,20 +44,21 @@ def __init__( self._primary_field = "id" self._vector_field = "embedding" - self.conn, self.cursor = self._create_connection(**self.db_config) + self.conn, self.cursor = self._create_connection(**self.db_config) log.info(f"{self.name} config values: {self.db_config}\n{self.case_config}") if not any( ( self.case_config.create_index_before_load, self.case_config.create_index_after_load, - ) + ), ): - err = f"{self.name} config must create an index using create_index_before_load or create_index_after_load" - log.error(err) - raise RuntimeError( - f"{err}\n{pprint.pformat(self.db_config)}\n{pprint.pformat(self.case_config)}" + msg = ( + f"{self.name} config must create an index using create_index_before_load or create_index_after_load" + f"{self.name} config values: {self.db_config}\n{self.case_config}" ) + log.error(msg) + raise RuntimeError(msg) if drop_old: self._drop_index() @@ -72,7 +73,7 @@ def __init__( self.conn = None @staticmethod - def _create_connection(**kwargs) -> Tuple[Connection, Cursor]: + def _create_connection(**kwargs) -> tuple[Connection, Cursor]: conn = psycopg.connect(**kwargs) conn.cursor().execute("CREATE EXTENSION IF NOT EXISTS vectorscale CASCADE") conn.commit() @@ -101,25 +102,25 @@ def init(self) -> Generator[None, None, None]: log.debug(command.as_string(self.cursor)) self.cursor.execute(command) self.conn.commit() - + self._filtered_search = sql.Composed( [ sql.SQL("SELECT id FROM public.{} WHERE id >= %s ORDER BY embedding ").format( sql.Identifier(self.table_name), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), - sql.SQL(" %s::vector LIMIT %s::int") - ] + sql.SQL(" %s::vector LIMIT %s::int"), + ], ) - + self._unfiltered_search = sql.Composed( [ sql.SQL("SELECT id FROM public.{} ORDER BY embedding ").format( - sql.Identifier(self.table_name) + sql.Identifier(self.table_name), ), sql.SQL(self.case_config.search_param()["metric_fun_op"]), sql.SQL(" %s::vector LIMIT %s::int"), - ] + ], ) try: @@ -137,8 +138,8 @@ def _drop_table(self): self.cursor.execute( sql.SQL("DROP TABLE IF EXISTS public.{table_name}").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) self.conn.commit() @@ -160,7 +161,7 @@ def _drop_index(self): log.info(f"{self.name} client drop index : {self._index_name}") drop_index_sql = sql.SQL("DROP INDEX IF EXISTS {index_name}").format( - index_name=sql.Identifier(self._index_name) + index_name=sql.Identifier(self._index_name), ) log.debug(drop_index_sql.as_string(self.cursor)) self.cursor.execute(drop_index_sql) @@ -180,36 +181,31 @@ def _create_index(self): sql.SQL("{option_name} = {val}").format( option_name=sql.Identifier(option_name), val=sql.Identifier(str(option_val)), - ) + ), ) - + num_bits_per_dimension = "2" if self.dim < 900 else "1" options.append( sql.SQL("{option_name} = {val}").format( option_name=sql.Identifier("num_bits_per_dimension"), val=sql.Identifier(num_bits_per_dimension), - ) + ), ) - if any(options): - with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) - else: - with_clause = sql.Composed(()) + with_clause = sql.SQL("WITH ({});").format(sql.SQL(", ").join(options)) if any(options) else sql.Composed(()) index_create_sql = sql.SQL( """ - CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} + CREATE INDEX IF NOT EXISTS {index_name} ON public.{table_name} USING {index_type} (embedding {embedding_metric}) - """ + """, ).format( index_name=sql.Identifier(self._index_name), table_name=sql.Identifier(self.table_name), index_type=sql.Identifier(index_param["index_type"].lower()), embedding_metric=sql.Identifier(index_param["metric"]), ) - index_create_sql_with_with_clause = ( - index_create_sql + with_clause - ).join(" ") + index_create_sql_with_with_clause = (index_create_sql + with_clause).join(" ") log.debug(index_create_sql_with_with_clause.as_string(self.cursor)) self.cursor.execute(index_create_sql_with_with_clause) self.conn.commit() @@ -223,14 +219,12 @@ def _create_table(self, dim: int): self.cursor.execute( sql.SQL( - "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));" - ).format(table_name=sql.Identifier(self.table_name), dim=dim) + "CREATE TABLE IF NOT EXISTS public.{table_name} (id BIGINT PRIMARY KEY, embedding vector({dim}));", + ).format(table_name=sql.Identifier(self.table_name), dim=dim), ) self.conn.commit() except Exception as e: - log.warning( - f"Failed to create pgvectorscale table: {self.table_name} error: {e}" - ) + log.warning(f"Failed to create pgvectorscale table: {self.table_name} error: {e}") raise e from None def insert_embeddings( @@ -238,7 +232,7 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception | None]: assert self.conn is not None, "Connection is not initialized" assert self.cursor is not None, "Cursor is not initialized" @@ -248,8 +242,8 @@ def insert_embeddings( with self.cursor.copy( sql.SQL("COPY public.{table_name} FROM STDIN (FORMAT BINARY)").format( - table_name=sql.Identifier(self.table_name) - ) + table_name=sql.Identifier(self.table_name), + ), ) as copy: copy.set_types(["bigint", "vector"]) for i, row in enumerate(metadata_arr): @@ -262,7 +256,7 @@ def insert_embeddings( return len(metadata), None except Exception as e: log.warning( - f"Failed to insert data into pgvector table ({self.table_name}), error: {e}" + f"Failed to insert data into pgvector table ({self.table_name}), error: {e}", ) return 0, e @@ -280,11 +274,12 @@ def search_embedding( if filters: gt = filters.get("id") result = self.cursor.execute( - self._filtered_search, (gt, q, k), prepare=True, binary=True + self._filtered_search, + (gt, q, k), + prepare=True, + binary=True, ) else: - result = self.cursor.execute( - self._unfiltered_search, (q, k), prepare=True, binary=True - ) + result = self.cursor.execute(self._unfiltered_search, (q, k), prepare=True, binary=True) return [int(i[0]) for i in result.fetchall()] diff --git a/vectordb_bench/backend/clients/pinecone/config.py b/vectordb_bench/backend/clients/pinecone/config.py index 2bbcbb350..fe1a039ed 100644 --- a/vectordb_bench/backend/clients/pinecone/config.py +++ b/vectordb_bench/backend/clients/pinecone/config.py @@ -1,4 +1,5 @@ from pydantic import SecretStr + from ..api import DBConfig diff --git a/vectordb_bench/backend/clients/pinecone/pinecone.py b/vectordb_bench/backend/clients/pinecone/pinecone.py index c1351f7a9..c59ee8760 100644 --- a/vectordb_bench/backend/clients/pinecone/pinecone.py +++ b/vectordb_bench/backend/clients/pinecone/pinecone.py @@ -2,11 +2,11 @@ import logging from contextlib import contextmanager -from typing import Type + import pinecone -from ..api import VectorDB, DBConfig, DBCaseConfig, EmptyDBCaseConfig, IndexType -from .config import PineconeConfig +from ..api import DBCaseConfig, DBConfig, EmptyDBCaseConfig, IndexType, VectorDB +from .config import PineconeConfig log = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class Pinecone(VectorDB): def __init__( self, - dim, + dim: int, db_config: dict, db_case_config: DBCaseConfig, drop_old: bool = False, @@ -27,7 +27,7 @@ def __init__( self.index_name = db_config.get("index_name", "") self.api_key = db_config.get("api_key", "") self.batch_size = int( - min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH) + min(PINECONE_MAX_SIZE_PER_BATCH / (dim * 5), PINECONE_MAX_NUM_PER_BATCH), ) pc = pinecone.Pinecone(api_key=self.api_key) @@ -37,9 +37,8 @@ def __init__( index_stats = index.describe_index_stats() index_dim = index_stats["dimension"] if index_dim != dim: - raise ValueError( - f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}" - ) + msg = f"Pinecone index {self.index_name} dimension mismatch, expected {index_dim} got {dim}" + raise ValueError(msg) for namespace in index_stats["namespaces"]: log.info(f"Pinecone index delete namespace: {namespace}") index.delete(delete_all=True, namespace=namespace) @@ -47,11 +46,11 @@ def __init__( self._metadata_key = "meta" @classmethod - def config_cls(cls) -> Type[DBConfig]: + def config_cls(cls) -> type[DBConfig]: return PineconeConfig @classmethod - def case_config_cls(cls, index_type: IndexType | None = None) -> Type[DBCaseConfig]: + def case_config_cls(cls, index_type: IndexType | None = None) -> type[DBCaseConfig]: return EmptyDBCaseConfig @contextmanager @@ -76,9 +75,7 @@ def insert_embeddings( insert_count = 0 try: for batch_start_offset in range(0, len(embeddings), self.batch_size): - batch_end_offset = min( - batch_start_offset + self.batch_size, len(embeddings) - ) + batch_end_offset = min(batch_start_offset + self.batch_size, len(embeddings)) insert_datas = [] for i in range(batch_start_offset, batch_end_offset): insert_data = ( @@ -100,10 +97,7 @@ def search_embedding( filters: dict | None = None, timeout: int | None = None, ) -> list[int]: - if filters is None: - pinecone_filters = {} - else: - pinecone_filters = {self._metadata_key: {"$gte": filters["id"]}} + pinecone_filters = {} if filters is None else {self._metadata_key: {"$gte": filters["id"]}} try: res = self.index.query( top_k=k, @@ -111,7 +105,6 @@ def search_embedding( filter=pinecone_filters, )["matches"] except Exception as e: - print(f"Error querying index: {e}") - raise e - id_res = [int(one_res["id"]) for one_res in res] - return id_res + log.warning(f"Error querying index: {e}") + raise e from e + return [int(one_res["id"]) for one_res in res] diff --git a/vectordb_bench/backend/clients/qdrant_cloud/config.py b/vectordb_bench/backend/clients/qdrant_cloud/config.py index 5b1dd7f18..c1d6882c0 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/config.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/config.py @@ -1,7 +1,7 @@ -from pydantic import BaseModel, SecretStr +from pydantic import BaseModel, SecretStr, validator + +from ..api import DBCaseConfig, DBConfig, MetricType -from ..api import DBConfig, DBCaseConfig, MetricType -from pydantic import validator # Allowing `api_key` to be left empty, to ensure compatibility with the open-source Qdrant. class QdrantConfig(DBConfig): @@ -16,17 +16,19 @@ def to_dict(self) -> dict: "api_key": self.api_key.get_secret_value(), "prefer_grpc": True, } - else: - return {"url": self.url.get_secret_value(),} - + return { + "url": self.url.get_secret_value(), + } + @validator("*") - def not_empty_field(cls, v, field): + def not_empty_field(cls, v: any, field: any): if field.name in ["api_key", "db_label"]: return v - if isinstance(v, (str, SecretStr)) and len(v) == 0: + if isinstance(v, str | SecretStr) and len(v) == 0: raise ValueError("Empty string!") return v + class QdrantIndexConfig(BaseModel, DBCaseConfig): metric_type: MetricType | None = None @@ -40,8 +42,7 @@ def parse_metric(self) -> str: return "Cosine" def index_param(self) -> dict: - params = {"distance": self.parse_metric()} - return params + return {"distance": self.parse_metric()} def search_param(self) -> dict: return {} diff --git a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py index a51632bc6..0861e8938 100644 --- a/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py +++ b/vectordb_bench/backend/clients/qdrant_cloud/qdrant_cloud.py @@ -4,23 +4,26 @@ import time from contextlib import contextmanager -from ..api import VectorDB, DBCaseConfig +from qdrant_client import QdrantClient from qdrant_client.http.models import ( - CollectionStatus, - VectorParams, - PayloadSchemaType, Batch, - Filter, + CollectionStatus, FieldCondition, + Filter, + PayloadSchemaType, Range, + VectorParams, ) -from qdrant_client import QdrantClient - +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) +SECONDS_WAITING_FOR_INDEXING_API_CALL = 5 +QDRANT_BATCH_SIZE = 500 + + class QdrantCloud(VectorDB): def __init__( self, @@ -57,16 +60,14 @@ def init(self) -> None: self.qdrant_client = QdrantClient(**self.db_config) yield self.qdrant_client = None - del(self.qdrant_client) + del self.qdrant_client def ready_to_load(self): pass - def optimize(self): assert self.qdrant_client, "Please call self.init() before" # wait for vectors to be fully indexed - SECONDS_WAITING_FOR_INDEXING_API_CALL = 5 try: while True: info = self.qdrant_client.get_collection(self.collection_name) @@ -74,19 +75,26 @@ def optimize(self): if info.status != CollectionStatus.GREEN: continue if info.status == CollectionStatus.GREEN: - log.info(f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, Collection status: {info.indexed_vectors_count}") + msg = ( + f"Stored vectors: {info.vectors_count}, Indexed vectors: {info.indexed_vectors_count}, ", + f"Collection status: {info.indexed_vectors_count}", + ) + log.info(msg) return except Exception as e: log.warning(f"QdrantCloud ready to search error: {e}") raise e from None - def _create_collection(self, dim, qdrant_client: int): + def _create_collection(self, dim: int, qdrant_client: QdrantClient): log.info(f"Create collection: {self.collection_name}") try: qdrant_client.create_collection( collection_name=self.collection_name, - vectors_config=VectorParams(size=dim, distance=self.case_config.index_param()["distance"]) + vectors_config=VectorParams( + size=dim, + distance=self.case_config.index_param()["distance"], + ), ) qdrant_client.create_payload_index( @@ -109,13 +117,12 @@ def insert_embeddings( ) -> (int, Exception): """Insert embeddings into Milvus. should call self.init() first""" assert self.qdrant_client is not None - QDRANT_BATCH_SIZE = 500 try: # TODO: counts for offset in range(0, len(embeddings), QDRANT_BATCH_SIZE): - vectors = embeddings[offset: offset + QDRANT_BATCH_SIZE] - ids = metadata[offset: offset + QDRANT_BATCH_SIZE] - payloads=[{self._primary_field: v} for v in ids] + vectors = embeddings[offset : offset + QDRANT_BATCH_SIZE] + ids = metadata[offset : offset + QDRANT_BATCH_SIZE] + payloads = [{self._primary_field: v} for v in ids] _ = self.qdrant_client.upsert( collection_name=self.collection_name, wait=True, @@ -142,21 +149,23 @@ def search_embedding( f = None if filters: f = Filter( - must=[FieldCondition( - key = self._primary_field, - range = Range( - gt=filters.get('id'), + must=[ + FieldCondition( + key=self._primary_field, + range=Range( + gt=filters.get("id"), + ), ), - )] + ], ) - res = self.qdrant_client.search( - collection_name=self.collection_name, - query_vector=query, - limit=k, - query_filter=f, - # with_payload=True, - ), + res = ( + self.qdrant_client.search( + collection_name=self.collection_name, + query_vector=query, + limit=k, + query_filter=f, + ), + ) - ret = [result.id for result in res[0]] - return ret + return [result.id for result in res[0]] diff --git a/vectordb_bench/backend/clients/redis/cli.py b/vectordb_bench/backend/clients/redis/cli.py index eb86b4c00..69277c00f 100644 --- a/vectordb_bench/backend/clients/redis/cli.py +++ b/vectordb_bench/backend/clients/redis/cli.py @@ -3,9 +3,6 @@ import click from pydantic import SecretStr -from .config import RedisHNSWConfig - - from ....cli.cli import ( CommonTypedDict, HNSWFlavor2, @@ -14,12 +11,11 @@ run, ) from .. import DB +from .config import RedisHNSWConfig class RedisTypedDict(TypedDict): - host: Annotated[ - str, click.option("--host", type=str, help="Db host", required=True) - ] + host: Annotated[str, click.option("--host", type=str, help="Db host", required=True)] password: Annotated[str, click.option("--password", type=str, help="Db password")] port: Annotated[int, click.option("--port", type=int, default=6379, help="Db Port")] ssl: Annotated[ @@ -52,27 +48,25 @@ class RedisTypedDict(TypedDict): ] -class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2): - ... +class RedisHNSWTypedDict(CommonTypedDict, RedisTypedDict, HNSWFlavor2): ... @cli.command() @click_parameter_decorators_from_typed_dict(RedisHNSWTypedDict) def Redis(**parameters: Unpack[RedisHNSWTypedDict]): from .config import RedisConfig + run( db=DB.Redis, db_config=RedisConfig( db_label=parameters["db_label"], - password=SecretStr(parameters["password"]) - if parameters["password"] - else None, + password=SecretStr(parameters["password"]) if parameters["password"] else None, host=SecretStr(parameters["host"]), port=parameters["port"], ssl=parameters["ssl"], ssl_ca_certs=parameters["ssl_ca_certs"], cmd=parameters["cmd"], - ), + ), db_case_config=RedisHNSWConfig( M=parameters["m"], efConstruction=parameters["ef_construction"], diff --git a/vectordb_bench/backend/clients/redis/config.py b/vectordb_bench/backend/clients/redis/config.py index 55a7a4159..db0127f84 100644 --- a/vectordb_bench/backend/clients/redis/config.py +++ b/vectordb_bench/backend/clients/redis/config.py @@ -1,10 +1,12 @@ -from pydantic import SecretStr, BaseModel -from ..api import DBConfig, DBCaseConfig, MetricType, IndexType +from pydantic import BaseModel, SecretStr + +from ..api import DBCaseConfig, DBConfig, IndexType, MetricType + class RedisConfig(DBConfig): password: SecretStr | None = None host: SecretStr - port: int | None = None + port: int | None = None def to_dict(self) -> dict: return { @@ -12,7 +14,6 @@ def to_dict(self) -> dict: "port": self.port, "password": self.password.get_secret_value() if self.password is not None else None, } - class RedisIndexConfig(BaseModel): @@ -24,7 +25,8 @@ def parse_metric(self) -> str: if not self.metric_type: return "" return self.metric_type.value - + + class RedisHNSWConfig(RedisIndexConfig, DBCaseConfig): M: int efConstruction: int diff --git a/vectordb_bench/backend/clients/redis/redis.py b/vectordb_bench/backend/clients/redis/redis.py index cf51fcc48..139850d2f 100644 --- a/vectordb_bench/backend/clients/redis/redis.py +++ b/vectordb_bench/backend/clients/redis/redis.py @@ -1,36 +1,40 @@ import logging from contextlib import contextmanager from typing import Any -from ..api import VectorDB, DBCaseConfig + +import numpy as np import redis -from redis.commands.search.field import TagField, VectorField, NumericField +from redis.commands.search.field import NumericField, TagField, VectorField from redis.commands.search.indexDefinition import IndexDefinition, IndexType from redis.commands.search.query import Query -import numpy as np +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) -INDEX_NAME = "index" # Vector Index Name +INDEX_NAME = "index" # Vector Index Name + class Redis(VectorDB): def __init__( - self, - dim: int, - db_config: dict, - db_case_config: DBCaseConfig, - drop_old: bool = False, - - **kwargs - ): - + self, + dim: int, + db_config: dict, + db_case_config: DBCaseConfig, + drop_old: bool = False, + **kwargs, + ): self.db_config = db_config self.case_config = db_case_config self.collection_name = INDEX_NAME # Create a redis connection, if db has password configured, add it to the connection here and in init(): - password=self.db_config["password"] - conn = redis.Redis(host=self.db_config["host"], port=self.db_config["port"], password=password, db=0) - + password = self.db_config["password"] + conn = redis.Redis( + host=self.db_config["host"], + port=self.db_config["port"], + password=password, + db=0, + ) if drop_old: try: @@ -39,7 +43,7 @@ def __init__( except redis.exceptions.ResponseError: drop_old = False log.info(f"Redis client drop_old collection: {self.collection_name}") - + self.make_index(dim, conn) conn.close() conn = None @@ -48,18 +52,20 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis): try: # check to see if index exists conn.ft(INDEX_NAME).info() - except: + except Exception: schema = ( - TagField("id"), - NumericField("metadata"), - VectorField("vector", # Vector Field Name - "HNSW", { # Vector Index Type: FLAT or HNSW - "TYPE": "FLOAT32", # FLOAT32 or FLOAT64 - "DIM": vector_dimensions, # Number of Vector Dimensions - "DISTANCE_METRIC": "COSINE", # Vector Search Distance Metric + TagField("id"), + NumericField("metadata"), + VectorField( + "vector", # Vector Field Name + "HNSW", # Vector Index Type: FLAT or HNSW + { + "TYPE": "FLOAT32", # FLOAT32 or FLOAT64 + "DIM": vector_dimensions, # Number of Vector Dimensions + "DISTANCE_METRIC": "COSINE", # Vector Search Distance Metric "M": self.case_config.index_param()["params"]["M"], "EF_CONSTRUCTION": self.case_config.index_param()["params"]["efConstruction"], - } + }, ), ) @@ -69,23 +75,25 @@ def make_index(self, vector_dimensions: int, conn: redis.Redis): rs.create_index(schema, definition=definition) @contextmanager - def init(self): - """ create and destory connections to database. + def init(self) -> None: + """create and destory connections to database. Examples: >>> with self.init(): >>> self.insert_embeddings() """ - self.conn = redis.Redis(host=self.db_config["host"], port=self.db_config["port"], password=self.db_config["password"], db=0) + self.conn = redis.Redis( + host=self.db_config["host"], + port=self.db_config["port"], + password=self.db_config["password"], + db=0, + ) yield self.conn.close() self.conn = None - def ready_to_search(self) -> bool: """Check if the database is ready to search.""" - pass - def ready_to_load(self) -> bool: pass @@ -93,7 +101,6 @@ def ready_to_load(self) -> bool: def optimize(self) -> None: pass - def insert_embeddings( self, embeddings: list[list[float]], @@ -104,27 +111,30 @@ def insert_embeddings( Should call self.init() first. """ - batch_size = 1000 # Adjust this as needed, but don't make too big + batch_size = 1000 # Adjust this as needed, but don't make too big try: with self.conn.pipeline(transaction=False) as pipe: for i, embedding in enumerate(embeddings): - embedding = np.array(embedding).astype(np.float32) - pipe.hset(metadata[i], mapping = { - "id": str(metadata[i]), - "metadata": metadata[i], - "vector": embedding.tobytes(), - }) + ndarr_emb = np.array(embedding).astype(np.float32) + pipe.hset( + metadata[i], + mapping={ + "id": str(metadata[i]), + "metadata": metadata[i], + "vector": ndarr_emb.tobytes(), + }, + ) # Execute the pipe so we don't keep too much in memory at once if i % batch_size == 0: - res = pipe.execute() + _ = pipe.execute() - res = pipe.execute() + _ = pipe.execute() result_len = i + 1 except Exception as e: return 0, e - + return result_len, None - + def search_embedding( self, query: list[float], @@ -132,26 +142,53 @@ def search_embedding( filters: dict | None = None, timeout: int | None = None, **kwargs: Any, - ) -> (list[int]): + ) -> list[int]: assert self.conn is not None - + query_vector = np.array(query).astype(np.float32).tobytes() ef_runtime = self.case_config.search_param()["params"]["ef"] - query_obj = Query(f"*=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) + query_obj = ( + Query(f"*=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]") + .sort_by("score") + .return_fields("id", "score") + .paging(0, k) + .dialect(2) + ) query_params = {"vec": query_vector} - + if filters: # benchmark test filters of format: {'metadata': '>=10000', 'id': 10000} # gets exact match for id, and range for metadata if they exist in filters id_value = filters.get("id") metadata_value = filters.get("metadata") if id_value and metadata_value: - query_obj = Query(f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) + query_obj = ( + Query( + f"(@metadata:[{metadata_value} +inf] @id:{ {id_value} })=>[KNN {k} ", + f"@vector $vec EF_RUNTIME {ef_runtime} as score]", + ) + .sort_by("score") + .return_fields("id", "score") + .paging(0, k) + .dialect(2) + ) elif id_value: - #gets exact match for id - query_obj = Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) - else: #metadata only case, greater than or equal to metadata value - query_obj = Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]").sort_by("score").return_fields("id", "score").paging(0, k).dialect(2) + # gets exact match for id + query_obj = ( + Query(f"@id:{ {id_value} }=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]") + .sort_by("score") + .return_fields("id", "score") + .paging(0, k) + .dialect(2) + ) + else: # metadata only case, greater than or equal to metadata value + query_obj = ( + Query(f"@metadata:[{metadata_value} +inf]=>[KNN {k} @vector $vec EF_RUNTIME {ef_runtime} as score]") + .sort_by("score") + .return_fields("id", "score") + .paging(0, k) + .dialect(2) + ) res = self.conn.ft(INDEX_NAME).search(query_obj, query_params) # doc in res of format {'id': '9831', 'payload': None, 'score': '1.19209289551e-07'} return [int(doc["id"]) for doc in res.docs] diff --git a/vectordb_bench/backend/clients/test/cli.py b/vectordb_bench/backend/clients/test/cli.py index f06f33492..e5cd4c78b 100644 --- a/vectordb_bench/backend/clients/test/cli.py +++ b/vectordb_bench/backend/clients/test/cli.py @@ -10,8 +10,7 @@ from ..test.config import TestConfig, TestIndexConfig -class TestTypedDict(CommonTypedDict): - ... +class TestTypedDict(CommonTypedDict): ... @cli.command() diff --git a/vectordb_bench/backend/clients/test/config.py b/vectordb_bench/backend/clients/test/config.py index 01a77e000..351d7bcac 100644 --- a/vectordb_bench/backend/clients/test/config.py +++ b/vectordb_bench/backend/clients/test/config.py @@ -1,6 +1,6 @@ -from pydantic import BaseModel, SecretStr +from pydantic import BaseModel -from ..api import DBCaseConfig, DBConfig, IndexType, MetricType +from ..api import DBCaseConfig, DBConfig, MetricType class TestConfig(DBConfig): diff --git a/vectordb_bench/backend/clients/test/test.py b/vectordb_bench/backend/clients/test/test.py index 78732eb1e..ee5a523f3 100644 --- a/vectordb_bench/backend/clients/test/test.py +++ b/vectordb_bench/backend/clients/test/test.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator, Optional, Tuple +from typing import Any from ..api import DBCaseConfig, VectorDB @@ -43,11 +44,10 @@ def insert_embeddings( embeddings: list[list[float]], metadata: list[int], **kwargs: Any, - ) -> Tuple[int, Optional[Exception]]: + ) -> tuple[int, Exception | None]: """Insert embeddings into the database. Should call self.init() first. """ - raise RuntimeError("Not implemented") return len(metadata), None def search_embedding( @@ -58,5 +58,4 @@ def search_embedding( timeout: int | None = None, **kwargs: Any, ) -> list[int]: - raise NotImplementedError - return [i for i in range(k)] + return list(range(k)) diff --git a/vectordb_bench/backend/clients/weaviate_cloud/cli.py b/vectordb_bench/backend/clients/weaviate_cloud/cli.py index b6f16011b..181898c74 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/cli.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/cli.py @@ -14,7 +14,8 @@ class WeaviateTypedDict(CommonTypedDict): api_key: Annotated[ - str, click.option("--api-key", type=str, help="Weaviate api key", required=True) + str, + click.option("--api-key", type=str, help="Weaviate api key", required=True), ] url: Annotated[ str, @@ -34,8 +35,6 @@ def Weaviate(**parameters: Unpack[WeaviateTypedDict]): api_key=SecretStr(parameters["api_key"]), url=SecretStr(parameters["url"]), ), - db_case_config=WeaviateIndexConfig( - ef=256, efConstruction=256, maxConnections=16 - ), + db_case_config=WeaviateIndexConfig(ef=256, efConstruction=256, maxConnections=16), **parameters, ) diff --git a/vectordb_bench/backend/clients/weaviate_cloud/config.py b/vectordb_bench/backend/clients/weaviate_cloud/config.py index 8b2f3ab81..4c58167d4 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/config.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/config.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, SecretStr -from ..api import DBConfig, DBCaseConfig, MetricType +from ..api import DBCaseConfig, DBConfig, MetricType class WeaviateConfig(DBConfig): @@ -23,7 +23,7 @@ class WeaviateIndexConfig(BaseModel, DBCaseConfig): def parse_metric(self) -> str: if self.metric_type == MetricType.L2: return "l2-squared" - elif self.metric_type == MetricType.IP: + if self.metric_type == MetricType.IP: return "dot" return "cosine" diff --git a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py index d8fde5f09..b42f70af1 100644 --- a/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py +++ b/vectordb_bench/backend/clients/weaviate_cloud/weaviate_cloud.py @@ -1,13 +1,13 @@ """Wrapper around the Weaviate vector database over VectorDB""" import logging -from typing import Iterable +from collections.abc import Iterable from contextlib import contextmanager import weaviate from weaviate.exceptions import WeaviateBaseError -from ..api import VectorDB, DBCaseConfig +from ..api import DBCaseConfig, VectorDB log = logging.getLogger(__name__) @@ -23,7 +23,13 @@ def __init__( **kwargs, ): """Initialize wrapper around the weaviate vector database.""" - db_config.update({"auth_client_secret": weaviate.AuthApiKey(api_key=db_config.get("auth_client_secret"))}) + db_config.update( + { + "auth_client_secret": weaviate.AuthApiKey( + api_key=db_config.get("auth_client_secret"), + ), + }, + ) self.db_config = db_config self.case_config = db_case_config self.collection_name = collection_name @@ -33,6 +39,7 @@ def __init__( self._index_name = "vector_idx" from weaviate import Client + client = Client(**db_config) if drop_old: try: @@ -40,7 +47,7 @@ def __init__( log.info(f"weaviate client drop_old collection: {self.collection_name}") client.schema.delete_class(self.collection_name) except WeaviateBaseError as e: - log.warning(f"Failed to drop collection: {self.collection_name} error: {str(e)}") + log.warning(f"Failed to drop collection: {self.collection_name} error: {e!s}") raise e from None self._create_collection(client) client = None @@ -54,20 +61,23 @@ def init(self) -> None: >>> self.search_embedding() """ from weaviate import Client + self.client = Client(**self.db_config) yield self.client = None - del(self.client) + del self.client def ready_to_load(self): """Should call insert first, do nothing""" - pass def optimize(self): assert self.client.schema.exists(self.collection_name) - self.client.schema.update_config(self.collection_name, {"vectorIndexConfig": self.case_config.search_param() } ) + self.client.schema.update_config( + self.collection_name, + {"vectorIndexConfig": self.case_config.search_param()}, + ) - def _create_collection(self, client): + def _create_collection(self, client: weaviate.Client) -> None: if not client.schema.exists(self.collection_name): log.info(f"Create collection: {self.collection_name}") class_obj = { @@ -78,13 +88,13 @@ def _create_collection(self, client): "dataType": ["int"], "name": self._scalar_field, }, - ] + ], } class_obj["vectorIndexConfig"] = self.case_config.index_param() try: client.schema.create_class(class_obj) except WeaviateBaseError as e: - log.warning(f"Failed to create collection: {self.collection_name} error: {str(e)}") + log.warning(f"Failed to create collection: {self.collection_name} error: {e!s}") raise e from None def insert_embeddings( @@ -102,15 +112,17 @@ def insert_embeddings( batch.dynamic = True res = [] for i in range(len(metadata)): - res.append(batch.add_data_object( - {self._scalar_field: metadata[i]}, - class_name=self.collection_name, - vector=embeddings[i] - )) + res.append( + batch.add_data_object( + {self._scalar_field: metadata[i]}, + class_name=self.collection_name, + vector=embeddings[i], + ), + ) insert_count += 1 return (len(res), None) except WeaviateBaseError as e: - log.warning(f"Failed to insert data, error: {str(e)}") + log.warning(f"Failed to insert data, error: {e!s}") return (insert_count, e) def search_embedding( @@ -125,12 +137,17 @@ def search_embedding( """ assert self.client.schema.exists(self.collection_name) - query_obj = self.client.query.get(self.collection_name, [self._scalar_field]).with_additional("distance").with_near_vector({"vector": query}).with_limit(k) + query_obj = ( + self.client.query.get(self.collection_name, [self._scalar_field]) + .with_additional("distance") + .with_near_vector({"vector": query}) + .with_limit(k) + ) if filters: where_filter = { "path": "key", "operator": "GreaterThanEqual", - "valueInt": filters.get('id') + "valueInt": filters.get("id"), } query_obj = query_obj.with_where(where_filter) @@ -138,7 +155,4 @@ def search_embedding( res = query_obj.do() # Organize results. - ret = [result[self._scalar_field] for result in res["data"]["Get"][self.collection_name]] - - return ret - + return [result[self._scalar_field] for result in res["data"]["Get"][self.collection_name]] diff --git a/vectordb_bench/backend/clients/zilliz_cloud/cli.py b/vectordb_bench/backend/clients/zilliz_cloud/cli.py index 31618f4ec..810a4175e 100644 --- a/vectordb_bench/backend/clients/zilliz_cloud/cli.py +++ b/vectordb_bench/backend/clients/zilliz_cloud/cli.py @@ -1,33 +1,36 @@ +import os from typing import Annotated, Unpack import click -import os from pydantic import SecretStr +from vectordb_bench.backend.clients import DB from vectordb_bench.cli.cli import ( CommonTypedDict, cli, click_parameter_decorators_from_typed_dict, run, ) -from vectordb_bench.backend.clients import DB class ZillizTypedDict(CommonTypedDict): uri: Annotated[ - str, click.option("--uri", type=str, help="uri connection string", required=True) + str, + click.option("--uri", type=str, help="uri connection string", required=True), ] user_name: Annotated[ - str, click.option("--user-name", type=str, help="Db username", required=True) + str, + click.option("--user-name", type=str, help="Db username", required=True), ] password: Annotated[ str, - click.option("--password", - type=str, - help="Zilliz password", - default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""), - show_default="$ZILLIZ_PASSWORD", - ), + click.option( + "--password", + type=str, + help="Zilliz password", + default=lambda: os.environ.get("ZILLIZ_PASSWORD", ""), + show_default="$ZILLIZ_PASSWORD", + ), ] level: Annotated[ str, @@ -38,7 +41,7 @@ class ZillizTypedDict(CommonTypedDict): @cli.command() @click_parameter_decorators_from_typed_dict(ZillizTypedDict) def ZillizAutoIndex(**parameters: Unpack[ZillizTypedDict]): - from .config import ZillizCloudConfig, AutoIndexConfig + from .config import AutoIndexConfig, ZillizCloudConfig run( db=DB.ZillizCloud, diff --git a/vectordb_bench/backend/clients/zilliz_cloud/config.py b/vectordb_bench/backend/clients/zilliz_cloud/config.py index ee60b397f..9f113dbda 100644 --- a/vectordb_bench/backend/clients/zilliz_cloud/config.py +++ b/vectordb_bench/backend/clients/zilliz_cloud/config.py @@ -1,7 +1,7 @@ from pydantic import SecretStr from ..api import DBCaseConfig, DBConfig -from ..milvus.config import MilvusIndexConfig, IndexType +from ..milvus.config import IndexType, MilvusIndexConfig class ZillizCloudConfig(DBConfig): @@ -33,7 +33,5 @@ def search_param(self) -> dict: "metric_type": self.parse_metric(), "params": { "level": self.level, - } + }, } - - diff --git a/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py b/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py index 36f7fb204..4ce15545a 100644 --- a/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py +++ b/vectordb_bench/backend/clients/zilliz_cloud/zilliz_cloud.py @@ -1,7 +1,7 @@ """Wrapper around the ZillizCloud vector database over VectorDB""" -from ..milvus.milvus import Milvus from ..api import DBCaseConfig +from ..milvus.milvus import Milvus class ZillizCloud(Milvus): diff --git a/vectordb_bench/backend/data_source.py b/vectordb_bench/backend/data_source.py index 9e2f172b4..b98dc7d7a 100644 --- a/vectordb_bench/backend/data_source.py +++ b/vectordb_bench/backend/data_source.py @@ -1,12 +1,12 @@ import logging import pathlib import typing +from abc import ABC, abstractmethod from enum import Enum + from tqdm import tqdm -import os -from abc import ABC, abstractmethod -from .. import config +from vectordb_bench import config logging.getLogger("s3fs").setLevel(logging.CRITICAL) @@ -14,6 +14,7 @@ DatasetReader = typing.TypeVar("DatasetReader") + class DatasetSource(Enum): S3 = "S3" AliyunOSS = "AliyunOSS" @@ -25,6 +26,8 @@ def reader(self) -> DatasetReader: if self == DatasetSource.AliyunOSS: return AliyunOSSReader() + return None + class DatasetReader(ABC): source: DatasetSource @@ -39,7 +42,6 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): files(list[str]): all filenames of the dataset local_ds_root(pathlib.Path): whether to write the remote data. """ - pass @abstractmethod def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: @@ -52,15 +54,18 @@ class AliyunOSSReader(DatasetReader): def __init__(self): import oss2 + self.bucket = oss2.Bucket(oss2.AnonymousAuth(), self.remote_root, "benchmark", True) def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: info = self.bucket.get_object_meta(remote.as_posix()) # check size equal - remote_size, local_size = info.content_length, os.path.getsize(local) + remote_size, local_size = info.content_length, local.stat().st_size if remote_size != local_size: - log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") + log.info( + f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]", + ) return False return True @@ -70,7 +75,13 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): if not local_ds_root.exists(): log.info(f"local dataset root path not exist, creating it: {local_ds_root}") local_ds_root.mkdir(parents=True) - downloads = [(pathlib.PurePosixPath("benchmark", dataset, f), local_ds_root.joinpath(f)) for f in files] + downloads = [ + ( + pathlib.PurePosixPath("benchmark", dataset, f), + local_ds_root.joinpath(f), + ) + for f in files + ] else: for file in files: @@ -78,7 +89,9 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): local_file = local_ds_root.joinpath(file) if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)): - log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") + log.info( + f"local file: {local_file} not match with remote: {remote_file}; add to downloading list", + ) downloads.append((remote_file, local_file)) if len(downloads) == 0: @@ -92,17 +105,14 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}") - class AwsS3Reader(DatasetReader): source: DatasetSource = DatasetSource.S3 remote_root: str = config.AWS_S3_URL def __init__(self): import s3fs - self.fs = s3fs.S3FileSystem( - anon=True, - client_kwargs={'region_name': 'us-west-2'} - ) + + self.fs = s3fs.S3FileSystem(anon=True, client_kwargs={"region_name": "us-west-2"}) def ls_all(self, dataset: str): dataset_root_dir = pathlib.Path(self.remote_root, dataset) @@ -112,7 +122,6 @@ def ls_all(self, dataset: str): log.info(n) return names - def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): downloads = [] if not local_ds_root.exists(): @@ -126,7 +135,9 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): local_file = local_ds_root.joinpath(file) if (not local_file.exists()) or (not self.validate_file(remote_file, local_file)): - log.info(f"local file: {local_file} not match with remote: {remote_file}; add to downloading list") + log.info( + f"local file: {local_file} not match with remote: {remote_file}; add to downloading list", + ) downloads.append(remote_file) if len(downloads) == 0: @@ -139,15 +150,16 @@ def read(self, dataset: str, files: list[str], local_ds_root: pathlib.Path): log.info(f"Succeed to download all files, downloaded file count = {len(downloads)}") - def validate_file(self, remote: pathlib.Path, local: pathlib.Path) -> bool: # info() uses ls() inside, maybe we only need to ls once info = self.fs.info(remote) # check size equal - remote_size, local_size = info.get("size"), os.path.getsize(local) + remote_size, local_size = info.get("size"), local.stat().st_size if remote_size != local_size: - log.info(f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]") + log.info( + f"local file: {local} size[{local_size}] not match with remote size[{remote_size}]", + ) return False return True diff --git a/vectordb_bench/backend/dataset.py b/vectordb_bench/backend/dataset.py index d62d96684..62700b0fa 100644 --- a/vectordb_bench/backend/dataset.py +++ b/vectordb_bench/backend/dataset.py @@ -4,25 +4,30 @@ >>> Dataset.Cohere.get(100_000) """ -from collections import namedtuple import logging import pathlib +import typing from enum import Enum + import pandas as pd -from pydantic import validator, PrivateAttr import polars as pl from pyarrow.parquet import ParquetFile +from pydantic import PrivateAttr, validator + +from vectordb_bench import config +from vectordb_bench.base import BaseModel -from ..base import BaseModel -from .. import config -from ..backend.clients import MetricType from . import utils -from .data_source import DatasetSource, DatasetReader +from .clients import MetricType +from .data_source import DatasetReader, DatasetSource log = logging.getLogger(__name__) -SizeLabel = namedtuple('SizeLabel', ['size', 'label', 'file_count']) +class SizeLabel(typing.NamedTuple): + size: int + label: str + file_count: int class BaseDataset(BaseModel): @@ -33,12 +38,13 @@ class BaseDataset(BaseModel): use_shuffled: bool with_gt: bool = False _size_label: dict[int, SizeLabel] = PrivateAttr() - isCustom: bool = False + is_custom: bool = False @validator("size") - def verify_size(cls, v): + def verify_size(cls, v: int): if v not in cls._size_label: - raise ValueError(f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}") + msg = f"Size {v} not supported for the dataset, expected: {cls._size_label.keys()}" + raise ValueError(msg) return v @property @@ -53,13 +59,14 @@ def dir_name(self) -> str: def file_count(self) -> int: return self._size_label.get(self.size).file_count + class CustomDataset(BaseDataset): dir: str file_num: int - isCustom: bool = True + is_custom: bool = True @validator("size") - def verify_size(cls, v): + def verify_size(cls, v: int): return v @property @@ -102,7 +109,7 @@ class Cohere(BaseDataset): dim: int = 768 metric_type: MetricType = MetricType.COSINE use_shuffled: bool = config.USE_SHUFFLED_DATA - with_gt: bool = True, + with_gt: bool = (True,) _size_label: dict = { 100_000: SizeLabel(100_000, "SMALL", 1), 1_000_000: SizeLabel(1_000_000, "MEDIUM", 1), @@ -124,7 +131,11 @@ class SIFT(BaseDataset): metric_type: MetricType = MetricType.L2 use_shuffled: bool = False _size_label: dict = { - 500_000: SizeLabel(500_000, "SMALL", 1,), + 500_000: SizeLabel( + 500_000, + "SMALL", + 1, + ), 5_000_000: SizeLabel(5_000_000, "MEDIUM", 1), # 50_000_000: SizeLabel(50_000_000, "LARGE", 50), } @@ -135,7 +146,7 @@ class OpenAI(BaseDataset): dim: int = 1536 metric_type: MetricType = MetricType.COSINE use_shuffled: bool = config.USE_SHUFFLED_DATA - with_gt: bool = True, + with_gt: bool = (True,) _size_label: dict = { 50_000: SizeLabel(50_000, "SMALL", 1), 500_000: SizeLabel(500_000, "MEDIUM", 1), @@ -153,13 +164,14 @@ class DatasetManager(BaseModel): >>> for data in cohere: >>> print(data.columns) """ - data: BaseDataset + + data: BaseDataset test_data: pd.DataFrame | None = None gt_data: pd.DataFrame | None = None - train_files : list[str] = [] + train_files: list[str] = [] reader: DatasetReader | None = None - def __eq__(self, obj): + def __eq__(self, obj: any): if isinstance(obj, DatasetManager): return self.data.name == obj.data.name and self.data.label == obj.data.label return False @@ -169,22 +181,27 @@ def set_reader(self, reader: DatasetReader): @property def data_dir(self) -> pathlib.Path: - """ data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname} + """data local directory: config.DATASET_LOCAL_DIR/{dataset_name}/{dataset_dirname} Examples: >>> sift_s = Dataset.SIFT.manager(500_000) >>> sift_s.relative_path '/tmp/vectordb_bench/dataset/sift/sift_small_500k/' """ - return pathlib.Path(config.DATASET_LOCAL_DIR, self.data.name.lower(), self.data.dir_name.lower()) + return pathlib.Path( + config.DATASET_LOCAL_DIR, + self.data.name.lower(), + self.data.dir_name.lower(), + ) def __iter__(self): return DataSetIterator(self) # TODO passing use_shuffle from outside - def prepare(self, - source: DatasetSource=DatasetSource.S3, - filters: int | float | str | None = None, + def prepare( + self, + source: DatasetSource = DatasetSource.S3, + filters: float | str | None = None, ) -> bool: """Download the dataset from DatasetSource url = f"{source}/{self.data.dir_name}" @@ -208,7 +225,7 @@ def prepare(self, gt_file, test_file = utils.compose_gt_file(filters), "test.parquet" all_files.extend([gt_file, test_file]) - if not self.data.isCustom: + if not self.data.is_custom: source.reader().read( dataset=self.data.dir_name.lower(), files=all_files, @@ -220,7 +237,7 @@ def prepare(self, self.gt_data = self._read_file(gt_file) prefix = "shuffle_train" if use_shuffled else "train" - self.train_files = sorted([f.name for f in self.data_dir.glob(f'{prefix}*.parquet')]) + self.train_files = sorted([f.name for f in self.data_dir.glob(f"{prefix}*.parquet")]) log.debug(f"{self.data.name}: available train files {self.train_files}") return True @@ -241,7 +258,7 @@ def __init__(self, dataset: DatasetManager): self._ds = dataset self._idx = 0 # file number self._cur = None - self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file + self._sub_idx = [0 for i in range(len(self._ds.train_files))] # iter num for each file def __iter__(self): return self @@ -250,7 +267,9 @@ def _get_iter(self, file_name: str): p = pathlib.Path(self._ds.data_dir, file_name) log.info(f"Get iterator for {p.name}") if not p.exists(): - raise IndexError(f"No such file {p}") + msg = f"No such file: {p}" + log.warning(msg) + raise IndexError(msg) return ParquetFile(p, memory_map=True, pre_buffer=True).iter_batches(config.NUM_PER_BATCH) def __next__(self) -> pd.DataFrame: @@ -281,6 +300,7 @@ class Dataset(Enum): >>> Dataset.COHERE.manager(100_000) >>> Dataset.COHERE.get(100_000) """ + LAION = LAION GIST = GIST COHERE = Cohere diff --git a/vectordb_bench/backend/result_collector.py b/vectordb_bench/backend/result_collector.py index d4c073c1a..a4119a5fd 100644 --- a/vectordb_bench/backend/result_collector.py +++ b/vectordb_bench/backend/result_collector.py @@ -1,7 +1,7 @@ +import logging import pathlib -from ..models import TestResult -import logging +from vectordb_bench.models import TestResult log = logging.getLogger(__name__) @@ -14,7 +14,6 @@ def collect(cls, result_dir: pathlib.Path) -> list[TestResult]: if not result_dir.exists() or len(list(result_dir.rglob(reg))) == 0: return [] - for json_file in result_dir.rglob(reg): file_result = TestResult.read_file(json_file, trans_unit=True) diff --git a/vectordb_bench/backend/runner/__init__.py b/vectordb_bench/backend/runner/__init__.py index 77bb25d67..b83df6f99 100644 --- a/vectordb_bench/backend/runner/__init__.py +++ b/vectordb_bench/backend/runner/__init__.py @@ -1,12 +1,10 @@ from .mp_runner import ( MultiProcessingSearchRunner, ) - -from .serial_runner import SerialSearchRunner, SerialInsertRunner - +from .serial_runner import SerialInsertRunner, SerialSearchRunner __all__ = [ - 'MultiProcessingSearchRunner', - 'SerialSearchRunner', - 'SerialInsertRunner', + "MultiProcessingSearchRunner", + "SerialInsertRunner", + "SerialSearchRunner", ] diff --git a/vectordb_bench/backend/runner/mp_runner.py b/vectordb_bench/backend/runner/mp_runner.py index 8f35dcd8d..5b69b5481 100644 --- a/vectordb_bench/backend/runner/mp_runner.py +++ b/vectordb_bench/backend/runner/mp_runner.py @@ -1,27 +1,29 @@ -import time -import traceback import concurrent +import logging import multiprocessing as mp import random -import logging -from typing import Iterable +import time +import traceback +from collections.abc import Iterable + import numpy as np -from ..clients import api -from ... import config +from ... import config +from ..clients import api NUM_PER_BATCH = config.NUM_PER_BATCH log = logging.getLogger(__name__) class MultiProcessingSearchRunner: - """ multiprocessing search runner + """multiprocessing search runner Args: k(int): search topk, default to 100 concurrency(Iterable): concurrencies, default [1, 5, 10, 15, 20, 25, 30, 35] duration(int): duration for each concurency, default to 30s """ + def __init__( self, db: api.VectorDB, @@ -40,7 +42,12 @@ def __init__( self.test_data = test_data log.debug(f"test dataset columns: {len(test_data)}") - def search(self, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> tuple[int, float]: + def search( + self, + test_data: list[list[float]], + q: mp.Queue, + cond: mp.Condition, + ) -> tuple[int, float]: # sync all process q.put(1) with cond: @@ -71,24 +78,27 @@ def search(self, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) idx = idx + 1 if idx < num - 1 else 0 if count % 500 == 0: - log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}") + log.debug( + f"({mp.current_process().name:16}) ", + f"search_count: {count}, latest_latency={time.perf_counter()-s}", + ) total_dur = round(time.perf_counter() - start_time, 4) log.info( f"{mp.current_process().name:16} search {self.duration}s: " - f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}" - ) + f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}", + ) return (count, total_dur, latencies) @staticmethod def get_mp_context(): mp_start_method = "spawn" - log.debug(f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}") + log.debug( + f"MultiProcessingSearchRunner get multiprocessing start method: {mp_start_method}", + ) return mp.get_context(mp_start_method) - - def _run_all_concurrencies_mem_efficient(self): max_qps = 0 conc_num_list = [] @@ -99,8 +109,13 @@ def _run_all_concurrencies_mem_efficient(self): for conc in self.concurrencies: with mp.Manager() as m: q, cond = m.Queue(), m.Condition() - with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor: - log.info(f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}") + with concurrent.futures.ProcessPoolExecutor( + mp_context=self.get_mp_context(), + max_workers=conc, + ) as executor: + log.info( + f"Start search {self.duration}s in concurrency {conc}, filters: {self.filters}", + ) future_iter = [executor.submit(self.search, self.test_data, q, cond) for i in range(conc)] # Sync all processes while q.qsize() < conc: @@ -109,7 +124,9 @@ def _run_all_concurrencies_mem_efficient(self): with cond: cond.notify_all() - log.info(f"Syncing all process and start concurrency search, concurrency={conc}") + log.info( + f"Syncing all process and start concurrency search, concurrency={conc}", + ) start = time.perf_counter() all_count = sum([r.result()[0] for r in future_iter]) @@ -123,13 +140,19 @@ def _run_all_concurrencies_mem_efficient(self): conc_qps_list.append(qps) conc_latency_p99_list.append(latency_p99) conc_latency_avg_list.append(latency_avg) - log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}") + log.info( + f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}", + ) if qps > max_qps: max_qps = qps - log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}") + log.info( + f"Update largest qps with concurrency {conc}: current max_qps={max_qps}", + ) except Exception as e: - log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}") + log.warning( + f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}", + ) traceback.print_exc() # No results available, raise exception @@ -139,7 +162,13 @@ def _run_all_concurrencies_mem_efficient(self): finally: self.stop() - return max_qps, conc_num_list, conc_qps_list, conc_latency_p99_list, conc_latency_avg_list + return ( + max_qps, + conc_num_list, + conc_qps_list, + conc_latency_p99_list, + conc_latency_avg_list, + ) def run(self) -> float: """ @@ -160,9 +189,16 @@ def _run_by_dur(self, duration: int) -> float: for conc in self.concurrencies: with mp.Manager() as m: q, cond = m.Queue(), m.Condition() - with concurrent.futures.ProcessPoolExecutor(mp_context=self.get_mp_context(), max_workers=conc) as executor: - log.info(f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}") - future_iter = [executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc)] + with concurrent.futures.ProcessPoolExecutor( + mp_context=self.get_mp_context(), + max_workers=conc, + ) as executor: + log.info( + f"Start search_by_dur {duration}s in concurrency {conc}, filters: {self.filters}", + ) + future_iter = [ + executor.submit(self.search_by_dur, duration, self.test_data, q, cond) for i in range(conc) + ] # Sync all processes while q.qsize() < conc: sleep_t = conc if conc < 10 else 10 @@ -170,20 +206,28 @@ def _run_by_dur(self, duration: int) -> float: with cond: cond.notify_all() - log.info(f"Syncing all process and start concurrency search, concurrency={conc}") + log.info( + f"Syncing all process and start concurrency search, concurrency={conc}", + ) start = time.perf_counter() all_count = sum([r.result() for r in future_iter]) cost = time.perf_counter() - start qps = round(all_count / cost, 4) - log.info(f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}") + log.info( + f"End search in concurrency {conc}: dur={cost}s, total_count={all_count}, qps={qps}", + ) if qps > max_qps: max_qps = qps - log.info(f"Update largest qps with concurrency {conc}: current max_qps={max_qps}") + log.info( + f"Update largest qps with concurrency {conc}: current max_qps={max_qps}", + ) except Exception as e: - log.warning(f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}") + log.warning( + f"Fail to search all concurrencies: {self.concurrencies}, max_qps before failure={max_qps}, reason={e}", + ) traceback.print_exc() # No results available, raise exception @@ -195,8 +239,13 @@ def _run_by_dur(self, duration: int) -> float: return max_qps - - def search_by_dur(self, dur: int, test_data: list[list[float]], q: mp.Queue, cond: mp.Condition) -> int: + def search_by_dur( + self, + dur: int, + test_data: list[list[float]], + q: mp.Queue, + cond: mp.Condition, + ) -> int: # sync all process q.put(1) with cond: @@ -225,13 +274,15 @@ def search_by_dur(self, dur: int, test_data: list[list[float]], q: mp.Queue, con idx = idx + 1 if idx < num - 1 else 0 if count % 500 == 0: - log.debug(f"({mp.current_process().name:16}) search_count: {count}, latest_latency={time.perf_counter()-s}") + log.debug( + f"({mp.current_process().name:16}) search_count: {count}, ", + f"latest_latency={time.perf_counter()-s}", + ) total_dur = round(time.perf_counter() - start_time, 4) log.debug( f"{mp.current_process().name:16} search {self.duration}s: " - f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}" - ) + f"actual_dur={total_dur}s, count={count}, qps in this process: {round(count / total_dur, 4):3}", + ) return count - diff --git a/vectordb_bench/backend/runner/rate_runner.py b/vectordb_bench/backend/runner/rate_runner.py index d77c0fd15..0145af4ce 100644 --- a/vectordb_bench/backend/runner/rate_runner.py +++ b/vectordb_bench/backend/runner/rate_runner.py @@ -1,36 +1,36 @@ +import concurrent import logging +import multiprocessing as mp import time -import concurrent from concurrent.futures import ThreadPoolExecutor -import multiprocessing as mp - +from vectordb_bench import config from vectordb_bench.backend.clients import api from vectordb_bench.backend.dataset import DataSetIterator from vectordb_bench.backend.utils import time_it -from vectordb_bench import config from .util import get_data + log = logging.getLogger(__name__) class RatedMultiThreadingInsertRunner: def __init__( self, - rate: int, # numRows per second + rate: int, # numRows per second db: api.VectorDB, dataset_iter: DataSetIterator, normalize: bool = False, timeout: float | None = None, ): - self.timeout = timeout if isinstance(timeout, (int, float)) else None + self.timeout = timeout if isinstance(timeout, int | float) else None self.dataset = dataset_iter self.db = db self.normalize = normalize self.insert_rate = rate self.batch_rate = rate // config.NUM_PER_BATCH - def send_insert_task(self, db, emb: list[list[float]], metadata: list[str]): + def send_insert_task(self, db: api.VectorDB, emb: list[list[float]], metadata: list[str]): db.insert_embeddings(emb, metadata) @time_it @@ -43,7 +43,9 @@ def submit_by_rate() -> bool: rate = self.batch_rate for data in self.dataset: emb, metadata = get_data(data, self.normalize) - executing_futures.append(executor.submit(self.send_insert_task, self.db, emb, metadata)) + executing_futures.append( + executor.submit(self.send_insert_task, self.db, emb, metadata), + ) rate -= 1 if rate == 0: @@ -66,19 +68,26 @@ def submit_by_rate() -> bool: done, not_done = concurrent.futures.wait( executing_futures, timeout=wait_interval, - return_when=concurrent.futures.FIRST_EXCEPTION) + return_when=concurrent.futures.FIRST_EXCEPTION, + ) if len(not_done) > 0: - log.warning(f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round") + log.warning( + f"Failed to finish all tasks in 1s, [{len(not_done)}/{len(executing_futures)}] ", + f"tasks are not done, waited={wait_interval:.2f}, trying to wait in the next round", + ) executing_futures = list(not_done) else: - log.debug(f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} task in 1s, wait_interval={wait_interval:.2f}") + log.debug( + f"Finished {len(executing_futures)} insert-{config.NUM_PER_BATCH} ", + f"task in 1s, wait_interval={wait_interval:.2f}", + ) executing_futures = [] except Exception as e: - log.warn(f"task error, terminating, err={e}") - q.put(None, block=True) - executor.shutdown(wait=True, cancel_futures=True) - raise e + log.warning(f"task error, terminating, err={e}") + q.put(None, block=True) + executor.shutdown(wait=True, cancel_futures=True) + raise e from e dur = time.perf_counter() - start_time if dur < 1: @@ -87,10 +96,12 @@ def submit_by_rate() -> bool: # wait for all tasks in executing_futures to complete if len(executing_futures) > 0: try: - done, _ = concurrent.futures.wait(executing_futures, - return_when=concurrent.futures.FIRST_EXCEPTION) + done, _ = concurrent.futures.wait( + executing_futures, + return_when=concurrent.futures.FIRST_EXCEPTION, + ) except Exception as e: - log.warn(f"task error, terminating, err={e}") + log.warning(f"task error, terminating, err={e}") q.put(None, block=True) executor.shutdown(wait=True, cancel_futures=True) - raise e + raise e from e diff --git a/vectordb_bench/backend/runner/read_write_runner.py b/vectordb_bench/backend/runner/read_write_runner.py index fd425b227..e916f45d6 100644 --- a/vectordb_bench/backend/runner/read_write_runner.py +++ b/vectordb_bench/backend/runner/read_write_runner.py @@ -1,16 +1,18 @@ +import concurrent import logging -from typing import Iterable +import math import multiprocessing as mp -import concurrent +from collections.abc import Iterable + import numpy as np -import math -from .mp_runner import MultiProcessingSearchRunner -from .serial_runner import SerialSearchRunner -from .rate_runner import RatedMultiThreadingInsertRunner from vectordb_bench.backend.clients import api from vectordb_bench.backend.dataset import DatasetManager +from .mp_runner import MultiProcessingSearchRunner +from .rate_runner import RatedMultiThreadingInsertRunner +from .serial_runner import SerialSearchRunner + log = logging.getLogger(__name__) @@ -24,8 +26,14 @@ def __init__( k: int = 100, filters: dict | None = None, concurrencies: Iterable[int] = (1, 15, 50), - search_stage: Iterable[float] = (0.5, 0.6, 0.7, 0.8, 0.9), # search from insert portion, 0.0 means search from the start - read_dur_after_write: int = 300, # seconds, search duration when insertion is done + search_stage: Iterable[float] = ( + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + ), # search from insert portion, 0.0 means search from the start + read_dur_after_write: int = 300, # seconds, search duration when insertion is done timeout: float | None = None, ): self.insert_rate = insert_rate @@ -36,7 +44,10 @@ def __init__( self.search_stage = sorted(search_stage) self.read_dur_after_write = read_dur_after_write - log.info(f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, stage_search_dur={read_dur_after_write}") + log.info( + f"Init runner, concurencys={concurrencies}, search_stage={search_stage}, ", + f"stage_search_dur={read_dur_after_write}", + ) test_emb = np.stack(dataset.test_data["emb"]) if normalize: @@ -76,8 +87,13 @@ def run_search(self): log.info("Search after write - Serial search start") res, ssearch_dur = self.serial_search_runner.run() recall, ndcg, p99_latency = res - log.info(f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}") - log.info(f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}") + log.info( + f"Search after write - Serial search - recall={recall}, ndcg={ndcg}, p99={p99_latency}, ", + f"dur={ssearch_dur:.4f}", + ) + log.info( + f"Search after wirte - Conc search start, dur for each conc={self.read_dur_after_write}", + ) max_qps = self.run_by_dur(self.read_dur_after_write) log.info(f"Search after wirte - Conc search finished, max_qps={max_qps}") @@ -86,7 +102,10 @@ def run_search(self): def run_read_write(self): with mp.Manager() as m: q = m.Queue() - with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context("spawn"), max_workers=2) as executor: + with concurrent.futures.ProcessPoolExecutor( + mp_context=mp.get_context("spawn"), + max_workers=2, + ) as executor: read_write_futures = [] read_write_futures.append(executor.submit(self.run_with_rate, q)) read_write_futures.append(executor.submit(self.run_search_by_sig, q)) @@ -107,10 +126,10 @@ def run_read_write(self): except Exception as e: log.warning(f"Read and write error: {e}") executor.shutdown(wait=True, cancel_futures=True) - raise e + raise e from e log.info("Concurrent read write all done") - def run_search_by_sig(self, q): + def run_search_by_sig(self, q: mp.Queue): """ Args: q: multiprocessing queue @@ -122,15 +141,14 @@ def run_search_by_sig(self, q): total_batch = math.ceil(self.data_volume / self.insert_rate) recall, ndcg, p99_latency = None, None, None - def wait_next_target(start, target_batch) -> bool: + def wait_next_target(start: int, target_batch: int) -> bool: """Return False when receive True or None""" while start < target_batch: sig = q.get(block=True) if sig is None or sig is True: return False - else: - start += 1 + start += 1 return True for idx, stage in enumerate(self.search_stage): @@ -139,19 +157,24 @@ def wait_next_target(start, target_batch) -> bool: got = wait_next_target(start_batch, target_batch) if got is False: - log.warning(f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}") - return + log.warning( + f"Abnormal exit, target_batch={target_batch}, start_batch={start_batch}", + ) + return None log.info(f"Insert {perc}% done, total batch={total_batch}") log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% start") res, ssearch_dur = self.serial_search_runner.run() recall, ndcg, p99_latency = res - log.info(f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}") + log.info( + f"[{target_batch}/{total_batch}] Serial search - {perc}% done, recall={recall}, ", + f"ndcg={ndcg}, p99={p99_latency}, dur={ssearch_dur:.4f}", + ) # Search duration for non-last search stage is carefully calculated. # If duration for each concurrency is less than 30s, runner will raise error. if idx < len(self.search_stage) - 1: - total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate + total_dur_between_stages = self.data_volume * (self.search_stage[idx + 1] - stage) // self.insert_rate csearch_dur = total_dur_between_stages - ssearch_dur # Try to leave room for init process executors @@ -159,14 +182,19 @@ def wait_next_target(start, target_batch) -> bool: each_conc_search_dur = csearch_dur / len(self.concurrencies) if each_conc_search_dur < 30: - warning_msg = f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}." + warning_msg = ( + f"Results might be inaccurate, duration[{csearch_dur:.4f}] left for conc-search is too short, ", + f"total available dur={total_dur_between_stages}, serial_search_cost={ssearch_dur}.", + ) log.warning(warning_msg) # The last stage else: each_conc_search_dur = 60 - log.info(f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}") + log.info( + f"[{target_batch}/{total_batch}] Concurrent search - {perc}% start, dur={each_conc_search_dur:.4f}", + ) max_qps = self.run_by_dur(each_conc_search_dur) result.append((perc, max_qps, recall, ndcg, p99_latency)) diff --git a/vectordb_bench/backend/runner/serial_runner.py b/vectordb_bench/backend/runner/serial_runner.py index 13270e21a..7eb59432b 100644 --- a/vectordb_bench/backend/runner/serial_runner.py +++ b/vectordb_bench/backend/runner/serial_runner.py @@ -1,20 +1,21 @@ -import time -import logging -import traceback import concurrent -import multiprocessing as mp +import logging import math -import psutil +import multiprocessing as mp +import time +import traceback import numpy as np import pandas as pd +import psutil -from ..clients import api +from vectordb_bench.backend.dataset import DatasetManager + +from ... import config from ...metric import calc_ndcg, calc_recall, get_ideal_dcg from ...models import LoadTimeoutError, PerformanceTimeoutError from .. import utils -from ... import config -from vectordb_bench.backend.dataset import DatasetManager +from ..clients import api NUM_PER_BATCH = config.NUM_PER_BATCH LOAD_MAX_TRY_COUNT = 10 @@ -22,9 +23,16 @@ log = logging.getLogger(__name__) + class SerialInsertRunner: - def __init__(self, db: api.VectorDB, dataset: DatasetManager, normalize: bool, timeout: float | None = None): - self.timeout = timeout if isinstance(timeout, (int, float)) else None + def __init__( + self, + db: api.VectorDB, + dataset: DatasetManager, + normalize: bool, + timeout: float | None = None, + ): + self.timeout = timeout if isinstance(timeout, int | float) else None self.dataset = dataset self.db = db self.normalize = normalize @@ -32,18 +40,20 @@ def __init__(self, db: api.VectorDB, dataset: DatasetManager, normalize: bool, t def task(self) -> int: count = 0 with self.db.init(): - log.info(f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}") + log.info( + f"({mp.current_process().name:16}) Start inserting embeddings in batch {config.NUM_PER_BATCH}", + ) start = time.perf_counter() for data_df in self.dataset: - all_metadata = data_df['id'].tolist() + all_metadata = data_df["id"].tolist() - emb_np = np.stack(data_df['emb']) + emb_np = np.stack(data_df["emb"]) if self.normalize: log.debug("normalize the 100k train data") all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist() else: all_embeddings = emb_np.tolist() - del(emb_np) + del emb_np log.debug(f"batch dataset size: {len(all_embeddings)}, {len(all_metadata)}") insert_count, error = self.db.insert_embeddings( @@ -56,30 +66,41 @@ def task(self) -> int: assert insert_count == len(all_metadata) count += insert_count if count % 100_000 == 0: - log.info(f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB") + log.info( + f"({mp.current_process().name:16}) Loaded {count} embeddings into VectorDB", + ) - log.info(f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, dur={time.perf_counter()-start}") + log.info( + f"({mp.current_process().name:16}) Finish loading all dataset into VectorDB, ", + f"dur={time.perf_counter()-start}", + ) return count - def endless_insert_data(self, all_embeddings, all_metadata, left_id: int = 0) -> int: + def endless_insert_data(self, all_embeddings: list, all_metadata: list, left_id: int = 0) -> int: with self.db.init(): # unique id for endlessness insertion - all_metadata = [i+left_id for i in all_metadata] + all_metadata = [i + left_id for i in all_metadata] - NUM_BATCHES = math.ceil(len(all_embeddings)/NUM_PER_BATCH) - log.info(f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}") + num_batches = math.ceil(len(all_embeddings) / NUM_PER_BATCH) + log.info( + f"({mp.current_process().name:16}) Start inserting {len(all_embeddings)} ", + f"embeddings in batch {NUM_PER_BATCH}", + ) count = 0 - for batch_id in range(NUM_BATCHES): + for batch_id in range(num_batches): retry_count = 0 already_insert_count = 0 - metadata = all_metadata[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH] - embeddings = all_embeddings[batch_id*NUM_PER_BATCH : (batch_id+1)*NUM_PER_BATCH] + metadata = all_metadata[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH] + embeddings = all_embeddings[batch_id * NUM_PER_BATCH : (batch_id + 1) * NUM_PER_BATCH] - log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Start inserting {len(metadata)} embeddings") + log.debug( + f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ", + f"Start inserting {len(metadata)} embeddings", + ) while retry_count < LOAD_MAX_TRY_COUNT: insert_count, error = self.db.insert_embeddings( - embeddings=embeddings[already_insert_count :], - metadata=metadata[already_insert_count :], + embeddings=embeddings[already_insert_count:], + metadata=metadata[already_insert_count:], ) already_insert_count += insert_count if error is not None: @@ -91,17 +112,26 @@ def endless_insert_data(self, all_embeddings, all_metadata, left_id: int = 0) -> raise error else: break - log.debug(f"({mp.current_process().name:16}) batch [{batch_id:3}/{NUM_BATCHES}], Finish inserting {len(metadata)} embeddings") + log.debug( + f"({mp.current_process().name:16}) batch [{batch_id:3}/{num_batches}], ", + f"Finish inserting {len(metadata)} embeddings", + ) assert already_insert_count == len(metadata) count += already_insert_count - log.info(f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in batch {NUM_PER_BATCH}") + log.info( + f"({mp.current_process().name:16}) Finish inserting {len(all_embeddings)} embeddings in ", + f"batch {NUM_PER_BATCH}", + ) return count @utils.time_it def _insert_all_batches(self) -> int: """Performance case only""" - with concurrent.futures.ProcessPoolExecutor(mp_context=mp.get_context('spawn'), max_workers=1) as executor: + with concurrent.futures.ProcessPoolExecutor( + mp_context=mp.get_context("spawn"), + max_workers=1, + ) as executor: future = executor.submit(self.task) try: count = future.result(timeout=self.timeout) @@ -121,8 +151,11 @@ def run_endlessness(self) -> int: """run forever util DB raises exception or crash""" # datasets for load tests are quite small, can fit into memory # only 1 file - data_df = [data_df for data_df in self.dataset][0] - all_embeddings, all_metadata = np.stack(data_df["emb"]).tolist(), data_df['id'].tolist() + data_df = next(iter(self.dataset)) + all_embeddings, all_metadata = ( + np.stack(data_df["emb"]).tolist(), + data_df["id"].tolist(), + ) start_time = time.perf_counter() max_load_count, times = 0, 0 @@ -130,18 +163,26 @@ def run_endlessness(self) -> int: with self.db.init(): self.db.ready_to_load() while time.perf_counter() - start_time < self.timeout: - count = self.endless_insert_data(all_embeddings, all_metadata, left_id=max_load_count) + count = self.endless_insert_data( + all_embeddings, + all_metadata, + left_id=max_load_count, + ) max_load_count += count times += 1 - log.info(f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, {max_load_count}") + log.info( + f"Loaded {times} entire dataset, current max load counts={utils.numerize(max_load_count)}, ", + f"{max_load_count}", + ) except Exception as e: - log.info(f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, {max_load_count}, err={e}") + log.info( + f"Capacity case load reach limit, insertion counts={utils.numerize(max_load_count)}, ", + f"{max_load_count}, err={e}", + ) traceback.print_exc() return max_load_count else: - msg = f"capacity case load timeout in {self.timeout}s" - log.info(msg) - raise LoadTimeoutError(msg) + raise LoadTimeoutError(self.timeout) def run(self) -> int: count, dur = self._insert_all_batches() @@ -168,7 +209,9 @@ def __init__( self.ground_truth = ground_truth def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: - log.info(f"{mp.current_process().name:14} start search the entire test_data to get recall and latency") + log.info( + f"{mp.current_process().name:14} start search the entire test_data to get recall and latency", + ) with self.db.init(): test_data, ground_truth = args ideal_dcg = get_ideal_dcg(self.k) @@ -193,13 +236,15 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: latencies.append(time.perf_counter() - s) - gt = ground_truth['neighbors_id'][idx] - recalls.append(calc_recall(self.k, gt[:self.k], results)) - ndcgs.append(calc_ndcg(gt[:self.k], results, ideal_dcg)) - + gt = ground_truth["neighbors_id"][idx] + recalls.append(calc_recall(self.k, gt[: self.k], results)) + ndcgs.append(calc_ndcg(gt[: self.k], results, ideal_dcg)) if len(latencies) % 100 == 0: - log.debug(f"({mp.current_process().name:14}) search_count={len(latencies):3}, latest_latency={latencies[-1]}, latest recall={recalls[-1]}") + log.debug( + f"({mp.current_process().name:14}) search_count={len(latencies):3}, ", + f"latest_latency={latencies[-1]}, latest recall={recalls[-1]}", + ) avg_latency = round(np.mean(latencies), 4) avg_recall = round(np.mean(recalls), 4) @@ -213,16 +258,14 @@ def search(self, args: tuple[list, pd.DataFrame]) -> tuple[float, float, float]: f"avg_recall={avg_recall}, " f"avg_ndcg={avg_ndcg}," f"avg_latency={avg_latency}, " - f"p99={p99}" - ) + f"p99={p99}", + ) return (avg_recall, avg_ndcg, p99) - def _run_in_subprocess(self) -> tuple[float, float]: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: future = executor.submit(self.search, (self.test_data, self.ground_truth)) - result = future.result() - return result + return future.result() @utils.time_it def run(self) -> tuple[float, float, float]: diff --git a/vectordb_bench/backend/runner/util.py b/vectordb_bench/backend/runner/util.py index ba1888167..50e91b04b 100644 --- a/vectordb_bench/backend/runner/util.py +++ b/vectordb_bench/backend/runner/util.py @@ -1,13 +1,14 @@ import logging -from pandas import DataFrame import numpy as np +from pandas import DataFrame log = logging.getLogger(__name__) + def get_data(data_df: DataFrame, normalize: bool) -> tuple[list[list[float]], list[str]]: - all_metadata = data_df['id'].tolist() - emb_np = np.stack(data_df['emb']) + all_metadata = data_df["id"].tolist() + emb_np = np.stack(data_df["emb"]) if normalize: log.debug("normalize the 100k train data") all_embeddings = (emb_np / np.linalg.norm(emb_np, axis=1)[:, np.newaxis]).tolist() diff --git a/vectordb_bench/backend/task_runner.py b/vectordb_bench/backend/task_runner.py index 568152ae0..e24d74f03 100644 --- a/vectordb_bench/backend/task_runner.py +++ b/vectordb_bench/backend/task_runner.py @@ -1,24 +1,20 @@ +import concurrent import logging -import psutil import traceback -import concurrent -import numpy as np from enum import Enum, auto -from . import utils -from .cases import Case, CaseLabel -from ..base import BaseModel -from ..models import TaskConfig, PerformanceTimeoutError, TaskStage +import numpy as np +import psutil -from .clients import ( - api, - MetricType -) -from ..metric import Metric -from .runner import MultiProcessingSearchRunner -from .runner import SerialSearchRunner, SerialInsertRunner -from .data_source import DatasetSource +from vectordb_bench.base import BaseModel +from vectordb_bench.metric import Metric +from vectordb_bench.models import PerformanceTimeoutError, TaskConfig, TaskStage +from . import utils +from .cases import Case, CaseLabel +from .clients import MetricType, api +from .data_source import DatasetSource +from .runner import MultiProcessingSearchRunner, SerialInsertRunner, SerialSearchRunner log = logging.getLogger(__name__) @@ -53,24 +49,39 @@ class CaseRunner(BaseModel): search_runner: MultiProcessingSearchRunner | None = None final_search_runner: MultiProcessingSearchRunner | None = None - def __eq__(self, obj): + def __eq__(self, obj: any): if isinstance(obj, CaseRunner): - return self.ca.label == CaseLabel.Performance and \ - self.config.db == obj.config.db and \ - self.config.db_case_config == obj.config.db_case_config and \ - self.ca.dataset == obj.ca.dataset + return ( + self.ca.label == CaseLabel.Performance + and self.config.db == obj.config.db + and self.config.db_case_config == obj.config.db_case_config + and self.ca.dataset == obj.ca.dataset + ) return False def display(self) -> dict: - c_dict = self.ca.dict(include={'label':True, 'filters': True,'dataset':{'data': {'name': True, 'size': True, 'dim': True, 'metric_type': True, 'label': True}} }) - c_dict['db'] = self.config.db_name + c_dict = self.ca.dict( + include={ + "label": True, + "filters": True, + "dataset": { + "data": { + "name": True, + "size": True, + "dim": True, + "metric_type": True, + "label": True, + }, + }, + }, + ) + c_dict["db"] = self.config.db_name return c_dict @property def normalize(self) -> bool: assert self.db - return self.db.need_normalize_cosine() and \ - self.ca.dataset.data.metric_type == MetricType.COSINE + return self.db.need_normalize_cosine() and self.ca.dataset.data.metric_type == MetricType.COSINE def init_db(self, drop_old: bool = True) -> None: db_cls = self.config.db.init_cls @@ -80,8 +91,7 @@ def init_db(self, drop_old: bool = True) -> None: db_config=self.config.db_config.to_dict(), db_case_config=self.config.db_case_config, drop_old=drop_old, - ) # type:ignore - + ) def _pre_run(self, drop_old: bool = True): try: @@ -89,12 +99,9 @@ def _pre_run(self, drop_old: bool = True): self.ca.dataset.prepare(self.dataset_source, filters=self.ca.filter_rate) except ModuleNotFoundError as e: log.warning( - f"pre run case error: please install client for db: {self.config.db}, error={e}" + f"pre run case error: please install client for db: {self.config.db}, error={e}", ) raise e from None - except Exception as e: - log.warning(f"pre run case error: {e}") - raise e from None def run(self, drop_old: bool = True) -> Metric: log.info("Starting run") @@ -103,12 +110,11 @@ def run(self, drop_old: bool = True) -> Metric: if self.ca.label == CaseLabel.Load: return self._run_capacity_case() - elif self.ca.label == CaseLabel.Performance: + if self.ca.label == CaseLabel.Performance: return self._run_perf_case(drop_old) - else: - msg = f"unknown case type: {self.ca.label}" - log.warning(msg) - raise ValueError(msg) + msg = f"unknown case type: {self.ca.label}" + log.warning(msg) + raise ValueError(msg) def _run_capacity_case(self) -> Metric: """run capacity cases @@ -120,7 +126,10 @@ def _run_capacity_case(self) -> Metric: log.info("Start capacity case") try: runner = SerialInsertRunner( - self.db, self.ca.dataset, self.normalize, self.ca.load_timeout + self.db, + self.ca.dataset, + self.normalize, + self.ca.load_timeout, ) count = runner.run_endlessness() except Exception as e: @@ -128,7 +137,7 @@ def _run_capacity_case(self) -> Metric: raise e from None else: log.info( - f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}" + f"Capacity case loading dataset reaches VectorDB's limit: max capacity = {count}", ) return Metric(max_load_count=count) @@ -138,7 +147,7 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: Returns: Metric: load_duration, recall, serial_latency_p99, and, qps """ - ''' + """ if drop_old: _, load_dur = self._load_train_data() build_dur = self._optimize() @@ -153,38 +162,40 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list = self._conc_search() m.recall, m.serial_latency_p99 = self._serial_search() - ''' + """ log.info("Start performance case") try: m = Metric() if drop_old: if TaskStage.LOAD in self.config.stages: - # self._load_train_data() _, load_dur = self._load_train_data() build_dur = self._optimize() m.load_duration = round(load_dur + build_dur, 4) log.info( f"Finish loading the entire dataset into VectorDB," f" insert_duration={load_dur}, optimize_duration={build_dur}" - f" load_duration(insert + optimize) = {m.load_duration}" + f" load_duration(insert + optimize) = {m.load_duration}", ) else: log.info("Data loading skipped") - if ( - TaskStage.SEARCH_SERIAL in self.config.stages - or TaskStage.SEARCH_CONCURRENT in self.config.stages - ): + if TaskStage.SEARCH_SERIAL in self.config.stages or TaskStage.SEARCH_CONCURRENT in self.config.stages: self._init_search_runner() if TaskStage.SEARCH_CONCURRENT in self.config.stages: search_results = self._conc_search() - m.qps, m.conc_num_list, m.conc_qps_list, m.conc_latency_p99_list, m.conc_latency_avg_list = search_results + ( + m.qps, + m.conc_num_list, + m.conc_qps_list, + m.conc_latency_p99_list, + m.conc_latency_avg_list, + ) = search_results if TaskStage.SEARCH_SERIAL in self.config.stages: search_results = self._serial_search() - ''' + """ m.recall = search_results.recall m.serial_latencies = search_results.serial_latencies - ''' + """ m.recall, m.ndcg, m.serial_latency_p99 = search_results except Exception as e: @@ -199,7 +210,12 @@ def _run_perf_case(self, drop_old: bool = True) -> Metric: def _load_train_data(self): """Insert train data and get the insert_duration""" try: - runner = SerialInsertRunner(self.db, self.ca.dataset, self.normalize, self.ca.load_timeout) + runner = SerialInsertRunner( + self.db, + self.ca.dataset, + self.normalize, + self.ca.load_timeout, + ) runner.run() except Exception as e: raise e from None @@ -215,11 +231,12 @@ def _serial_search(self) -> tuple[float, float, float]: """ try: results, _ = self.serial_search_runner.run() - return results except Exception as e: - log.warning(f"search error: {str(e)}, {e}") + log.warning(f"search error: {e!s}, {e}") self.stop() - raise e from None + raise e from e + else: + return results def _conc_search(self): """Performance concurrency tests, search the test data endlessness @@ -231,7 +248,7 @@ def _conc_search(self): try: return self.search_runner.run() except Exception as e: - log.warning(f"search error: {str(e)}, {e}") + log.warning(f"search error: {e!s}, {e}") raise e from None finally: self.stop() @@ -250,7 +267,7 @@ def _optimize(self) -> float: log.warning(f"VectorDB optimize timeout in {self.ca.optimize_timeout}") for pid, _ in executor._processes.items(): psutil.Process(pid).kill() - raise PerformanceTimeoutError("Performance case optimize timeout") from e + raise PerformanceTimeoutError from e except Exception as e: log.warning(f"VectorDB optimize error: {e}") raise e from None @@ -286,6 +303,16 @@ def stop(self): self.search_runner.stop() +DATA_FORMAT = " %-14s | %-12s %-20s %7s | %-10s" +TITLE_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s") % ( + "DB", + "CaseType", + "Dataset", + "Filter", + "task_label", +) + + class TaskRunner(BaseModel): run_id: str task_label: str @@ -304,18 +331,8 @@ def _get_num_by_status(self, status: RunningStatus) -> int: return sum([1 for c in self.case_runners if c.status == status]) def display(self) -> None: - DATA_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s") - TITLE_FORMAT = (" %-14s | %-12s %-20s %7s | %-10s") % ( - "DB", "CaseType", "Dataset", "Filter", "task_label") - fmt = [TITLE_FORMAT] - fmt.append(DATA_FORMAT%( - "-"*11, - "-"*12, - "-"*20, - "-"*7, - "-"*7 - )) + fmt.append(DATA_FORMAT % ("-" * 11, "-" * 12, "-" * 20, "-" * 7, "-" * 7)) for f in self.case_runners: if f.ca.filter_rate != 0.0: @@ -326,13 +343,16 @@ def display(self) -> None: filters = "None" ds_str = f"{f.ca.dataset.data.name}-{f.ca.dataset.data.label}-{utils.numerize(f.ca.dataset.data.size)}" - fmt.append(DATA_FORMAT%( - f.config.db_name, - f.ca.label.name, - ds_str, - filters, - self.task_label, - )) + fmt.append( + DATA_FORMAT + % ( + f.config.db_name, + f.ca.label.name, + ds_str, + filters, + self.task_label, + ), + ) tmp_logger = logging.getLogger("no_color") for f in fmt: diff --git a/vectordb_bench/backend/utils.py b/vectordb_bench/backend/utils.py index ea11e6461..86c4faf5e 100644 --- a/vectordb_bench/backend/utils.py +++ b/vectordb_bench/backend/utils.py @@ -2,7 +2,7 @@ from functools import wraps -def numerize(n) -> str: +def numerize(n: int) -> str: """display positive number n for readability Examples: @@ -16,32 +16,34 @@ def numerize(n) -> str: "K": 1e6, "M": 1e9, "B": 1e12, - "END": float('inf'), + "END": float("inf"), } display_n, sufix = n, "" for s, base in sufix2upbound.items(): # number >= 1000B will alway have sufix 'B' if s == "END": - display_n = int(n/1e9) + display_n = int(n / 1e9) sufix = "B" break if n < base: sufix = "" if s == "EMPTY" else s - display_n = int(n/(base/1e3)) + display_n = int(n / (base / 1e3)) break return f"{display_n}{sufix}" -def time_it(func): - """ returns result and elapsed time""" +def time_it(func: any): + """returns result and elapsed time""" + @wraps(func) def inner(*args, **kwargs): pref = time.perf_counter() result = func(*args, **kwargs) delta = time.perf_counter() - pref return result, delta + return inner @@ -62,14 +64,19 @@ def compose_train_files(train_count: int, use_shuffled: bool) -> list[str]: return train_files -def compose_gt_file(filters: int | float | str | None = None) -> str: +ONE_PERCENT = 0.01 +NINETY_NINE_PERCENT = 0.99 + + +def compose_gt_file(filters: float | str | None = None) -> str: if filters is None: return "neighbors.parquet" - if filters == 0.01: + if filters == ONE_PERCENT: return "neighbors_head_1p.parquet" - if filters == 0.99: + if filters == NINETY_NINE_PERCENT: return "neighbors_tail_1p.parquet" - raise ValueError(f"Filters not supported: {filters}") + msg = f"Filters not supported: {filters}" + raise ValueError(msg) diff --git a/vectordb_bench/base.py b/vectordb_bench/base.py index 3c71fb5a7..502d5fa49 100644 --- a/vectordb_bench/base.py +++ b/vectordb_bench/base.py @@ -3,4 +3,3 @@ class BaseModel(PydanticBaseModel, arbitrary_types_allowed=True): pass - diff --git a/vectordb_bench/cli/cli.py b/vectordb_bench/cli/cli.py index e5b9a5fe2..3bb7763d8 100644 --- a/vectordb_bench/cli/cli.py +++ b/vectordb_bench/cli/cli.py @@ -1,27 +1,27 @@ import logging +import os import time +from collections.abc import Callable from concurrent.futures import wait from datetime import datetime from pprint import pformat from typing import ( Annotated, - Callable, - List, - Optional, - Type, + Any, TypedDict, Unpack, get_origin, get_type_hints, - Dict, - Any, ) + import click +from yaml import load from vectordb_bench.backend.clients.api import MetricType + from .. import config from ..backend.clients import DB -from ..interface import benchMarkRunner, global_result_future +from ..interface import benchmark_runner, global_result_future from ..models import ( CaseConfig, CaseType, @@ -31,8 +31,7 @@ TaskConfig, TaskStage, ) -import os -from yaml import load + try: from yaml import CLoader as Loader except ImportError: @@ -46,8 +45,8 @@ def click_get_defaults_from_file(ctx, param, value): else: input_file = os.path.join(config.CONFIG_LOCAL_DIR, value) try: - with open(input_file, 'r') as f: - _config: Dict[str, Dict[str, Any]] = load(f.read(), Loader=Loader) + with open(input_file) as f: + _config: dict[str, dict[str, Any]] = load(f.read(), Loader=Loader) ctx.default_map = _config.get(ctx.command.name, {}) except Exception as e: raise click.BadParameter(f"Failed to load config file: {e}") @@ -55,7 +54,7 @@ def click_get_defaults_from_file(ctx, param, value): def click_parameter_decorators_from_typed_dict( - typed_dict: Type, + typed_dict: type, ) -> Callable[[click.decorators.FC], click.decorators.FC]: """A convenience method decorator that will read in a TypedDict with parameters defined by Annotated types. from .models import CaseConfig, CaseType, DBCaseConfig, DBConfig, TaskConfig, TaskStage @@ -91,15 +90,12 @@ def foo(**parameters: Unpack[FooTypedDict]): decorators = [] for _, t in get_type_hints(typed_dict, include_extras=True).items(): assert get_origin(t) is Annotated - if ( - len(t.__metadata__) == 1 - and t.__metadata__[0].__module__ == "click.decorators" - ): + if len(t.__metadata__) == 1 and t.__metadata__[0].__module__ == "click.decorators": # happy path -- only accept Annotated[..., Union[click.option,click.argument,...]] with no additional metadata defined (len=1) decorators.append(t.__metadata__[0]) else: raise RuntimeError( - "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring" + "Click-TypedDict decorator parsing must only contain root type and a click decorator like click.option. See docstring", ) def deco(f): @@ -132,11 +128,11 @@ def parse_task_stages( load: bool, search_serial: bool, search_concurrent: bool, -) -> List[TaskStage]: +) -> list[TaskStage]: stages = [] if load and not drop_old: raise RuntimeError("Dropping old data cannot be skipped if loading data") - elif drop_old and not load: + if drop_old and not load: raise RuntimeError("Load cannot be skipped if dropping old data") if drop_old: stages.append(TaskStage.DROP_OLD) @@ -149,12 +145,19 @@ def parse_task_stages( return stages -def check_custom_case_parameters(ctx, param, value): - if ctx.params.get("case_type") == "PerformanceCustomDataset": - if value is None: - raise click.BadParameter("Custom case parameters\ - \n--custom-case-name\n--custom-dataset-name\n--custom-dataset-dir\n--custom-dataset-size \ - \n--custom-dataset-dim\n--custom-dataset-file-count\n are required") +# ruff: noqa +def check_custom_case_parameters(ctx: any, param: any, value: any): + if ctx.params.get("case_type") == "PerformanceCustomDataset" and value is None: + raise click.BadParameter( + """ Custom case parameters +--custom-case-name +--custom-dataset-name +--custom-dataset-dir +--custom-dataset-sizes +--custom-dataset-dim +--custom-dataset-file-count +are required """, + ) return value @@ -175,7 +178,7 @@ def get_custom_case_config(parameters: dict) -> dict: "file_count": parameters["custom_dataset_file_count"], "use_shuffled": parameters["custom_dataset_use_shuffled"], "with_gt": parameters["custom_dataset_with_gt"], - } + }, } return custom_case_config @@ -186,12 +189,14 @@ def get_custom_case_config(parameters: dict) -> dict: class CommonTypedDict(TypedDict): config_file: Annotated[ bool, - click.option('--config-file', - type=click.Path(), - callback=click_get_defaults_from_file, - is_eager=True, - expose_value=False, - help='Read configuration from yaml file'), + click.option( + "--config-file", + type=click.Path(), + callback=click_get_defaults_from_file, + is_eager=True, + expose_value=False, + help="Read configuration from yaml file", + ), ] drop_old: Annotated[ bool, @@ -246,9 +251,11 @@ class CommonTypedDict(TypedDict): db_label: Annotated[ str, click.option( - "--db-label", type=str, help="Db label, default: date in ISO format", + "--db-label", + type=str, + help="Db label, default: date in ISO format", show_default=True, - default=datetime.now().isoformat() + default=datetime.now().isoformat(), ), ] dry_run: Annotated[ @@ -282,7 +289,7 @@ class CommonTypedDict(TypedDict): ), ] num_concurrency: Annotated[ - List[str], + list[str], click.option( "--num-concurrency", type=str, @@ -298,7 +305,7 @@ class CommonTypedDict(TypedDict): "--custom-case-name", help="Custom dataset case name", callback=check_custom_case_parameters, - ) + ), ] custom_case_description: Annotated[ str, @@ -307,7 +314,7 @@ class CommonTypedDict(TypedDict): help="Custom dataset case description", default="This is a customized dataset.", show_default=True, - ) + ), ] custom_case_load_timeout: Annotated[ int, @@ -316,7 +323,7 @@ class CommonTypedDict(TypedDict): help="Custom dataset case load timeout", default=36000, show_default=True, - ) + ), ] custom_case_optimize_timeout: Annotated[ int, @@ -325,7 +332,7 @@ class CommonTypedDict(TypedDict): help="Custom dataset case optimize timeout", default=36000, show_default=True, - ) + ), ] custom_dataset_name: Annotated[ str, @@ -397,60 +404,60 @@ class CommonTypedDict(TypedDict): class HNSWBaseTypedDict(TypedDict): - m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m")] + m: Annotated[int | None, click.option("--m", type=int, help="hnsw m")] ef_construction: Annotated[ - Optional[int], + int | None, click.option("--ef-construction", type=int, help="hnsw ef-construction"), ] class HNSWBaseRequiredTypedDict(TypedDict): - m: Annotated[Optional[int], click.option("--m", type=int, help="hnsw m", required=True)] + m: Annotated[int | None, click.option("--m", type=int, help="hnsw m", required=True)] ef_construction: Annotated[ - Optional[int], + int | None, click.option("--ef-construction", type=int, help="hnsw ef-construction", required=True), ] class HNSWFlavor1(HNSWBaseTypedDict): ef_search: Annotated[ - Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", is_eager=True) + int | None, + click.option("--ef-search", type=int, help="hnsw ef-search", is_eager=True), ] class HNSWFlavor2(HNSWBaseTypedDict): ef_runtime: Annotated[ - Optional[int], click.option("--ef-runtime", type=int, help="hnsw ef-runtime") + int | None, + click.option("--ef-runtime", type=int, help="hnsw ef-runtime"), ] class HNSWFlavor3(HNSWBaseRequiredTypedDict): ef_search: Annotated[ - Optional[int], click.option("--ef-search", type=int, help="hnsw ef-search", required=True) + int | None, + click.option("--ef-search", type=int, help="hnsw ef-search", required=True), ] class IVFFlatTypedDict(TypedDict): - lists: Annotated[ - Optional[int], click.option("--lists", type=int, help="ivfflat lists") - ] - probes: Annotated[ - Optional[int], click.option("--probes", type=int, help="ivfflat probes") - ] + lists: Annotated[int | None, click.option("--lists", type=int, help="ivfflat lists")] + probes: Annotated[int | None, click.option("--probes", type=int, help="ivfflat probes")] class IVFFlatTypedDictN(TypedDict): nlist: Annotated[ - Optional[int], click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True) + int | None, + click.option("--lists", "nlist", type=int, help="ivfflat lists", required=True), ] nprobe: Annotated[ - Optional[int], click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True) + int | None, + click.option("--probes", "nprobe", type=int, help="ivfflat probes", required=True), ] @click.group() -def cli(): - ... +def cli(): ... def run( @@ -482,9 +489,7 @@ def run( custom_case=get_custom_case_config(parameters), ), stages=parse_task_stages( - ( - False if not parameters["load"] else parameters["drop_old"] - ), # only drop old data if loading new data + (False if not parameters["load"] else parameters["drop_old"]), # only drop old data if loading new data parameters["load"], parameters["search_serial"], parameters["search_concurrent"], @@ -493,7 +498,7 @@ def run( log.info(f"Task:\n{pformat(task)}\n") if not parameters["dry_run"]: - benchMarkRunner.run([task]) + benchmark_runner.run([task]) time.sleep(5) if global_result_future: wait([global_result_future]) diff --git a/vectordb_bench/cli/vectordbbench.py b/vectordb_bench/cli/vectordbbench.py index f9ad69ceb..5e3798691 100644 --- a/vectordb_bench/cli/vectordbbench.py +++ b/vectordb_bench/cli/vectordbbench.py @@ -1,16 +1,15 @@ -from ..backend.clients.pgvector.cli import PgVectorHNSW +from ..backend.clients.alloydb.cli import AlloyDBScaNN +from ..backend.clients.aws_opensearch.cli import AWSOpenSearch +from ..backend.clients.memorydb.cli import MemoryDB +from ..backend.clients.milvus.cli import MilvusAutoIndex +from ..backend.clients.pgdiskann.cli import PgDiskAnn from ..backend.clients.pgvecto_rs.cli import PgVectoRSHNSW, PgVectoRSIVFFlat +from ..backend.clients.pgvector.cli import PgVectorHNSW from ..backend.clients.pgvectorscale.cli import PgVectorScaleDiskAnn -from ..backend.clients.pgdiskann.cli import PgDiskAnn from ..backend.clients.redis.cli import Redis -from ..backend.clients.memorydb.cli import MemoryDB from ..backend.clients.test.cli import Test from ..backend.clients.weaviate_cloud.cli import Weaviate from ..backend.clients.zilliz_cloud.cli import ZillizAutoIndex -from ..backend.clients.milvus.cli import MilvusAutoIndex -from ..backend.clients.aws_opensearch.cli import AWSOpenSearch -from ..backend.clients.alloydb.cli import AlloyDBScaNN - from .cli import cli cli.add_command(PgVectorHNSW) diff --git a/vectordb_bench/frontend/components/check_results/charts.py b/vectordb_bench/frontend/components/check_results/charts.py index 9e869b479..0e74d2752 100644 --- a/vectordb_bench/frontend/components/check_results/charts.py +++ b/vectordb_bench/frontend/components/check_results/charts.py @@ -1,8 +1,7 @@ -from vectordb_bench.backend.cases import Case from vectordb_bench.frontend.components.check_results.expanderStyle import ( initMainExpanderStyle, ) -from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap +from vectordb_bench.metric import metric_order, isLowerIsBetterMetric, metric_unit_map from vectordb_bench.frontend.config.styles import * from vectordb_bench.models import ResultLabel import plotly.express as px @@ -21,9 +20,7 @@ def drawCharts(st, allData, failedTasks, caseNames: list[str]): def showFailedDBs(st, errorDBs): failedDBs = [db for db, label in errorDBs.items() if label == ResultLabel.FAILED] - timeoutDBs = [ - db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE - ] + timeoutDBs = [db for db, label in errorDBs.items() if label == ResultLabel.OUTOFRANGE] showFailedText(st, "Failed", failedDBs) showFailedText(st, "Timeout", timeoutDBs) @@ -41,7 +38,7 @@ def drawChart(data, st, key_prefix: str): metricsSet = set() for d in data: metricsSet = metricsSet.union(d["metricsSet"]) - showMetrics = [metric for metric in metricOrder if metric in metricsSet] + showMetrics = [metric for metric in metric_order if metric in metricsSet] for i, metric in enumerate(showMetrics): container = st.container() @@ -72,9 +69,7 @@ def getLabelToShapeMap(data): else: usedShapes.add(labelIndexMap[label] % len(PATTERN_SHAPES)) - labelToShapeMap = { - label: getPatternShape(index) for label, index in labelIndexMap.items() - } + labelToShapeMap = {label: getPatternShape(index) for label, index in labelIndexMap.items()} return labelToShapeMap @@ -96,11 +91,9 @@ def drawMetricChart(data, metric, st, key: str): xpadding = (xmax - xmin) / 16 xpadding_multiplier = 1.8 xrange = [xmin, xmax + xpadding * xpadding_multiplier] - unit = metricUnitMap.get(metric, "") + unit = metric_unit_map.get(metric, "") labelToShapeMap = getLabelToShapeMap(dataWithMetric) - categoryorder = ( - "total descending" if isLowerIsBetterMetric(metric) else "total ascending" - ) + categoryorder = "total descending" if isLowerIsBetterMetric(metric) else "total ascending" fig = px.bar( dataWithMetric, x=metric, @@ -137,18 +130,14 @@ def drawMetricChart(data, metric, st, key: str): color="#333", size=12, ), - marker=dict( - pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7) - ), + marker=dict(pattern=dict(fillmode="overlay", fgcolor="#fff", fgopacity=1, size=7)), texttemplate="%{x:,.4~r}" + unit, ) fig.update_layout( margin=dict(l=0, r=0, t=48, b=12, pad=8), bargap=0.25, showlegend=False, - legend=dict( - orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title="" - ), + legend=dict(orientation="h", yanchor="bottom", y=1, xanchor="right", x=1, title=""), # legend=dict(orientation="v", title=""), yaxis={"categoryorder": categoryorder}, title=dict( diff --git a/vectordb_bench/frontend/components/check_results/data.py b/vectordb_bench/frontend/components/check_results/data.py index b3cac21e1..94d1b4eab 100644 --- a/vectordb_bench/frontend/components/check_results/data.py +++ b/vectordb_bench/frontend/components/check_results/data.py @@ -1,6 +1,5 @@ from collections import defaultdict from dataclasses import asdict -from vectordb_bench.backend.cases import Case from vectordb_bench.metric import isLowerIsBetterMetric from vectordb_bench.models import CaseResult, ResultLabel @@ -24,10 +23,7 @@ def getFilterTasks( task for task in tasks if task.task_config.db_name in dbNames - and task.task_config.case_config.case_id.case_cls( - task.task_config.case_config.custom_case - ).name - in caseNames + and task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case).name in caseNames ] return filterTasks @@ -39,9 +35,7 @@ def mergeTasks(tasks: list[CaseResult]): db = task.task_config.db.value db_label = task.task_config.db_config.db_label or "" version = task.task_config.db_config.version or "" - case = task.task_config.case_config.case_id.case_cls( - task.task_config.case_config.custom_case - ) + case = task.task_config.case_config.case_id.case_cls(task.task_config.case_config.custom_case) dbCaseMetricsMap[db_name][case.name] = { "db": db, "db_label": db_label, @@ -86,9 +80,7 @@ def mergeTasks(tasks: list[CaseResult]): def mergeMetrics(metrics_1: dict, metrics_2: dict) -> dict: metrics = {**metrics_1} for key, value in metrics_2.items(): - metrics[key] = ( - getBetterMetric(key, value, metrics[key]) if key in metrics else value - ) + metrics[key] = getBetterMetric(key, value, metrics[key]) if key in metrics else value return metrics @@ -99,11 +91,7 @@ def getBetterMetric(metric, value_1, value_2): return value_2 if value_2 < 1e-7: return value_1 - return ( - min(value_1, value_2) - if isLowerIsBetterMetric(metric) - else max(value_1, value_2) - ) + return min(value_1, value_2) if isLowerIsBetterMetric(metric) else max(value_1, value_2) except Exception: return value_1 diff --git a/vectordb_bench/frontend/components/check_results/filters.py b/vectordb_bench/frontend/components/check_results/filters.py index e60efb2e1..129c1d5ae 100644 --- a/vectordb_bench/frontend/components/check_results/filters.py +++ b/vectordb_bench/frontend/components/check_results/filters.py @@ -20,23 +20,17 @@ def getshownData(results: list[TestResult], st): shownResults = getshownResults(results, st) showDBNames, showCaseNames = getShowDbsAndCases(shownResults, st) - shownData, failedTasks = getChartData( - shownResults, showDBNames, showCaseNames) + shownData, failedTasks = getChartData(shownResults, showDBNames, showCaseNames) return shownData, failedTasks, showCaseNames def getshownResults(results: list[TestResult], st) -> list[CaseResult]: resultSelectOptions = [ - result.task_label - if result.task_label != result.run_id - else f"res-{result.run_id[:4]}" - for result in results + result.task_label if result.task_label != result.run_id else f"res-{result.run_id[:4]}" for result in results ] if len(resultSelectOptions) == 0: - st.write( - "There are no results to display. Please wait for the task to complete or run a new task." - ) + st.write("There are no results to display. Please wait for the task to complete or run a new task.") return [] selectedResultSelectedOptions = st.multiselect( @@ -58,13 +52,12 @@ def getShowDbsAndCases(result: list[CaseResult], st) -> tuple[list[str], list[st allDbNames = list(set({res.task_config.db_name for res in result})) allDbNames.sort() allCases: list[Case] = [ - res.task_config.case_config.case_id.case_cls( - res.task_config.case_config.custom_case) - for res in result + res.task_config.case_config.case_id.case_cls(res.task_config.case_config.custom_case) for res in result ] allCaseNameSet = set({case.name for case in allCases}) - allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + \ - [case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER] + allCaseNames = [case_name for case_name in CASE_NAME_ORDER if case_name in allCaseNameSet] + [ + case_name for case_name in allCaseNameSet if case_name not in CASE_NAME_ORDER + ] # DB Filter dbFilterContainer = st.container() @@ -120,8 +113,7 @@ def filterView(container, header, options, col, optionLables=None): ) if optionLables is None: optionLables = options - isActive = {option: st.session_state[selectAllState] - for option in optionLables} + isActive = {option: st.session_state[selectAllState] for option in optionLables} for i, option in enumerate(optionLables): isActive[option] = columns[i % col].checkbox( optionLables[i], diff --git a/vectordb_bench/frontend/components/check_results/nav.py b/vectordb_bench/frontend/components/check_results/nav.py index ad070ab35..f95e43d7a 100644 --- a/vectordb_bench/frontend/components/check_results/nav.py +++ b/vectordb_bench/frontend/components/check_results/nav.py @@ -7,15 +7,15 @@ def NavToRunTest(st): navClick = st.button("Run Your Test   >") if navClick: switch_page("run test") - - + + def NavToQuriesPerDollar(st): st.subheader("Compare qps with price.") navClick = st.button("QP$ (Quries per Dollar)   >") if navClick: switch_page("quries_per_dollar") - - + + def NavToResults(st, key="nav-to-results"): navClick = st.button("<   Back to Results", key=key) if navClick: diff --git a/vectordb_bench/frontend/components/check_results/priceTable.py b/vectordb_bench/frontend/components/check_results/priceTable.py index 06d7a6a6b..f2c0ae001 100644 --- a/vectordb_bench/frontend/components/check_results/priceTable.py +++ b/vectordb_bench/frontend/components/check_results/priceTable.py @@ -7,9 +7,7 @@ def priceTable(container, data): - dbAndLabelSet = { - (d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value - } + dbAndLabelSet = {(d["db"], d["db_label"]) for d in data if d["db"] != DB.Milvus.value} dbAndLabelList = list(dbAndLabelSet) dbAndLabelList.sort() diff --git a/vectordb_bench/frontend/components/check_results/stPageConfig.py b/vectordb_bench/frontend/components/check_results/stPageConfig.py index c03cceab4..9521be285 100644 --- a/vectordb_bench/frontend/components/check_results/stPageConfig.py +++ b/vectordb_bench/frontend/components/check_results/stPageConfig.py @@ -9,10 +9,11 @@ def initResultsPageConfig(st): # initial_sidebar_state="collapsed", ) + def initRunTestPageConfig(st): st.set_page_config( page_title=PAGE_TITLE, page_icon=FAVICON, # layout="wide", initial_sidebar_state="collapsed", - ) \ No newline at end of file + ) diff --git a/vectordb_bench/frontend/components/concurrent/charts.py b/vectordb_bench/frontend/components/concurrent/charts.py index 11379c4b3..83e2961d6 100644 --- a/vectordb_bench/frontend/components/concurrent/charts.py +++ b/vectordb_bench/frontend/components/concurrent/charts.py @@ -14,24 +14,24 @@ def drawChartsByCase(allData, showCaseNames: list[str], st, latency_type: str): data = [ { "conc_num": caseData["conc_num_list"][i], - "qps": caseData["conc_qps_list"][i] - if 0 <= i < len(caseData["conc_qps_list"]) - else 0, - "latency_p99": caseData["conc_latency_p99_list"][i] * 1000 - if 0 <= i < len(caseData["conc_latency_p99_list"]) - else 0, - "latency_avg": caseData["conc_latency_avg_list"][i] * 1000 - if 0 <= i < len(caseData["conc_latency_avg_list"]) - else 0, + "qps": (caseData["conc_qps_list"][i] if 0 <= i < len(caseData["conc_qps_list"]) else 0), + "latency_p99": ( + caseData["conc_latency_p99_list"][i] * 1000 + if 0 <= i < len(caseData["conc_latency_p99_list"]) + else 0 + ), + "latency_avg": ( + caseData["conc_latency_avg_list"][i] * 1000 + if 0 <= i < len(caseData["conc_latency_avg_list"]) + else 0 + ), "db_name": caseData["db_name"], "db": caseData["db"], } for caseData in caseDataList for i in range(len(caseData["conc_num_list"])) ] - drawChart( - data, chartContainer, key=f"{caseName}-qps-p99", x_metric=latency_type - ) + drawChart(data, chartContainer, key=f"{caseName}-qps-p99", x_metric=latency_type) def getRange(metric, data, padding_multipliers): diff --git a/vectordb_bench/frontend/components/custom/displayCustomCase.py b/vectordb_bench/frontend/components/custom/displayCustomCase.py index 3f5266051..ac111c883 100644 --- a/vectordb_bench/frontend/components/custom/displayCustomCase.py +++ b/vectordb_bench/frontend/components/custom/displayCustomCase.py @@ -1,4 +1,3 @@ - from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig @@ -6,26 +5,33 @@ def displayCustomCase(customCase: CustomCaseConfig, st, key): columns = st.columns([1, 2]) customCase.dataset_config.name = columns[0].text_input( - "Name", key=f"{key}_name", value=customCase.dataset_config.name) + "Name", key=f"{key}_name", value=customCase.dataset_config.name + ) customCase.name = f"{customCase.dataset_config.name} (Performace Case)" customCase.dataset_config.dir = columns[1].text_input( - "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir) + "Folder Path", key=f"{key}_dir", value=customCase.dataset_config.dir + ) columns = st.columns(4) customCase.dataset_config.dim = columns[0].number_input( - "dim", key=f"{key}_dim", value=customCase.dataset_config.dim) + "dim", key=f"{key}_dim", value=customCase.dataset_config.dim + ) customCase.dataset_config.size = columns[1].number_input( - "size", key=f"{key}_size", value=customCase.dataset_config.size) + "size", key=f"{key}_size", value=customCase.dataset_config.size + ) customCase.dataset_config.metric_type = columns[2].selectbox( - "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"]) + "metric type", key=f"{key}_metric_type", options=["L2", "Cosine", "IP"] + ) customCase.dataset_config.file_count = columns[3].number_input( - "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count) + "train file count", key=f"{key}_file_count", value=customCase.dataset_config.file_count + ) columns = st.columns(4) customCase.dataset_config.use_shuffled = columns[0].checkbox( - "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled) + "use shuffled data", key=f"{key}_use_shuffled", value=customCase.dataset_config.use_shuffled + ) customCase.dataset_config.with_gt = columns[1].checkbox( - "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt) + "with groundtruth", key=f"{key}_with_gt", value=customCase.dataset_config.with_gt + ) - customCase.description = st.text_area( - "description", key=f"{key}_description", value=customCase.description) + customCase.description = st.text_area("description", key=f"{key}_description", value=customCase.description) diff --git a/vectordb_bench/frontend/components/custom/displaypPrams.py b/vectordb_bench/frontend/components/custom/displaypPrams.py index a300b45e1..b677e5909 100644 --- a/vectordb_bench/frontend/components/custom/displaypPrams.py +++ b/vectordb_bench/frontend/components/custom/displaypPrams.py @@ -1,5 +1,6 @@ def displayParams(st): - st.markdown(""" + st.markdown( + """ - `Folder Path` - The path to the folder containing all the files. Please ensure that all files in the folder are in the `Parquet` format. - Vectors data files: The file must be named `train.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. - Query test vectors: The file must be named `test.parquet` and should have two columns: `id` as an incrementing `int` and `emb` as an array of `float32`. @@ -8,4 +9,5 @@ def displayParams(st): - `Train File Count` - If the vector file is too large, you can consider splitting it into multiple files. The naming format for the split files should be `train-[index]-of-[file_count].parquet`. For example, `train-01-of-10.parquet` represents the second file (0-indexed) among 10 split files. - `Use Shuffled Data` - If you check this option, the vector data files need to be modified. VectorDBBench will load the data labeled with `shuffle`. For example, use `shuffle_train.parquet` instead of `train.parquet` and `shuffle_train-04-of-10.parquet` instead of `train-04-of-10.parquet`. The `id` column in the shuffled data can be in any order. -""") +""" + ) diff --git a/vectordb_bench/frontend/components/custom/getCustomConfig.py b/vectordb_bench/frontend/components/custom/getCustomConfig.py index ede664b6a..668bfb6d5 100644 --- a/vectordb_bench/frontend/components/custom/getCustomConfig.py +++ b/vectordb_bench/frontend/components/custom/getCustomConfig.py @@ -32,8 +32,7 @@ def get_custom_configs(): def save_custom_configs(custom_configs: list[CustomDatasetConfig]): with open(config.CUSTOM_CONFIG_DIR, "w") as f: - json.dump([custom_config.dict() - for custom_config in custom_configs], f, indent=4) + json.dump([custom_config.dict() for custom_config in custom_configs], f, indent=4) def generate_custom_case(): diff --git a/vectordb_bench/frontend/components/custom/initStyle.py b/vectordb_bench/frontend/components/custom/initStyle.py index 0e5129248..7ecbbf24e 100644 --- a/vectordb_bench/frontend/components/custom/initStyle.py +++ b/vectordb_bench/frontend/components/custom/initStyle.py @@ -12,4 +12,4 @@ def initStyle(st): */ """, unsafe_allow_html=True, - ) \ No newline at end of file + ) diff --git a/vectordb_bench/frontend/components/get_results/saveAsImage.py b/vectordb_bench/frontend/components/get_results/saveAsImage.py index 9820e9575..4be2baec1 100644 --- a/vectordb_bench/frontend/components/get_results/saveAsImage.py +++ b/vectordb_bench/frontend/components/get_results/saveAsImage.py @@ -9,10 +9,12 @@ def load_unpkg(src: str) -> str: return requests.get(src).text + def getResults(container, pageName="vectordb_bench"): container.subheader("Get results") saveAsImage(container, pageName) + def saveAsImage(container, pageName): html2canvasJS = load_unpkg(HTML_2_CANVAS_URL) container.write() diff --git a/vectordb_bench/frontend/components/run_test/caseSelector.py b/vectordb_bench/frontend/components/run_test/caseSelector.py index b25618271..e3d28238d 100644 --- a/vectordb_bench/frontend/components/run_test/caseSelector.py +++ b/vectordb_bench/frontend/components/run_test/caseSelector.py @@ -1,6 +1,4 @@ - from vectordb_bench.frontend.config.styles import * -from vectordb_bench.backend.cases import CaseType from vectordb_bench.frontend.config.dbCaseConfigs import * from collections import defaultdict @@ -23,8 +21,7 @@ def caseSelector(st, activedDbList: list[DB]): dbToCaseConfigs = defaultdict(lambda: defaultdict(dict)) caseClusters = UI_CASE_CLUSTERS + [get_custom_case_cluter()] for caseCluster in caseClusters: - activedCaseList += caseClusterExpander( - st, caseCluster, dbToCaseClusterConfigs, activedDbList) + activedCaseList += caseClusterExpander(st, caseCluster, dbToCaseClusterConfigs, activedDbList) for db in dbToCaseClusterConfigs: for uiCaseItem in dbToCaseClusterConfigs[db]: for case in uiCaseItem.cases: @@ -40,8 +37,7 @@ def caseClusterExpander(st, caseCluster: UICaseItemCluster, dbToCaseClusterConfi if uiCaseItem.isLine: addHorizontalLine(expander) else: - activedCases += caseItemCheckbox(expander, - dbToCaseClusterConfigs, uiCaseItem, activedDbList) + activedCases += caseItemCheckbox(expander, dbToCaseClusterConfigs, uiCaseItem, activedDbList) return activedCases @@ -53,9 +49,7 @@ def caseItemCheckbox(st, dbToCaseClusterConfigs, uiCaseItem: UICaseItem, actived ) if selected: - caseConfigSetting( - st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList - ) + caseConfigSetting(st.container(), dbToCaseClusterConfigs, uiCaseItem, activedDbList) return uiCaseItem.cases if selected else [] diff --git a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py index 257608413..800e6dede 100644 --- a/vectordb_bench/frontend/components/run_test/dbConfigSetting.py +++ b/vectordb_bench/frontend/components/run_test/dbConfigSetting.py @@ -42,10 +42,7 @@ def dbConfigSettingItem(st, activeDb: DB): # db config (unique) for key, property in properties.items(): - if ( - key not in dbConfigClass.common_short_configs() - and key not in dbConfigClass.common_long_configs() - ): + if key not in dbConfigClass.common_short_configs() and key not in dbConfigClass.common_long_configs(): column = columns[idx % DB_CONFIG_SETTING_COLUMNS] idx += 1 dbConfig[key] = column.text_input( diff --git a/vectordb_bench/frontend/components/run_test/dbSelector.py b/vectordb_bench/frontend/components/run_test/dbSelector.py index ccf0168c6..e20ee5059 100644 --- a/vectordb_bench/frontend/components/run_test/dbSelector.py +++ b/vectordb_bench/frontend/components/run_test/dbSelector.py @@ -22,7 +22,7 @@ def dbSelector(st): dbIsActived[db] = column.checkbox(db.name) try: column.image(DB_TO_ICON.get(db, "")) - except MediaFileStorageError as e: + except MediaFileStorageError: column.warning(f"{db.name} image not available") pass activedDbList = [db for db in DB_LIST if dbIsActived[db]] diff --git a/vectordb_bench/frontend/components/run_test/generateTasks.py b/vectordb_bench/frontend/components/run_test/generateTasks.py index 828913f30..d8a678ffc 100644 --- a/vectordb_bench/frontend/components/run_test/generateTasks.py +++ b/vectordb_bench/frontend/components/run_test/generateTasks.py @@ -7,13 +7,13 @@ def generate_tasks(activedDbList: list[DB], dbConfigs, activedCaseList: list[Cas for db in activedDbList: for case in activedCaseList: task = TaskConfig( - db=db.value, - db_config=dbConfigs[db], - case_config=case, - db_case_config=db.case_config_cls( - allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None) - )(**{key.value: value for key, value in allCaseConfigs[db][case].items()}), - ) + db=db.value, + db_config=dbConfigs[db], + case_config=case, + db_case_config=db.case_config_cls(allCaseConfigs[db][case].get(CaseConfigParamType.IndexType, None))( + **{key.value: value for key, value in allCaseConfigs[db][case].items()} + ), + ) tasks.append(task) - + return tasks diff --git a/vectordb_bench/frontend/components/run_test/submitTask.py b/vectordb_bench/frontend/components/run_test/submitTask.py index f824dd9d9..426095397 100644 --- a/vectordb_bench/frontend/components/run_test/submitTask.py +++ b/vectordb_bench/frontend/components/run_test/submitTask.py @@ -1,6 +1,6 @@ from datetime import datetime -from vectordb_bench.frontend.config.styles import * -from vectordb_bench.interface import benchMarkRunner +from vectordb_bench.frontend.config import styles +from vectordb_bench.interface import benchmark_runner def submitTask(st, tasks, isAllValid): @@ -27,10 +27,8 @@ def submitTask(st, tasks, isAllValid): def taskLabelInput(st): defaultTaskLabel = datetime.now().strftime("%Y%m%d%H") - columns = st.columns(TASK_LABEL_INPUT_COLUMNS) - taskLabel = columns[0].text_input( - "task_label", defaultTaskLabel, label_visibility="collapsed" - ) + columns = st.columns(styles.TASK_LABEL_INPUT_COLUMNS) + taskLabel = columns[0].text_input("task_label", defaultTaskLabel, label_visibility="collapsed") return taskLabel @@ -46,10 +44,8 @@ def advancedSettings(st): ) container = st.columns([1, 2]) - k = container[0].number_input("k",min_value=1, value=100, label_visibility="collapsed") - container[1].caption( - "K value for number of nearest neighbors to search" - ) + k = container[0].number_input("k", min_value=1, value=100, label_visibility="collapsed") + container[1].caption("K value for number of nearest neighbors to search") return index_already_exists, use_aliyun, k @@ -58,20 +54,20 @@ def controlPanel(st, tasks, taskLabel, isAllValid): index_already_exists, use_aliyun, k = advancedSettings(st) def runHandler(): - benchMarkRunner.set_drop_old(not index_already_exists) + benchmark_runner.set_drop_old(not index_already_exists) for task in tasks: task.case_config.k = k - benchMarkRunner.set_download_address(use_aliyun) - benchMarkRunner.run(tasks, taskLabel) + benchmark_runner.set_download_address(use_aliyun) + benchmark_runner.run(tasks, taskLabel) def stopHandler(): - benchMarkRunner.stop_running() + benchmark_runner.stop_running() - isRunning = benchMarkRunner.has_running() + isRunning = benchmark_runner.has_running() if isRunning: - currentTaskId = benchMarkRunner.get_current_task_id() - tasksCount = benchMarkRunner.get_tasks_count() + currentTaskId = benchmark_runner.get_current_task_id() + tasksCount = benchmark_runner.get_tasks_count() text = f":running: Running Task {currentTaskId} / {tasksCount}" st.progress(currentTaskId / tasksCount, text=text) @@ -89,7 +85,7 @@ def stopHandler(): ) else: - errorText = benchMarkRunner.latest_error or "" + errorText = benchmark_runner.latest_error or "" if len(errorText) > 0: st.error(errorText) disabled = True if len(tasks) == 0 or not isAllValid else False diff --git a/vectordb_bench/frontend/components/tables/data.py b/vectordb_bench/frontend/components/tables/data.py index 96134c7ff..fbe83f197 100644 --- a/vectordb_bench/frontend/components/tables/data.py +++ b/vectordb_bench/frontend/components/tables/data.py @@ -1,12 +1,11 @@ from dataclasses import asdict -from vectordb_bench.backend.cases import CaseType -from vectordb_bench.interface import benchMarkRunner +from vectordb_bench.interface import benchmark_runner from vectordb_bench.models import CaseResult, ResultLabel import pandas as pd def getNewResults(): - allResults = benchMarkRunner.get_results() + allResults = benchmark_runner.get_results() newResults: list[CaseResult] = [] for res in allResults: @@ -14,7 +13,6 @@ def getNewResults(): for result in results: if result.label == ResultLabel.NORMAL: newResults.append(result) - df = pd.DataFrame(formatData(newResults)) return df @@ -26,7 +24,6 @@ def formatData(caseResults: list[CaseResult]): db = caseResult.task_config.db.value db_label = caseResult.task_config.db_config.db_label case_config = caseResult.task_config.case_config - db_case_config = caseResult.task_config.db_case_config case = case_config.case_id.case_cls() filter_rate = case.filter_rate dataset = case.dataset.data.name @@ -41,4 +38,4 @@ def formatData(caseResults: list[CaseResult]): **metrics, } ) - return data \ No newline at end of file + return data diff --git a/vectordb_bench/frontend/config/dbCaseConfigs.py b/vectordb_bench/frontend/config/dbCaseConfigs.py index 7f076a2dd..e004f2ba7 100644 --- a/vectordb_bench/frontend/config/dbCaseConfigs.py +++ b/vectordb_bench/frontend/config/dbCaseConfigs.py @@ -33,9 +33,9 @@ class UICaseItem(BaseModel): def __init__( self, isLine: bool = False, - case_id: CaseType = None, - custom_case: dict = {}, - cases: list[CaseConfig] = [], + case_id: CaseType | None = None, + custom_case: dict | None = None, + cases: list[CaseConfig] | None = None, label: str = "", description: str = "", caseLabel: CaseLabel = CaseLabel.Performance, @@ -70,17 +70,13 @@ class UICaseItemCluster(BaseModel): def get_custom_case_items() -> list[UICaseItem]: custom_configs = get_custom_configs() return [ - UICaseItem( - case_id=CaseType.PerformanceCustomDataset, custom_case=custom_config.dict() - ) + UICaseItem(case_id=CaseType.PerformanceCustomDataset, custom_case=custom_config.dict()) for custom_config in custom_configs ] def get_custom_case_cluter() -> UICaseItemCluster: - return UICaseItemCluster( - label="Custom Search Performance Test", uiCaseItems=get_custom_case_items() - ) + return UICaseItemCluster(label="Custom Search Performance Test", uiCaseItems=get_custom_case_items()) UI_CASE_CLUSTERS: list[UICaseItemCluster] = [ @@ -224,8 +220,7 @@ class CaseConfigInput(BaseModel): "max": 300, "value": 32, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value, ) CaseConfigParamInput_l_value_ib = CaseConfigInput( @@ -236,8 +231,7 @@ class CaseConfigInput(BaseModel): "max": 300, "value": 50, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value, ) CaseConfigParamInput_l_value_is = CaseConfigInput( @@ -248,8 +242,7 @@ class CaseConfigInput(BaseModel): "max": 300, "value": 40, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value, ) CaseConfigParamInput_num_neighbors = CaseConfigInput( @@ -260,8 +253,7 @@ class CaseConfigInput(BaseModel): "max": 300, "value": 50, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.STREAMING_DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value, ) CaseConfigParamInput_search_list_size = CaseConfigInput( @@ -272,8 +264,7 @@ class CaseConfigInput(BaseModel): "max": 300, "value": 100, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.STREAMING_DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value, ) CaseConfigParamInput_max_alpha = CaseConfigInput( @@ -284,8 +275,7 @@ class CaseConfigInput(BaseModel): "max": 2.0, "value": 1.2, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.STREAMING_DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value, ) CaseConfigParamInput_num_dimensions = CaseConfigInput( @@ -296,8 +286,7 @@ class CaseConfigInput(BaseModel): "max": 2000, "value": 0, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.STREAMING_DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value, ) CaseConfigParamInput_query_search_list_size = CaseConfigInput( @@ -308,8 +297,7 @@ class CaseConfigInput(BaseModel): "max": 150, "value": 100, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.STREAMING_DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value, ) @@ -321,8 +309,7 @@ class CaseConfigInput(BaseModel): "max": 150, "value": 50, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.STREAMING_DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.STREAMING_DISKANN.value, ) CaseConfigParamInput_IndexType_PgVector = CaseConfigInput( @@ -358,8 +345,7 @@ class CaseConfigInput(BaseModel): "max": 64, "value": 30, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) CaseConfigParamInput_m = CaseConfigInput( @@ -370,8 +356,7 @@ class CaseConfigInput(BaseModel): "max": 64, "value": 16, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) @@ -383,8 +368,7 @@ class CaseConfigInput(BaseModel): "max": 512, "value": 360, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] - == IndexType.HNSW.value, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, ) CaseConfigParamInput_EFConstruction_Weaviate = CaseConfigInput( @@ -480,8 +464,7 @@ class CaseConfigInput(BaseModel): "max": 2000, "value": 300, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) CaseConfigParamInput_EFSearch_PgVectoRS = CaseConfigInput( @@ -492,8 +475,7 @@ class CaseConfigInput(BaseModel): "max": 65535, "value": 100, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) CaseConfigParamInput_EFConstruction_PgVector = CaseConfigInput( @@ -504,8 +486,7 @@ class CaseConfigInput(BaseModel): "max": 1024, "value": 256, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] - == IndexType.HNSW.value, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, ) @@ -537,8 +518,7 @@ class CaseConfigInput(BaseModel): "max": MAX_STREAMLIT_INT, "value": 100, }, - isDisplayed=lambda config: config[CaseConfigParamType.IndexType] - == IndexType.HNSW.value, + isDisplayed=lambda config: config[CaseConfigParamType.IndexType] == IndexType.HNSW.value, ) CaseConfigParamInput_EF_Weaviate = CaseConfigInput( @@ -565,8 +545,7 @@ class CaseConfigInput(BaseModel): "max": MAX_STREAMLIT_INT, "value": 100, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.DISKANN.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.DISKANN.value, ) CaseConfigParamInput_Nlist = CaseConfigInput( @@ -611,8 +590,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 0, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_IVF_PQ.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value], ) @@ -624,8 +602,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 8, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_IVF_PQ.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_IVF_PQ.value], ) CaseConfigParamInput_intermediate_graph_degree = CaseConfigInput( @@ -636,8 +613,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 64, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) CaseConfigParamInput_graph_degree = CaseConfigInput( @@ -648,8 +624,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 32, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) CaseConfigParamInput_itopk_size = CaseConfigInput( @@ -660,8 +635,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 128, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) CaseConfigParamInput_team_size = CaseConfigInput( @@ -672,8 +646,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 0, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) CaseConfigParamInput_search_width = CaseConfigInput( @@ -684,8 +657,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 4, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) CaseConfigParamInput_min_iterations = CaseConfigInput( @@ -696,8 +668,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 0, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) CaseConfigParamInput_max_iterations = CaseConfigInput( @@ -708,8 +679,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 0, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) CaseConfigParamInput_build_algo = CaseConfigInput( @@ -718,8 +688,7 @@ class CaseConfigInput(BaseModel): inputConfig={ "options": ["IVF_PQ", "NN_DESCENT"], }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.GPU_CAGRA.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.GPU_CAGRA.value], ) @@ -762,8 +731,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 10, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - in [IndexType.IVFFlat.value], + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) in [IndexType.IVFFlat.value], ) CaseConfigParamInput_Probes = CaseConfigInput( @@ -784,8 +752,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 10, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.IVFFlat.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFFlat.value, ) CaseConfigParamInput_Probes_PgVector = CaseConfigInput( @@ -796,8 +763,7 @@ class CaseConfigInput(BaseModel): "max": 65536, "value": 1, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.IVFFlat.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.IVFFlat.value, ) CaseConfigParamInput_EFSearch_PgVector = CaseConfigInput( @@ -808,8 +774,7 @@ class CaseConfigInput(BaseModel): "max": 2048, "value": 256, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) - == IndexType.HNSW.value, + isDisplayed=lambda config: config.get(CaseConfigParamType.IndexType, None) == IndexType.HNSW.value, ) @@ -845,8 +810,7 @@ class CaseConfigInput(BaseModel): inputConfig={ "options": ["x4", "x8", "x16", "x32", "x64"], }, - isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) - == "product" + isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "product" and config.get(CaseConfigParamType.IndexType, None) in [ IndexType.HNSW.value, @@ -885,8 +849,7 @@ class CaseConfigInput(BaseModel): inputConfig={ "value": False, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) - == "bit" + isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "bit", ) CaseConfigParamInput_quantized_fetch_limit_PgVector = CaseConfigInput( @@ -899,8 +862,8 @@ class CaseConfigInput(BaseModel): "max": 1000, "value": 200, }, - isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) - == "bit" and config.get(CaseConfigParamType.reranking, False) + isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "bit" + and config.get(CaseConfigParamType.reranking, False), ) @@ -908,12 +871,10 @@ class CaseConfigInput(BaseModel): label=CaseConfigParamType.rerankingMetric, inputType=InputType.Option, inputConfig={ - "options": [ - metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"] - ], + "options": [metric.value for metric in MetricType if metric.value not in ["HAMMING", "JACCARD"]], }, - isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) - == "bit" and config.get(CaseConfigParamType.reranking, False) + isDisplayed=lambda config: config.get(CaseConfigParamType.quantizationType, None) == "bit" + and config.get(CaseConfigParamType.reranking, False), ) @@ -1131,7 +1092,10 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_NumCandidates_ES, ] -AWSOpensearchLoadingConfig = [CaseConfigParamInput_EFConstruction_AWSOpensearch, CaseConfigParamInput_M_AWSOpensearch] +AWSOpensearchLoadingConfig = [ + CaseConfigParamInput_EFConstruction_AWSOpensearch, + CaseConfigParamInput_M_AWSOpensearch, +] AWSOpenSearchPerformanceConfig = [ CaseConfigParamInput_EFConstruction_AWSOpensearch, CaseConfigParamInput_M_AWSOpensearch, @@ -1250,7 +1214,10 @@ class CaseConfigInput(BaseModel): CaseConfigParamInput_max_parallel_workers_AlloyDB, ] -AliyunElasticsearchLoadingConfig = [CaseConfigParamInput_EFConstruction_AliES, CaseConfigParamInput_M_AliES] +AliyunElasticsearchLoadingConfig = [ + CaseConfigParamInput_EFConstruction_AliES, + CaseConfigParamInput_M_AliES, +] AliyunElasticsearchPerformanceConfig = [ CaseConfigParamInput_EFConstruction_AliES, CaseConfigParamInput_M_AliES, diff --git a/vectordb_bench/frontend/pages/concurrent.py b/vectordb_bench/frontend/pages/concurrent.py index 941675436..045239494 100644 --- a/vectordb_bench/frontend/pages/concurrent.py +++ b/vectordb_bench/frontend/pages/concurrent.py @@ -9,7 +9,7 @@ from vectordb_bench.frontend.components.concurrent.charts import drawChartsByCase from vectordb_bench.frontend.components.get_results.saveAsImage import getResults from vectordb_bench.frontend.config.styles import FAVICON -from vectordb_bench.interface import benchMarkRunner +from vectordb_bench.interface import benchmark_runner from vectordb_bench.models import TestResult @@ -25,7 +25,7 @@ def main(): # header drawHeaderIcon(st) - allResults = benchMarkRunner.get_results() + allResults = benchmark_runner.get_results() def check_conc_data(res: TestResult): case_results = res.results @@ -57,9 +57,7 @@ def check_conc_data(res: TestResult): # main latency_type = st.radio("Latency Type", options=["latency_p99", "latency_avg"]) - drawChartsByCase( - shownData, showCaseNames, st.container(), latency_type=latency_type - ) + drawChartsByCase(shownData, showCaseNames, st.container(), latency_type=latency_type) # footer footer(st.container()) diff --git a/vectordb_bench/frontend/pages/custom.py b/vectordb_bench/frontend/pages/custom.py index 28c249f78..4f6beed91 100644 --- a/vectordb_bench/frontend/pages/custom.py +++ b/vectordb_bench/frontend/pages/custom.py @@ -1,13 +1,21 @@ +from functools import partial import streamlit as st from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon -from vectordb_bench.frontend.components.custom.displayCustomCase import displayCustomCase +from vectordb_bench.frontend.components.custom.displayCustomCase import ( + displayCustomCase, +) from vectordb_bench.frontend.components.custom.displaypPrams import displayParams -from vectordb_bench.frontend.components.custom.getCustomConfig import CustomCaseConfig, generate_custom_case, get_custom_configs, save_custom_configs +from vectordb_bench.frontend.components.custom.getCustomConfig import ( + CustomCaseConfig, + generate_custom_case, + get_custom_configs, + save_custom_configs, +) from vectordb_bench.frontend.components.custom.initStyle import initStyle from vectordb_bench.frontend.config.styles import FAVICON, PAGE_TITLE -class CustomCaseManager(): +class CustomCaseManager: customCaseItems: list[CustomCaseConfig] def __init__(self): @@ -52,12 +60,25 @@ def main(): columns = expander.columns(8) columns[0].button( - "Save", key=f"{key}_", type="secondary", on_click=lambda: customCaseManager.save()) - columns[1].button(":red[Delete]", key=f"{key}_delete", type="secondary", - on_click=lambda: customCaseManager.deleteCase(idx)) - - st.button("\+ New Dataset", key=f"add_custom_configs", - type="primary", on_click=lambda: customCaseManager.addCase()) + "Save", + key=f"{key}_", + type="secondary", + on_click=lambda: customCaseManager.save(), + ) + columns[1].button( + ":red[Delete]", + key=f"{key}_delete", + type="secondary", + # B023 + on_click=partial(lambda idx: customCaseManager.deleteCase(idx), idx=idx), + ) + + st.button( + "\+ New Dataset", + key="add_custom_configs", + type="primary", + on_click=lambda: customCaseManager.addCase(), + ) if __name__ == "__main__": diff --git a/vectordb_bench/frontend/pages/quries_per_dollar.py b/vectordb_bench/frontend/pages/quries_per_dollar.py index 4a45181de..2822f2864 100644 --- a/vectordb_bench/frontend/pages/quries_per_dollar.py +++ b/vectordb_bench/frontend/pages/quries_per_dollar.py @@ -15,8 +15,8 @@ from vectordb_bench.frontend.components.check_results.charts import drawMetricChart from vectordb_bench.frontend.components.check_results.filters import getshownData from vectordb_bench.frontend.components.get_results.saveAsImage import getResults -from vectordb_bench.frontend.config.styles import * -from vectordb_bench.interface import benchMarkRunner + +from vectordb_bench.interface import benchmark_runner from vectordb_bench.metric import QURIES_PER_DOLLAR_METRIC @@ -27,7 +27,7 @@ def main(): # header drawHeaderIcon(st) - allResults = benchMarkRunner.get_results() + allResults = benchmark_runner.get_results() st.title("Vector DB Benchmark (QP$)") diff --git a/vectordb_bench/frontend/pages/run_test.py b/vectordb_bench/frontend/pages/run_test.py index 1297743ae..3da8ea2c0 100644 --- a/vectordb_bench/frontend/pages/run_test.py +++ b/vectordb_bench/frontend/pages/run_test.py @@ -15,10 +15,10 @@ def main(): # set page config initRunTestPageConfig(st) - + # init style initStyle(st) - + # header drawHeaderIcon(st) @@ -48,11 +48,7 @@ def main(): activedCaseList, allCaseConfigs = caseSelector(caseSelectorContainer, activedDbList) # generate tasks - tasks = ( - generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs) - if isAllValid - else [] - ) + tasks = generate_tasks(activedDbList, dbConfigs, activedCaseList, allCaseConfigs) if isAllValid else [] # submit submitContainer = st.container() diff --git a/vectordb_bench/frontend/utils.py b/vectordb_bench/frontend/utils.py index 787b67d03..407dd497d 100644 --- a/vectordb_bench/frontend/utils.py +++ b/vectordb_bench/frontend/utils.py @@ -18,5 +18,5 @@ def addHorizontalLine(st): def generate_random_string(length): letters = string.ascii_letters + string.digits - result = ''.join(random.choice(letters) for _ in range(length)) + result = "".join(random.choice(letters) for _ in range(length)) return result diff --git a/vectordb_bench/frontend/vdb_benchmark.py b/vectordb_bench/frontend/vdb_benchmark.py index c76a2f3de..cf261bf1d 100644 --- a/vectordb_bench/frontend/vdb_benchmark.py +++ b/vectordb_bench/frontend/vdb_benchmark.py @@ -11,8 +11,8 @@ from vectordb_bench.frontend.components.check_results.charts import drawCharts from vectordb_bench.frontend.components.check_results.filters import getshownData from vectordb_bench.frontend.components.get_results.saveAsImage import getResults -from vectordb_bench.frontend.config.styles import * -from vectordb_bench.interface import benchMarkRunner + +from vectordb_bench.interface import benchmark_runner def main(): @@ -22,7 +22,7 @@ def main(): # header drawHeaderIcon(st) - allResults = benchMarkRunner.get_results() + allResults = benchmark_runner.get_results() st.title("Vector Database Benchmark") st.caption( @@ -32,9 +32,7 @@ def main(): # results selector and filter resultSelectorContainer = st.sidebar.container() - shownData, failedTasks, showCaseNames = getshownData( - allResults, resultSelectorContainer - ) + shownData, failedTasks, showCaseNames = getshownData(allResults, resultSelectorContainer) resultSelectorContainer.divider() diff --git a/vectordb_bench/interface.py b/vectordb_bench/interface.py index c765d0d63..615a9600d 100644 --- a/vectordb_bench/interface.py +++ b/vectordb_bench/interface.py @@ -5,6 +5,7 @@ import signal import traceback import uuid +from collections.abc import Callable from enum import Enum from multiprocessing.connection import Connection @@ -16,8 +17,15 @@ from .backend.result_collector import ResultCollector from .backend.task_runner import TaskRunner from .metric import Metric -from .models import (CaseResult, LoadTimeoutError, PerformanceTimeoutError, - ResultLabel, TaskConfig, TaskStage, TestResult) +from .models import ( + CaseResult, + LoadTimeoutError, + PerformanceTimeoutError, + ResultLabel, + TaskConfig, + TaskStage, + TestResult, +) log = logging.getLogger(__name__) @@ -37,11 +45,9 @@ def __init__(self): self.drop_old: bool = True self.dataset_source: DatasetSource = DatasetSource.S3 - def set_drop_old(self, drop_old: bool): self.drop_old = drop_old - def set_download_address(self, use_aliyun: bool): if use_aliyun: self.dataset_source = DatasetSource.AliyunOSS @@ -59,7 +65,9 @@ def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool: log.warning("Empty tasks submitted") return False - log.debug(f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}") + log.debug( + f"tasks: {tasks}, task_label: {task_label}, dataset source: {self.dataset_source}", + ) # Generate run_id run_id = uuid.uuid4().hex @@ -70,7 +78,12 @@ def run(self, tasks: list[TaskConfig], task_label: str | None = None) -> bool: self.latest_error = "" try: - self.running_task = Assembler.assemble_all(run_id, task_label, tasks, self.dataset_source) + self.running_task = Assembler.assemble_all( + run_id, + task_label, + tasks, + self.dataset_source, + ) self.running_task.display() except ModuleNotFoundError as e: msg = f"Please install client for database, error={e}" @@ -119,7 +132,7 @@ def get_tasks_count(self) -> int: return 0 def get_current_task_id(self) -> int: - """ the index of current running task + """the index of current running task return -1 if not running """ if not self.running_task: @@ -153,18 +166,18 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non task_config=runner.config, ) - # drop_old = False if latest_runner and runner == latest_runner else config.DROP_OLD - # drop_old = config.DROP_OLD drop_old = TaskStage.DROP_OLD in runner.config.stages - if latest_runner and runner == latest_runner: - drop_old = False - elif not self.drop_old: + if (latest_runner and runner == latest_runner) or not self.drop_old: drop_old = False try: - log.info(f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}") + log.info( + f"[{idx+1}/{running_task.num_cases()}] start case: {runner.display()}, drop_old={drop_old}", + ) case_res.metrics = runner.run(drop_old) - log.info(f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, " - f"result={case_res.metrics}, label={case_res.label}") + log.info( + f"[{idx+1}/{running_task.num_cases()}] finish case: {runner.display()}, " + f"result={case_res.metrics}, label={case_res.label}", + ) # cache the latest succeeded runner latest_runner = runner @@ -176,12 +189,16 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non if not drop_old: case_res.metrics.load_duration = cached_load_duration if cached_load_duration else 0.0 except (LoadTimeoutError, PerformanceTimeoutError) as e: - log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}") + log.warning( + f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}", + ) case_res.label = ResultLabel.OUTOFRANGE continue except Exception as e: - log.warning(f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}") + log.warning( + f"[{idx+1}/{running_task.num_cases()}] case {runner.display()} failed to run, reason={e}", + ) traceback.print_exc() case_res.label = ResultLabel.FAILED continue @@ -200,10 +217,14 @@ def _async_task_v2(self, running_task: TaskRunner, send_conn: Connection) -> Non send_conn.send((SIGNAL.SUCCESS, None)) send_conn.close() - log.info(f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}") + log.info( + f"Success to finish task: label={running_task.task_label}, run_id={running_task.run_id}", + ) except Exception as e: - err_msg = f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}" + err_msg = ( + f"An error occurs when running task={running_task.task_label}, run_id={running_task.run_id}, err={e}" + ) traceback.print_exc() log.warning(err_msg) send_conn.send((SIGNAL.ERROR, err_msg)) @@ -226,16 +247,26 @@ def _clear_running_task(self): self.receive_conn.close() self.receive_conn = None - def _run_async(self, conn: Connection) -> bool: - log.info(f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, case number: {len(self.running_task.case_runners)}") + log.info( + f"task submitted: id={self.running_task.run_id}, {self.running_task.task_label}, ", + f"case number: {len(self.running_task.case_runners)}", + ) global global_result_future - executor = concurrent.futures.ProcessPoolExecutor(max_workers=1, mp_context=mp.get_context("spawn")) + executor = concurrent.futures.ProcessPoolExecutor( + max_workers=1, + mp_context=mp.get_context("spawn"), + ) global_result_future = executor.submit(self._async_task_v2, self.running_task, conn) return True - def kill_proc_tree(self, sig=signal.SIGTERM, timeout=None, on_terminate=None): + def kill_proc_tree( + self, + sig: int = signal.SIGTERM, + timeout: float | None = None, + on_terminate: Callable | None = None, + ): """Kill a process tree (including grandchildren) with signal "sig" and return a (gone, still_alive) tuple. "on_terminate", if specified, is a callback function which is @@ -248,12 +279,11 @@ def kill_proc_tree(self, sig=signal.SIGTERM, timeout=None, on_terminate=None): p.send_signal(sig) except psutil.NoSuchProcess: pass - gone, alive = psutil.wait_procs(children, timeout=timeout, - callback=on_terminate) + gone, alive = psutil.wait_procs(children, timeout=timeout, callback=on_terminate) for p in alive: log.warning(f"force killing child process: {p}") p.kill() -benchMarkRunner = BenchMarkRunner() +benchmark_runner = BenchMarkRunner() diff --git a/vectordb_bench/log_util.py b/vectordb_bench/log_util.py index b923bdcd2..d75688137 100644 --- a/vectordb_bench/log_util.py +++ b/vectordb_bench/log_util.py @@ -1,102 +1,97 @@ import logging from logging import config -def init(log_level): - LOGGING = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'default': { - 'format': '%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)', + +def init(log_level: str): + log_config = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(asctime)s | %(levelname)s |%(message)s (%(filename)s:%(lineno)s)", }, - 'colorful_console': { - 'format': '%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)', - '()': ColorfulFormatter, + "colorful_console": { + "format": "%(asctime)s | %(levelname)s: %(message)s (%(filename)s:%(lineno)s) (%(process)s)", + "()": ColorfulFormatter, }, }, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'formatter': 'colorful_console', + "handlers": { + "console": { + "class": "logging.StreamHandler", + "formatter": "colorful_console", }, - 'no_color_console': { - 'class': 'logging.StreamHandler', - 'formatter': 'default', + "no_color_console": { + "class": "logging.StreamHandler", + "formatter": "default", }, }, - 'loggers': { - 'vectordb_bench': { - 'handlers': ['console'], - 'level': log_level, - 'propagate': False + "loggers": { + "vectordb_bench": { + "handlers": ["console"], + "level": log_level, + "propagate": False, }, - 'no_color': { - 'handlers': ['no_color_console'], - 'level': log_level, - 'propagate': False + "no_color": { + "handlers": ["no_color_console"], + "level": log_level, + "propagate": False, }, }, - 'propagate': False, + "propagate": False, } - config.dictConfig(LOGGING) + config.dictConfig(log_config) -class colors: - HEADER= '\033[95m' - INFO= '\033[92m' - DEBUG= '\033[94m' - WARNING= '\033[93m' - ERROR= '\033[95m' - CRITICAL= '\033[91m' - ENDC= '\033[0m' +class colors: + HEADER = "\033[95m" + INFO = "\033[92m" + DEBUG = "\033[94m" + WARNING = "\033[93m" + ERROR = "\033[95m" + CRITICAL = "\033[91m" + ENDC = "\033[0m" COLORS = { - 'INFO': colors.INFO, - 'INFOM': colors.INFO, - 'DEBUG': colors.DEBUG, - 'DEBUGM': colors.DEBUG, - 'WARNING': colors.WARNING, - 'WARNINGM': colors.WARNING, - 'CRITICAL': colors.CRITICAL, - 'CRITICALM': colors.CRITICAL, - 'ERROR': colors.ERROR, - 'ERRORM': colors.ERROR, - 'ENDC': colors.ENDC, + "INFO": colors.INFO, + "INFOM": colors.INFO, + "DEBUG": colors.DEBUG, + "DEBUGM": colors.DEBUG, + "WARNING": colors.WARNING, + "WARNINGM": colors.WARNING, + "CRITICAL": colors.CRITICAL, + "CRITICALM": colors.CRITICAL, + "ERROR": colors.ERROR, + "ERRORM": colors.ERROR, + "ENDC": colors.ENDC, } class ColorFulFormatColMixin: - def format_col(self, message_str, level_name): - if level_name in COLORS.keys(): - message_str = COLORS[level_name] + message_str + COLORS['ENDC'] - return message_str - - def formatTime(self, record, datefmt=None): - ret = super().formatTime(record, datefmt) - return ret + def format_col(self, message: str, level_name: str): + if level_name in COLORS: + message = COLORS[level_name] + message + COLORS["ENDC"] + return message class ColorfulLogRecordProxy(logging.LogRecord): - def __init__(self, record): + def __init__(self, record: any): self._record = record - msg_level = record.levelname + 'M' + msg_level = record.levelname + "M" self.msg = f"{COLORS[msg_level]}{record.msg}{COLORS['ENDC']}" self.filename = record.filename - self.lineno = f'{record.lineno}' - self.process = f'{record.process}' + self.lineno = f"{record.lineno}" + self.process = f"{record.process}" self.levelname = f"{COLORS[record.levelname]}{record.levelname}{COLORS['ENDC']}" - def __getattr__(self, attr): + def __getattr__(self, attr: any): if attr not in self.__dict__: return getattr(self._record, attr) return getattr(self, attr) class ColorfulFormatter(ColorFulFormatColMixin, logging.Formatter): - def format(self, record): + def format(self, record: any): proxy = ColorfulLogRecordProxy(record) - message_str = super().format(proxy) - - return message_str + return super().format(proxy) diff --git a/vectordb_bench/metric.py b/vectordb_bench/metric.py index 9f083a5c6..e0b6cff0e 100644 --- a/vectordb_bench/metric.py +++ b/vectordb_bench/metric.py @@ -1,8 +1,7 @@ import logging -import numpy as np - from dataclasses import dataclass, field +import numpy as np log = logging.getLogger(__name__) @@ -33,19 +32,19 @@ class Metric: QPS_METRIC = "qps" RECALL_METRIC = "recall" -metricUnitMap = { +metric_unit_map = { LOAD_DURATION_METRIC: "s", SERIAL_LATENCY_P99_METRIC: "ms", MAX_LOAD_COUNT_METRIC: "K", QURIES_PER_DOLLAR_METRIC: "K", } -lowerIsBetterMetricList = [ +lower_is_better_metrics = [ LOAD_DURATION_METRIC, SERIAL_LATENCY_P99_METRIC, ] -metricOrder = [ +metric_order = [ QPS_METRIC, RECALL_METRIC, LOAD_DURATION_METRIC, @@ -55,7 +54,7 @@ class Metric: def isLowerIsBetterMetric(metric: str) -> bool: - return metric in lowerIsBetterMetricList + return metric in lower_is_better_metrics def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float: @@ -70,7 +69,7 @@ def calc_recall(count: int, ground_truth: list[int], got: list[int]) -> float: def get_ideal_dcg(k: int): ideal_dcg = 0 for i in range(k): - ideal_dcg += 1 / np.log2(i+2) + ideal_dcg += 1 / np.log2(i + 2) return ideal_dcg @@ -78,8 +77,8 @@ def get_ideal_dcg(k: int): def calc_ndcg(ground_truth: list[int], got: list[int], ideal_dcg: float) -> float: dcg = 0 ground_truth = list(ground_truth) - for id in set(got): - if id in ground_truth: - idx = ground_truth.index(id) - dcg += 1 / np.log2(idx+2) + for got_id in set(got): + if got_id in ground_truth: + idx = ground_truth.index(got_id) + dcg += 1 / np.log2(idx + 2) return dcg / ideal_dcg diff --git a/vectordb_bench/models.py b/vectordb_bench/models.py index 648fb1727..49bb04ae0 100644 --- a/vectordb_bench/models.py +++ b/vectordb_bench/models.py @@ -2,29 +2,31 @@ import pathlib from datetime import date, datetime from enum import Enum, StrEnum, auto -from typing import List, Self +from typing import Self import ujson +from . import config +from .backend.cases import CaseType from .backend.clients import ( DB, - DBConfig, DBCaseConfig, + DBConfig, ) -from .backend.cases import CaseType from .base import BaseModel -from . import config from .metric import Metric log = logging.getLogger(__name__) class LoadTimeoutError(TimeoutError): - pass + def __init__(self, duration: int): + super().__init__(f"capacity case load timeout in {duration}s") class PerformanceTimeoutError(TimeoutError): - pass + def __init__(self): + super().__init__("Performance case optimize timeout") class CaseConfigParamType(Enum): @@ -92,7 +94,7 @@ class CustomizedCase(BaseModel): class ConcurrencySearchConfig(BaseModel): - num_concurrency: List[int] = config.NUM_CONCURRENCY + num_concurrency: list[int] = config.NUM_CONCURRENCY concurrency_duration: int = config.CONCURRENCY_DURATION @@ -146,7 +148,7 @@ class TaskConfig(BaseModel): db_config: DBConfig db_case_config: DBCaseConfig case_config: CaseConfig - stages: List[TaskStage] = ALL_TASK_STAGES + stages: list[TaskStage] = ALL_TASK_STAGES @property def db_name(self): @@ -210,26 +212,23 @@ def write_db_file(self, result_dir: pathlib.Path, partial: Self, db: str): log.info(f"local result directory not exist, creating it: {result_dir}") result_dir.mkdir(parents=True) - file_name = self.file_fmt.format( - date.today().strftime("%Y%m%d"), partial.task_label, db - ) + file_name = self.file_fmt.format(date.today().strftime("%Y%m%d"), partial.task_label, db) result_file = result_dir.joinpath(file_name) if result_file.exists(): - log.warning( - f"Replacing existing result with the same file_name: {result_file}" - ) + log.warning(f"Replacing existing result with the same file_name: {result_file}") log.info(f"write results to disk {result_file}") - with open(result_file, "w") as f: + with pathlib.Path(result_file).open("w") as f: b = partial.json(exclude={"db_config": {"password", "api_key"}}) f.write(b) @classmethod def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: if not full_path.exists(): - raise ValueError(f"No such file: {full_path}") + msg = f"No such file: {full_path}" + raise ValueError(msg) - with open(full_path) as f: + with pathlib.Path(full_path).open("r") as f: test_result = ujson.loads(f.read()) if "task_label" not in test_result: test_result["task_label"] = test_result["run_id"] @@ -248,19 +247,16 @@ def read_file(cls, full_path: pathlib.Path, trans_unit: bool = False) -> Self: if trans_unit: cur_max_count = case_result["metrics"]["max_load_count"] case_result["metrics"]["max_load_count"] = ( - cur_max_count / 1000 - if int(cur_max_count) > 0 - else cur_max_count + cur_max_count / 1000 if int(cur_max_count) > 0 else cur_max_count ) cur_latency = case_result["metrics"]["serial_latency_p99"] case_result["metrics"]["serial_latency_p99"] = ( cur_latency * 1000 if cur_latency > 0 else cur_latency ) - c = TestResult.validate(test_result) - - return c + return TestResult.validate(test_result) + # ruff: noqa def display(self, dbs: list[DB] | None = None): filter_list = dbs if dbs and isinstance(dbs, list) else None sorted_results = sorted( @@ -273,31 +269,18 @@ def display(self, dbs: list[DB] | None = None): reverse=True, ) - filtered_results = [ - r - for r in sorted_results - if not filter_list or r.task_config.db not in filter_list - ] + filtered_results = [r for r in sorted_results if not filter_list or r.task_config.db not in filter_list] - def append_return(x, y): + def append_return(x: any, y: any): x.append(y) return x max_db = max(map(len, [f.task_config.db.name for f in filtered_results])) - max_db_labels = ( - max(map(len, [f.task_config.db_config.db_label for f in filtered_results])) - + 3 - ) - max_case = max( - map(len, [f.task_config.case_config.case_id.name for f in filtered_results]) - ) - max_load_dur = ( - max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3 - ) + max_db_labels = max(map(len, [f.task_config.db_config.db_label for f in filtered_results])) + 3 + max_case = max(map(len, [f.task_config.case_config.case_id.name for f in filtered_results])) + max_load_dur = max(map(len, [str(f.metrics.load_duration) for f in filtered_results])) + 3 max_qps = max(map(len, [str(f.metrics.qps) for f in filtered_results])) + 3 - max_recall = ( - max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3 - ) + max_recall = max(map(len, [str(f.metrics.recall) for f in filtered_results])) + 3 max_db_labels = 8 if max_db_labels < 8 else max_db_labels max_load_dur = 11 if max_load_dur < 11 else max_load_dur @@ -356,7 +339,7 @@ def append_return(x, y): f.metrics.recall, f.metrics.max_load_count, f.label.value, - ) + ), ) tmp_logger = logging.getLogger("no_color")