Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ test = [
"ruff",
"pytest",
]
restful = [ "flask" ]

all = [
"grpcio==1.53.0", # for qdrant-client and pymilvus
Expand Down Expand Up @@ -111,6 +112,7 @@ turbopuffer = [ "turbopuffer" ]

[project.scripts]
init_bench = "vectordb_bench.__main__:main"
init_bench_rest = "vectordb_bench.restful.app:main"
vectordbbench = "vectordb_bench.cli.vectordbbench:cli"

[tool.setuptools_scm]
Expand Down
Empty file.
101 changes: 101 additions & 0 deletions vectordb_bench/restful/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from flask import Flask, jsonify, request

from vectordb_bench.backend.clients import DB
from vectordb_bench.interface import benchmark_runner
from vectordb_bench.models import ALL_TASK_STAGES, CaseConfig, TaskConfig, TaskStage
from vectordb_bench.restful.format_res import format_results

app = Flask(__name__)


def res_wrapper(code: int = 0, message: str = "", data: any = None): # noqa: RUF013
return jsonify({"code": code, "message": message, "data": data}), 200


def success_res(data: any = None, message: str = "success"): # noqa: RUF013
return res_wrapper(code=0, message=message, data=data)


def failed_res(data: any = None, message: str = "failed"): # noqa: RUF013
return res_wrapper(code=1, message=message, data=data)


@app.route("/get_res", methods=["GET"])
def get_res():
"""task label -> res"""
task_label = request.args.get("task_label", "standard")
all_results = benchmark_runner.get_results()
res = format_results(all_results, task_label=task_label)

return success_res(res)


@app.route("/get_status", methods=["GET"])
def get_status():
"running 5/18, not running"
is_running = benchmark_runner.has_running()
tasks_count = benchmark_runner.get_tasks_count()
if is_running:
tasks_count = benchmark_runner.get_tasks_count()
cur_task_idx = benchmark_runner.get_current_task_id()
return success_res(
data={
"is_running": is_running,
"tasks_count": tasks_count,
"cur_task_idx": cur_task_idx,
}
)
return success_res(data={"is_running": is_running})


@app.route("/stop", methods=["GET"])
def stop():
benchmark_runner.stop_running()
return success_res(message="stopped")


@app.route("/run", methods=["post"])
def run():
if benchmark_runner.has_running():
return failed_res(message="There are already running tasks.")
data = request.get_json()
task_label = data.get("task_label", "test")
use_aliyun = data.get("use_aliyun", False)
task_configs: list[TaskConfig] = []
try:
tasks = data.get("tasks", [])
if len(tasks) == 0:
return failed_res(message="empty tasks")
for task in tasks:
db = DB(task["db"])
db_config = db.config_cls(**task["db_config"])
case_config = CaseConfig(**task["case_config"])
print(case_config) # noqa: T201
db_case_config = db.case_config_cls(index_type=task["db_case_config"].get("index", None))(
**task["db_case_config"]
)
stages = [TaskStage(stage) for stage in task.get("stages", ALL_TASK_STAGES)]
print(stages) # noqa: T201
task_config = TaskConfig(
db=db,
db_config=db_config,
case_config=case_config,
db_case_config=db_case_config,
stages=stages,
)
task_configs.append(task_config)
except Exception as e:
return failed_res(message=f"invalid tasks: {e}")

benchmark_runner.set_download_address(use_aliyun)
benchmark_runner.run(task_configs, task_label)

return success_res(message="start")


def main():
app.run(host="0.0.0.0", port=5000, debug=False) # noqa: S104


if __name__ == "__main__":
main()
76 changes: 76 additions & 0 deletions vectordb_bench/restful/format_res.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from dataclasses import asdict

from pydantic import BaseModel

from vectordb_bench.backend.cases import CaseLabel
from vectordb_bench.models import TestResult


class FormatResult(BaseModel):
# db_config
task_label: str = ""
timestamp: int = 0
db: str = ""
db_label: str = "" # perf-x86
version: str = ""
note: str = ""

# params
params: dict = {}

# case_config
case_name: str = ""
dataset: str = ""
dim: int = 0
filter_type: str = "" # FilterType(Enum).value
filter_rate: float = 0
k: int = 100

# metrics
max_load_count: int = 0
load_duration: int = 0
qps: float = 0
serial_latency_p99: float = 0
recall: float = 0
ndcg: float = 0
conc_num_list: list[int] = []
conc_qps_list: list[float] = []
conc_latency_p99_list: list[float] = []
conc_latency_avg_list: list[float] = []


def format_results(test_results: list[TestResult], task_label: str) -> list[dict]:
results = []
for test_result in test_results:
if test_result.task_label == task_label:
for case_result in test_result.results:
task_config = case_result.task_config
case_config = task_config.case_config
case = case_config.case
if case.label == CaseLabel.Load:
continue
dataset = case.dataset.data
filter_ = case.filters
metrics = asdict(case_result.metrics)
for k, v in metrics.items():
if isinstance(v, list) and len(v) > 0:
metrics[k] = [round(d, 6) if isinstance(d, float) else d for d in v]
results.append(
FormatResult(
task_label=test_result.task_label,
timestamp=int(test_result.timestamp),
db=task_config.db.value,
db_label=task_config.db_config.db_label,
version=task_config.db_config.version,
note=task_config.db_config.note,
params=task_config.db_case_config.dict(),
case_name=case.name,
dataset=dataset.full_name,
dim=dataset.dim,
filter_type=filter_.type.name,
filter_rate=filter_.filter_rate,
k=task_config.case_config.k,
**metrics,
).dict()
)
return results