From 907dce5f85838fda262eef07e44869dbaff604f7 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Wed, 16 Apr 2025 09:38:02 +0800 Subject: [PATCH] support restful Signed-off-by: min.tian --- pyproject.toml | 2 + vectordb_bench/restful/__init__.py | 0 vectordb_bench/restful/app.py | 101 +++++++++++++++++++++++++++ vectordb_bench/restful/format_res.py | 76 ++++++++++++++++++++ 4 files changed, 179 insertions(+) create mode 100644 vectordb_bench/restful/__init__.py create mode 100644 vectordb_bench/restful/app.py create mode 100644 vectordb_bench/restful/format_res.py diff --git a/pyproject.toml b/pyproject.toml index bdb7768c0..ef4792bee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,6 +51,7 @@ test = [ "ruff", "pytest", ] +restful = [ "flask" ] all = [ "grpcio==1.53.0", # for qdrant-client and pymilvus @@ -111,6 +112,7 @@ turbopuffer = [ "turbopuffer" ] [project.scripts] init_bench = "vectordb_bench.__main__:main" +init_bench_rest = "vectordb_bench.restful.app:main" vectordbbench = "vectordb_bench.cli.vectordbbench:cli" [tool.setuptools_scm] diff --git a/vectordb_bench/restful/__init__.py b/vectordb_bench/restful/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/vectordb_bench/restful/app.py b/vectordb_bench/restful/app.py new file mode 100644 index 000000000..ad0336501 --- /dev/null +++ b/vectordb_bench/restful/app.py @@ -0,0 +1,101 @@ +from flask import Flask, jsonify, request + +from vectordb_bench.backend.clients import DB +from vectordb_bench.interface import benchmark_runner +from vectordb_bench.models import ALL_TASK_STAGES, CaseConfig, TaskConfig, TaskStage +from vectordb_bench.restful.format_res import format_results + +app = Flask(__name__) + + +def res_wrapper(code: int = 0, message: str = "", data: any = None): # noqa: RUF013 + return jsonify({"code": code, "message": message, "data": data}), 200 + + +def success_res(data: any = None, message: str = "success"): # noqa: RUF013 + return res_wrapper(code=0, message=message, data=data) + + +def failed_res(data: any = None, message: str = "failed"): # noqa: RUF013 + return res_wrapper(code=1, message=message, data=data) + + +@app.route("/get_res", methods=["GET"]) +def get_res(): + """task label -> res""" + task_label = request.args.get("task_label", "standard") + all_results = benchmark_runner.get_results() + res = format_results(all_results, task_label=task_label) + + return success_res(res) + + +@app.route("/get_status", methods=["GET"]) +def get_status(): + "running 5/18, not running" + is_running = benchmark_runner.has_running() + tasks_count = benchmark_runner.get_tasks_count() + if is_running: + tasks_count = benchmark_runner.get_tasks_count() + cur_task_idx = benchmark_runner.get_current_task_id() + return success_res( + data={ + "is_running": is_running, + "tasks_count": tasks_count, + "cur_task_idx": cur_task_idx, + } + ) + return success_res(data={"is_running": is_running}) + + +@app.route("/stop", methods=["GET"]) +def stop(): + benchmark_runner.stop_running() + return success_res(message="stopped") + + +@app.route("/run", methods=["post"]) +def run(): + if benchmark_runner.has_running(): + return failed_res(message="There are already running tasks.") + data = request.get_json() + task_label = data.get("task_label", "test") + use_aliyun = data.get("use_aliyun", False) + task_configs: list[TaskConfig] = [] + try: + tasks = data.get("tasks", []) + if len(tasks) == 0: + return failed_res(message="empty tasks") + for task in tasks: + db = DB(task["db"]) + db_config = db.config_cls(**task["db_config"]) + case_config = CaseConfig(**task["case_config"]) + print(case_config) # noqa: T201 + db_case_config = db.case_config_cls(index_type=task["db_case_config"].get("index", None))( + **task["db_case_config"] + ) + stages = [TaskStage(stage) for stage in task.get("stages", ALL_TASK_STAGES)] + print(stages) # noqa: T201 + task_config = TaskConfig( + db=db, + db_config=db_config, + case_config=case_config, + db_case_config=db_case_config, + stages=stages, + ) + task_configs.append(task_config) + except Exception as e: + return failed_res(message=f"invalid tasks: {e}") + + benchmark_runner.set_download_address(use_aliyun) + benchmark_runner.run(task_configs, task_label) + + return success_res(message="start") + + +def main(): + app.run(host="0.0.0.0", port=5000, debug=False) # noqa: S104 + + +if __name__ == "__main__": + main() diff --git a/vectordb_bench/restful/format_res.py b/vectordb_bench/restful/format_res.py new file mode 100644 index 000000000..2e289ec3b --- /dev/null +++ b/vectordb_bench/restful/format_res.py @@ -0,0 +1,76 @@ +from dataclasses import asdict + +from pydantic import BaseModel + +from vectordb_bench.backend.cases import CaseLabel +from vectordb_bench.models import TestResult + + +class FormatResult(BaseModel): + # db_config + task_label: str = "" + timestamp: int = 0 + db: str = "" + db_label: str = "" # perf-x86 + version: str = "" + note: str = "" + + # params + params: dict = {} + + # case_config + case_name: str = "" + dataset: str = "" + dim: int = 0 + filter_type: str = "" # FilterType(Enum).value + filter_rate: float = 0 + k: int = 100 + + # metrics + max_load_count: int = 0 + load_duration: int = 0 + qps: float = 0 + serial_latency_p99: float = 0 + recall: float = 0 + ndcg: float = 0 + conc_num_list: list[int] = [] + conc_qps_list: list[float] = [] + conc_latency_p99_list: list[float] = [] + conc_latency_avg_list: list[float] = [] + + +def format_results(test_results: list[TestResult], task_label: str) -> list[dict]: + results = [] + for test_result in test_results: + if test_result.task_label == task_label: + for case_result in test_result.results: + task_config = case_result.task_config + case_config = task_config.case_config + case = case_config.case + if case.label == CaseLabel.Load: + continue + dataset = case.dataset.data + filter_ = case.filters + metrics = asdict(case_result.metrics) + for k, v in metrics.items(): + if isinstance(v, list) and len(v) > 0: + metrics[k] = [round(d, 6) if isinstance(d, float) else d for d in v] + results.append( + FormatResult( + task_label=test_result.task_label, + timestamp=int(test_result.timestamp), + db=task_config.db.value, + db_label=task_config.db_config.db_label, + version=task_config.db_config.version, + note=task_config.db_config.note, + params=task_config.db_case_config.dict(), + case_name=case.name, + dataset=dataset.full_name, + dim=dataset.dim, + filter_type=filter_.type.name, + filter_rate=filter_.filter_rate, + k=task_config.case_config.k, + **metrics, + ).dict() + ) + return results