diff --git a/test/common/uc_eval/task.py b/test/common/uc_eval/task.py new file mode 100644 index 000000000..760235283 --- /dev/null +++ b/test/common/uc_eval/task.py @@ -0,0 +1,157 @@ +import time +from abc import ABC, abstractmethod +from typing import Any + +from common.uc_eval.utils.config_loader import ConfigLoader, TaskFactory +from common.uc_eval.utils.data_class import ( + BenchmarkModeType, + EvalConfig, + ModelConfig, + PerfConfig, + SynthericParams, +) +from common.uc_eval.utils.utils import get_logger + +BAD_COMPLETION_TOKENS_THR = 20 +logger = get_logger() + + +class BaseTask(ABC): + def __init__( + self, + model_config: ModelConfig, + perf_config: PerfConfig = None, + eval_config: EvalConfig = None, + ): + ConfigLoader(model_config, perf_config, eval_config) + self.model_config = model_config + self.perf_config = perf_config + self.eval_config = eval_config + + self.dataset, self.client, self.benchmark = TaskFactory.create_task( + model_config, perf_config, eval_config + ) + + @abstractmethod + def process(self) -> Any: + raise NotImplementedError + + +class SyntheticPerfTask(BaseTask): + def __init__(self, model_config: ModelConfig, perf_config: PerfConfig): + super().__init__(model_config, perf_config) + self.enable_clear_hbm = model_config.enable_clear_hbm + self.enable_prefix_cache = perf_config.enable_prefix_cache + self.parallel_num = perf_config.parallel_num + self.prompt_tokens = perf_config.prompt_tokens + self.output_tokens = perf_config.output_tokens + self.prefix_cache_num = perf_config.prefix_cache_num + self.benchmark_mode = perf_config.benchmark_mode + self.stable_perf = perf_config.benchmark_mode == BenchmarkModeType.STABLE_PREF + self.prompt_seed = 0 if self.enable_prefix_cache else -1 + + def process(self): + logger.info( + "-------------------------------------------------------------------" + ) + logger.info( + f"Starting synthetic performance benchmark, the benchmark mode is {self.benchmark_mode}" + ) + result = [] + for parallel_num in self.parallel_num: + for idx in range(len(self.prompt_tokens)): + syntheric_params = SynthericParams() + syntheric_params.parallel_num = parallel_num + if self.stable_perf: + syntheric_params.parallel_num *= 5 + if self.enable_prefix_cache: + syntheric_params.seeds = [ + self.prompt_seed + i + for i in range(syntheric_params.parallel_num) + ] + self.prompt_seed += syntheric_params.parallel_num + else: + syntheric_params.seeds = [ + self.prompt_seed + ] * syntheric_params.parallel_num + syntheric_params.prompt_tokens = self.prompt_tokens[idx] + syntheric_params.prefix_cache_tokens = ( + int(self.prefix_cache_num[idx] * syntheric_params.prompt_tokens) + if self.enable_prefix_cache + else 0 + ) + logger.info( + f"Performance benchmark running with: enable prefix cache: ({self.enable_prefix_cache}), {syntheric_params=}" + ) + if self.enable_prefix_cache and self.prefix_cache_num[idx] > 0: + logger.info(f"Begin build kvcache...") + input_data = self.dataset.prepare_data(syntheric_params) + self.client.handle_requests_with_pool( + input_data, parallel_num, BAD_COMPLETION_TOKENS_THR + ) + logger.info( + "To ensure thal all kvcache is offload2ssd, sleep for 10 seconds" + ) + time.sleep(10) + + if self.enable_clear_hbm: + self.client.clear_hbm() + + logger.info(f"Begin post cases...") + input_data = self.dataset.prepare_data(syntheric_params) + request_records = self.client.handle_requests_with_pool( + input_data, parallel_num, self.output_tokens[idx] + ) + latency_statistics = self.benchmark.perf_show( + request_records, parallel_num + ) + result.append(latency_statistics) + return result + + +class MultiPerfTask(BaseTask): + def __init__(self, model_config: ModelConfig, perf_config: PerfConfig): + super().__init__(model_config, perf_config) + self.data_type = perf_config.data_type + self.dataset_file_path = perf_config.dataset_file_path + self.benchmark_mode = perf_config.benchmark_mode + self.parallel_num = perf_config.parallel_num + + def process(self): + logger.info( + f"Begin test, the data type: {self.data_type}, the benchmark mode: {self.benchmark_mode}" + ) + cases = self.dataset.prepare_data(self.dataset_file_path) + records = self.client.handle_requests_with_pool(cases, self.parallel_num) + all_records = [r for record in records for r in record] + latency_statistics = self.benchmark.perf_show(all_records, self.parallel_num) + return latency_statistics + + +class DocQaPerfTask(BaseTask): + def __init__(self, model_config: ModelConfig, perf_config: PerfConfig): + super().__init__(model_config, perf_config) + self.data_type = perf_config.data_type + self.dataset_file_path = perf_config.dataset_file_path + self.enable_prefix_cache = perf_config.enable_prefix_cache + self.parallel_num = perf_config.parallel_num + self.max_tokens = model_config.payload.get("max_tokens") + self.benchmark_mode = perf_config.benchmark_mode + + def process(self): + logger.info( + f"Begin test, the data type: {self.data_type}, the benchmark mode: {self.benchmark_mode}" + ) + cases_list = self.dataset.prepare_data(self.dataset_file_path) + if self.enable_prefix_cache: + logger.info("Begin build kvcache...") + self.client.handle_requests_with_pool( + cases_list, self.parallel_num, BAD_COMPLETION_TOKENS_THR + ) + + logger.info("Begin post cases...") + record = self.client.handle_requests_with_pool( + cases_list, self.parallel_num, self.max_tokens + ) + latency_statistics = self.benchmark.perf_show(record, self.parallel_num) + return latency_statistics diff --git a/test/common/uc_eval/utils/benchmark.py b/test/common/uc_eval/utils/benchmark.py new file mode 100644 index 000000000..a596b760e --- /dev/null +++ b/test/common/uc_eval/utils/benchmark.py @@ -0,0 +1,253 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import numpy as np +from common.uc_eval.utils.data_class import ( + EvalConfig, + LatencyStatistics, + MultiTurnDialogRecord, + RequestRecord, +) +from common.uc_eval.utils.utils import get_logger +from tqdm import tqdm + +logger = get_logger() +MS_SCALE = 1000 +# the max wave rate for stable perf +MAX_WAVE_RATE = 0.05 + + +class BenchmarkBase(ABC): + def __init__(self, eval_config: Optional[EvalConfig], stable_perf: bool = False): + self.stable_perf = stable_perf + + def get_success_request(self, data: List[RequestRecord | MultiTurnDialogRecord]): + """ + Get the successful request from the record + """ + success_request = [] + for request in data: + if request.is_success: + success_request.append(request) + if len(success_request) == 0: + logger.warning(f"No success request found, please check the result") + return success_request + + def result_to_column_dict( + self, data: List[RequestRecord | MultiTurnDialogRecord] + ) -> Dict[str, List[Any]]: + """ + format: list[dict] ---> dict[list] + """ + if not data: + return {} + keys = list(data[0].to_dict().keys()) + result = {key: [] for key in keys} + for item in data: + for key in keys: + result[key].append(item.to_dict()[key]) + return result + + @abstractmethod + def perf_show(self, records: Any, parallel_num: int = 1): + raise NotImplementedError + + +class EvaluatorBenchmark(BenchmarkBase): + def __init__(self, stable_perf: bool, eval_class: str): + self.stable_perf = stable_perf + self.metric_method = eval_class + + def perf_show(self, records: List[RequestRecord | MultiTurnDialogRecord]): + pass + + +class PerformanceBenchmark(BenchmarkBase): + def __init__(self, stable_perf: bool): + super().__init__(stable_perf) + self.stable_perf = stable_perf + self.stable_work_time = [0, 0] + + def perf_show( + self, + input_data_lists: List[RequestRecord | MultiTurnDialogRecord], + parallel_num: int, + ) -> LatencyStatistics: + logger.info(f"Begin calculate latency...") + success_request = self.get_success_request(input_data_lists) + request_record_dict = self.result_to_column_dict(success_request) + if self.stable_perf: + request_ids = self._get_stable_request_id(request_record_dict, parallel_num) + else: + request_ids = request_record_dict.get("request_id") + records = [ + record for record in input_data_lists if record.request_id in request_ids + ] + perf_result = self._get_performance_data(records) + return perf_result + + def _get_performance_data( + self, record_list: List[RequestRecord | MultiTurnDialogRecord] + ) -> LatencyStatistics: + """ + After all requests are completed, get the performance data + """ + if len(record_list) == 0: + logger.warning(f"there is no request_id in the record_list, please check") + latency = LatencyStatistics() + record_dict = self.result_to_column_dict(record_list) + + e2e_latency_all = ( + max(record_dict["end_time"]) - min(record_dict["start_time"]) + ) * MS_SCALE + latency.e2e_latency_all = round(e2e_latency_all, 2) + logger.debug("All request latencies: %.4f ms", e2e_latency_all) + + total_output_tokens = sum(record_dict["output_tokens"]) + output_token_throughput = total_output_tokens / e2e_latency_all * MS_SCALE + latency.output_token_throughput = round(output_token_throughput, 2) + logger.debug( + "Total output token throughput: %.4f tokens/s", output_token_throughput + ) + + throughputs = [] + for tokens, cost in zip(record_dict["output_tokens"], record_dict["req_cost"]): + if cost > 0: + throughputs.append(tokens / cost) + if throughputs: + token_throughput_per_request = np.mean(throughputs) + latency.token_throughput_per_request = round( + token_throughput_per_request, 2 + ) + logger.debug( + "Average per-request throughput: %.4f tokens/s", + token_throughput_per_request, + ) + else: + logger.warning("No valid requests for throughput calculation") + + prefill_latency_list = [record_dict["prefill_latency"]] + p50_prefill_latency = np.percentile(prefill_latency_list, 50) * MS_SCALE + latency.p50_prefill_latency = round(p50_prefill_latency, 2) + logger.debug("Time to First token latency P50: %.4f ms", p50_prefill_latency) + + p90_prefill_latency = np.percentile(prefill_latency_list, 90) * MS_SCALE + latency.p90_prefill_latency = round(p90_prefill_latency, 2) + logger.debug("Time to First token latency TP90: %.4f ms", p90_prefill_latency) + + p99_prefill_latency = np.percentile(prefill_latency_list, 99) * MS_SCALE + latency.p99_prefill_latency = round(p99_prefill_latency, 2) + logger.debug("Time to First token latency TP99: %.4f ms", p99_prefill_latency) + + max_prefill_latency = np.max(prefill_latency_list) * MS_SCALE + latency.max_prefill_latency = round(max_prefill_latency, 2) + logger.debug( + "Maximum time to first token latency: %.4f ms", max_prefill_latency + ) + + avg_prefill_latency = np.mean(prefill_latency_list) * MS_SCALE + latency.avg_prefill_latency = round(avg_prefill_latency, 2) + logger.debug( + "Average time to first token latency: %.4f ms", avg_prefill_latency + ) + + # 这里是list[list[float]],要求的是这里所有float的平均(这里可能也会要求list[float]中去除前两个元素后,取所有请求per decode time的均值) + decode_latency_list = [] + for tmp_list in record_dict["tbt_list"]: + decode_latency_list.extend(tmp_list[2:]) + + p50_decode_latency = np.percentile(decode_latency_list, 50) * MS_SCALE + latency.p50_decode_latency = round(p50_decode_latency, 2) + logger.debug("Tokens Per Second latency TP50: %.4f ms", p50_decode_latency) + + p90_decode_latency = np.percentile(decode_latency_list, 90) * MS_SCALE + latency.p90_decode_latency = round(p90_decode_latency, 2) + logger.debug("Tokens Per Second latency TP90: %.4f ms", p90_decode_latency) + + p99_decode_latency = np.percentile(decode_latency_list, 99) * MS_SCALE + latency.p99_decode_latency = round(p99_decode_latency, 2) + logger.debug("Tokens Per Second latency TP99: %.4f ms", p99_decode_latency) + + max_decode_latency = np.max(decode_latency_list) * MS_SCALE + latency.max_decode_latency = round(max_decode_latency, 2) + logger.debug("Maximum tokens per second latency: %.4f ms", max_decode_latency) + + avg_decode_latency = np.mean(decode_latency_list) * MS_SCALE + latency.avg_decode_latency = round(avg_decode_latency, 2) + logger.debug("Average tokens per second latency: %.4f ms", avg_decode_latency) + + return latency.to_dict() + + def _get_stable_request_id( + self, result: Dict[str, List[Any]], target_concurrency: int + ): + """ + Get steady-state request ids via start_time vs. end_time delta + """ + # the number of concurrent requests at each request start and end + request_num = len(result.get("request_id", [])) + concurrent_levels = [0] * 2 * request_num + request_events = [] + for idx in range(request_num): + request_events.append( + { + "request_id": result.get("request_id", [])[idx], + "event_type": "start", + "timestamp": result.get("start_time", [])[idx], + } + ) + request_events.append( + { + "request_id": result.get("request_id", [])[idx], + "event_type": "end", + "timestamp": result.get("end_time", [])[idx], + } + ) + sorted_events = sorted(request_events, key=lambda x: x["timestamp"]) + stable_stage_requests = [] + logger.info("Start calculating stable request id") + used_request_num = 0 + for idx, item in enumerate( + tqdm(sorted_events, desc="search stable request id") + ): + if item["event_type"] == "start": + used_request_num += 1 + concurrent_levels[idx] = ( + concurrent_levels[idx - 1] + 1 if idx > 0 else 1 + ) + else: + concurrent_levels[idx] = concurrent_levels[idx - 1] - 1 + if ( + item["event_type"] == "start" + and concurrent_levels[idx] == target_concurrency + ): + stable_stage_requests.append(item["request_id"]) + if len(stable_stage_requests) == 2: + self.stable_work_time[0] = item["timestamp"] + elif ( + item["event_type"] == "start" + and concurrent_levels[idx] + >= int(target_concurrency * (1 - MAX_WAVE_RATE)) + and len(stable_stage_requests) > 2 + ): + stable_stage_requests.append(item["request_id"]) + elif used_request_num == request_num and item["event_type"] == "end": + self.stable_work_time[1] = item["timestamp"] + break + elif ( + len(stable_stage_requests) > 1 + and item["event_type"] == "end" + and concurrent_levels[idx] + < int(target_concurrency * (1 - MAX_WAVE_RATE)) + ): + self.stable_work_time[1] = item["timestamp"] + break + + if len(stable_stage_requests) > 1: + # ignore first request + stable_stage_requests.pop(0) + if len(stable_stage_requests) == 0: + logger.error("cannot find stable stage, please check your settings") + raise ValueError("cannot find stable stage, please check your settings") + logger.info(f"stable request id list: {stable_stage_requests=}") + return stable_stage_requests diff --git a/test/common/uc_eval/utils/client.py b/test/common/uc_eval/utils/client.py new file mode 100644 index 000000000..24d7e3a69 --- /dev/null +++ b/test/common/uc_eval/utils/client.py @@ -0,0 +1,489 @@ +import concurrent.futures +import copy +import json +import os +import time +import uuid +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +import requests +from common.uc_eval.utils.data_class import ( + ModelConfig, + MultiTurnDialogRecord, + RequestRecord, +) +from common.uc_eval.utils.utils import PathUtil, get_logger +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer +from typing_extensions import override + +logger = get_logger() +TIMEOUT = 6000 +HEADERS = {"User-Agent": "Benchmark Client", "Content-Type": "application/json"} +CHUNK_SIZE = 2**16 + + +def _excute_with_pool( + task_func: callable, + process_func: callable, + tasks: List, + parallel_num: int, + desc: str = "Processing Requests", +) -> List[RequestRecord | MultiTurnDialogRecord]: + record_results: List[RequestRecord | MultiTurnDialogRecord] = [] + if parallel_num > len(tasks): + logger.error( + f"The number of requests: {len(tasks)} is less than parallel_num: {parallel_num}, please check..." + ) + raise ValueError( + f"The number of requests: {len(tasks)} is less than parallel_num: {parallel_num}, please check..." + ) + logger.info(f"Start to send {len(tasks)} requests to server...") + with ThreadPoolExecutor(max_workers=parallel_num) as executor: + futures = [executor.submit(task_func, task) for task in tasks] + + with tqdm(total=len(futures), desc=desc, mininterval=0.5) as pbar: + for future in concurrent.futures.as_completed(futures): + try: + pbar.update(1) + result = process_func(future.result()) + record_results.append(result) + pbar.set_postfix( + { + "Completed": len(record_results), + "Pending": len(futures) - pbar.n, + } + ) + except Exception as e: + pbar.update(1) + logger.error(f"Requested failed: {str(e)}") + raise Exception(f"Requested failed: {str(e)}") + return record_results + + +class BaseClient: + def __init__( + self, + config: ModelConfig, + stream: bool = False, + enable_prefix_cache: bool = False, + ): + self.ip_ports = config.ip_ports + self.url = f"http://{self.ip_ports}/v1/chat/completions" + self.served_model_name = config.served_model_name + tokenizer_path = PathUtil.get_datasets_dir_path(config.tokenizer_path) + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + tokenizer_path + ) + self.session = requests.Session() + self.payload = config.payload + self.enable_prefix_cache = enable_prefix_cache + self.stream = stream + if self.stream: + self.payload.update( + {"stream": True, "ignore_eos": True, "temperature": 0.0} + ) + else: + self.payload.update( + {"stream": False, "ignore_eos": False, "temperature": 0.0} + ) + + def handle_requests_with_pool( + self, prompt_list: List, parallel_num: int, max_tokens: int + ) -> List[RequestRecord]: + return _excute_with_pool( + task_func=lambda prompt: self.send_request(prompt, max_tokens), + process_func=self.update_request_record, + tasks=prompt_list, + parallel_num=parallel_num, + ) + + def send_request(self, prompt, max_tokens) -> List[RequestRecord]: + """ + update payload and send request + """ + payload = self._update_payload(prompt, max_tokens) + if self.stream: + record = self.do_stream_request(payload) + else: + record = self.do_request(payload) + return record + + def _update_payload(self, prompt, max_tokens) -> Dict: + """ + update request payload + """ + payload = copy.deepcopy(self.payload) + payload.update({"model": self.served_model_name}) + # If payload already has default max_tokens, the input max_tokens will be set to 0 + if max_tokens > 0: + payload.update({"max_tokens": max_tokens}) + if isinstance(prompt, str): + message = [{"role": "user", "content": prompt}] + if isinstance(prompt, list): + # Multi-turn conversation - prompt already contains full message history. + # No need to update messages as they are already properly formatted + message = prompt + payload.update({"messages": message}) + + return payload + + def _create_record(self, prompt): + # If the prompt is not a dict, it must be a list of dicts for multi-turn dialogue. + if isinstance(prompt, dict): + record = RequestRecord(input_data=prompt["content"]) + else: + record = RequestRecord(input_data=str(prompt)) + + return record + + def update_request_record( + self, records: Union[RequestRecord, List[RequestRecord]] + ) -> Union[RequestRecord, List[RequestRecord]]: + """ + Get the number of input and output tokens for each request record + """ + if not records: + logger.warning("No records to update, please check...") + if isinstance(records, RequestRecord): + single_record = records + records = [single_record] + else: + single_record = None + + for record in records: + record.input_tokens = len(self.tokenizer.tokenize(record.input_data)) + record.output_tokens = len(self.tokenizer.tokenize(record.output_data)) + + return records[0] if single_record is not None else records + + def _requset(self, payload): + response = None + try: + response = self.session.post( + self.url, + headers=HEADERS, + json=payload, + timeout=TIMEOUT, + stream=self.stream, + ) + response.raise_for_status() + return response + except Exception as err: + raise self._handle_request_error(err) + + def do_request(self, payload: Dict) -> RequestRecord: + prompt = payload["messages"] + record = self._create_record(prompt) + record.start_time = time.time() + + response = self._requset(payload) + result = json.loads(response.text) + request_id = result.get("id", "request_id not found") + output = self._get_message_from_response(result) + + record.request_id = request_id + record.output_data = output + record.is_success = True + record.end_time = time.time() + record.req_cost = record.end_time - record.start_time + return record + + def _get_message_from_response(self, response) -> str: + message = response.get("choices", [])[0].get("message", {}) + output = "" + if message.get("content", "") is not None: + output += message.get("content", "") + elif message.get("reasoning_content", "") is not None: + output += message.get("reasoning_content", "") + return output + + def do_stream_request(self, payload: Dict) -> RequestRecord: + prompt = payload["messages"] + record = self._create_record(prompt) + while True: + all_chunks = [] + first_token = True + last_chunk = None + timeout_finish_reason = False + cur_time = last_time = time.perf_counter() + record.start_time = last_time + response = self._requset(payload) + for chunk in response.iter_content(chunk_size=CHUNK_SIZE): + all_chunks.append(chunk) + if len(chunk.strip()) == 0: + continue + last_chunk = chunk + cur_time = time.perf_counter() + time_diff = cur_time - last_time + if first_token: + record.prefill_latency = time_diff + first_token = False + else: + record.tbt_list.append(time_diff) + last_time = cur_time + chunk_output = chunk[5:].strip().decode("utf-8") + + # when the MindIE engine side timeout, it will return timeout information + if chunk.startswith(b"Engine callback timeout"): + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="Engine callback timeout", + ) + record.output_data = "TIMEOUT" + return record + if "[DONE]" in chunk_output: + logger.debug(f"Finished chunk: {chunk_output=}") + continue + output = self._get_message_from_stream_response( + json.loads(chunk_output) + ) + if record.request_id == "": + record.request_id = json.loads(chunk_output).get( + "id", "request_id not found" + ) + record.output_data += output + + # when the uc-vllm request timeout, finish_reason == "length" and the final output is empty + finish_reason = ( + json.loads(chunk_output) + .get("choices", [])[0] + .get("finish_reason", "") + ) + if finish_reason == "length": + timeout_finish_reason = True + + # handle the last chunk + if last_chunk.startswith(b"data:"): + chunk_output = last_chunk[5:].strip().decode("utf-8") + else: + chunk_output = last_chunk.strip().strip().decode("utf-8").rstrip("\0") + # while the last chunk meets the following conditions, the request is finished successfully + if "[DONE]" in chunk_output: + break + else: + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="request failed, please retry!!!", + ) + break + # while the request is done, we need to check the content to see if the request is successful + if record.output_data == "": + if timeout_finish_reason: + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="vllm server scheduling timeout, please check", + ) + return record + else: + self._print_request_info( + request_id=record.request_id, + chunk=chunk, + content=record.output_data, + all_chunks=all_chunks, + payload=payload, + msg="the request returned an empty message, which may be an unknown error on the engine side. Please check the specific reason!", + ) + return record + record.is_success = True + record.end_time = time.perf_counter() + record.req_cost = record.end_time - record.start_time + logger.debug(f"{record.request_id} finished, cost: {record.req_cost:.2f}s") + return record + + def _get_message_from_stream_response(self, response) -> str: + message = response.get("choices", [])[0].get("delta", {}) + output = "" + if message.get("content", "") is not None: + output += message.get("content", "") + elif message.get("reasoning_content", "") is not None: + output += message.get("reasoning_content", "") + return output + + def clear_hbm(self) -> bool: + """ + The API is used to clear HBM. It is available only when the serving backend is VLLM. + """ + os.environ["NO_PROXY"] = "127.0.0.1, localhost, local, .local" + logger.info("Begin to clear HBM") + headers = {"Content-Type": "application/json"} + payload = {} + url = f"http://{self.ip_ports}/reset_prefix_cache" + try: + response = requests.post( + url, json=payload, headers=headers, timeout=TIMEOUT + ) + response.raise_for_status() + except Exception as err: + raise self._handle_request_error(err) + time.sleep(5) + logger.info("Clear HBM success") + return True + + def _handle_request_error(self, err: Exception) -> Exception: + """ + Used to handle request errors + """ + if isinstance(err, requests.exceptions.ConnectionError): + logger.error(f"Cannot connect to {self.url}, please check your network") + return ConnectionError(f"Cannot connect to {self.url}") + elif isinstance(err, requests.exceptions.Timeout): + logger.error("The request timed out, please check your server status") + return TimeoutError( + "The request timed out, please check your server status" + ) + elif isinstance(err, requests.exceptions.HTTPError): + status_code = err.response.status_code + if status_code == 404: + logger.error( + f"The requested resource does not exist, or the served model name is incorrect" + ) + else: + logger.error(f"HTTP error, status code: {status_code}") + return Exception(f"HTTP error, status code: {status_code}, err: {err}") + else: + logger.error(f"Other error: {err}") + return Exception(f"Other error: {err}") + + @staticmethod + def _print_request_info(**kwargs): + """print request info when the request is failed""" + for key, value in kwargs.items(): + value = ( + json.dumps(value, ensure_ascii=False) + if isinstance(value, dict) + else value + ) + logger.error(f"{key} => {value}") + + +class MultiDialogClient(BaseClient): + def __init__(self, config: ModelConfig, stream: bool, enable_prefix_cache: bool): + super().__init__(config, stream, enable_prefix_cache) + self.uuid = uuid.uuid4().hex + + @override + def handle_requests_with_pool( + self, + cases: List[List[Union[str, Dict]]], + parallel_num: int, + max_tokens: int = -1, + ) -> List[List[MultiTurnDialogRecord]]: + return _excute_with_pool( + task_func=lambda case: self._send_multi_request(case, max_tokens), + process_func=self.update_request_record, + tasks=cases, + parallel_num=parallel_num, + ) + + def _send_multi_request( + self, case: List[Union[str, Dict]], max_tokens: int = -1 + ) -> List[MultiTurnDialogRecord]: + case_name, dialog = case + history, conv_record = [], [] + conversion = dialog["conversations"] + turns = self._convert_conversation_2_turns(conversion, 2) + for i, turn in enumerate(turns): + in_content, reply = turn[0]["content"], turn[1]["content"] + # 更新payload,然后发送请求 + prompt = self._update_request_body(history, in_content) + record: RequestRecord = self.send_request(prompt, max_tokens) + record.case_name = case_name + history = self._update_history(history, in_content, reply) + multi_turn_record: MultiTurnDialogRecord = ( + self._update_multi_turn_request_record(record, len(turns), i) + ) + conv_record.append(multi_turn_record) + return conv_record + + def _update_multi_turn_request_record( + self, record: RequestRecord, total_turns: int, turn_id: int + ) -> MultiTurnDialogRecord: + """ + Update multi-tuen dialogue request record + """ + request_record = MultiTurnDialogRecord() + request_record.__dict__.update(record.__dict__) + request_record.total_turns = total_turns + request_record.turn_id = turn_id + return request_record + + @staticmethod + def _convert_conversation_2_turns(conversion_list: list, chunk_size: int): + """ + Convert conversation list to turns + """ + if chunk_size < 0: + raise ValueError(f"the chunk size {chunk_size} must be greater than 0") + num_full_chunks = len(conversion_list) // chunk_size + return [ + conversion_list[i * chunk_size : (i + 1) * chunk_size] + for i in range(num_full_chunks) + ] + + def _update_request_body(self, history: Optional[List[Dict]], in_content: str): + """ + Multi turn dialogue request body + """ + history = copy.deepcopy(history) + if history and self.enable_prefix_cache: + # To make sure the prefix cache is unique + history[0]["content"] = f"uuid: [{self.uuid}]" + history[0]["content"] + if history and not self.enable_prefix_cache: + history[0]["content"] = ( + f"uuid: [{uuid.uuid4().hex}]" + history[0]["content"] + ) + + message = history + [{"role": "user", "content": in_content}] + return message + + @staticmethod + def _update_history( + history: Optional[List[Dict]], in_content: str, out_content: str + ) -> List[Dict]: + """ + Update conversation history + """ + history.append({"role": "user", "content": in_content}) + history.append({"role": "assistant", "content": out_content}) + return history + + +class DocQaClient(BaseClient): + def __init__(self, config: ModelConfig, stream: bool, enable_prefix_cache: bool): + super().__init__(config, stream, enable_prefix_cache) + + @override + def handle_requests_with_pool( + self, cases: List[Union[str, str, str]], parallel_num: int, max_tokens: int = -1 + ) -> List[List[MultiTurnDialogRecord]]: + return _excute_with_pool( + task_func=lambda case: self.send_qa_request(case, max_tokens), + process_func=self.update_request_record, + tasks=cases, + parallel_num=parallel_num, + ) + + def send_qa_request( + self, case: Union[str, str, str, str], max_tokens: int = -1 + ) -> RequestRecord: + case_name, context, question, answer = case + prompt = context + question + record: RequestRecord = self.send_request(prompt, max_tokens) + record.case_name = case_name + record.question = question + record.expected_output = answer + return record diff --git a/test/common/uc_eval/utils/config_loader.py b/test/common/uc_eval/utils/config_loader.py new file mode 100644 index 000000000..4eda8810e --- /dev/null +++ b/test/common/uc_eval/utils/config_loader.py @@ -0,0 +1,199 @@ +import dataclasses +import functools +import importlib +import json +from typing import Any, Optional, Tuple + +from common.uc_eval.utils.benchmark import ( + BenchmarkBase, + EvaluatorBenchmark, + PerformanceBenchmark, +) +from common.uc_eval.utils.client import BaseClient, DocQaClient, MultiDialogClient +from common.uc_eval.utils.data_class import ( + BenchmarkModeType, + DatasetType, + EvalConfig, + ModelConfig, + PerfConfig, +) +from common.uc_eval.utils.dataloader import ( + BaseDataset, + DocQADataset, + MultiTurnDialogueDataset, + SyntheticDataset, +) +from common.uc_eval.utils.utils import get_logger + +logger = get_logger() + + +def make_object(object_ref: str, *args: Any, **kwargs: Any) -> Any: + """create object based on class name""" + modname, qualname_separator, qualname = object_ref.partition(":") + obj = importlib.import_module(modname) + if qualname_separator: + for attr in qualname.split("."): + obj = getattr(obj, attr) + return functools.partial(obj, *args, **kwargs) + + +class ConfigLoader: + def __init__( + self, + model_config: ModelConfig, + perf_config: PerfConfig = None, + eval_config: EvalConfig = None, + ): + + self.model_config = model_config + self.perf_config = perf_config + self.eval_config = eval_config + self._valid_config() + + def _valid_config(self) -> bool: + logger.info("Validating config...") + if self.perf_config is not None and self.eval_config is not None: + raise ValueError( + "perf_config and eval_config are mutually exclusive – one must be None." + ) + if self.perf_config is None and self.eval_config is None: + raise ValueError( + "At least one of perf_config or eval_config must be provided." + ) + + result = ( + self._valid_model_config() and self._valid_perf_config() + if self.perf_config is not None + else self._valid_eval_config() + ) + logger.info("Complete validation...") + return result + + def _valid_model_config(self) -> bool: + payload = self.model_config.payload + if isinstance(payload, str): + try: + self.model_config.payload = json.loads(payload) + except Exception as e: + raise ValueError(f"Invalid payload JSON format: {e}") + + empty_fields = [] + field_names = [field.name for field in dataclasses.fields(ModelConfig)] + for field_name in field_names: + value = getattr(self.model_config, field_name) + if value is None or (isinstance(value, str) and not value.strip()): + empty_fields.append(field_name) + + if empty_fields: + raise ValueError( + f"The following model config fields can't be empty: {', '.join(empty_fields)}" + ) + + return True + + def _valid_perf_config(self) -> bool: + data_type = self.perf_config.data_type + benchmark_mode = self.perf_config.benchmark_mode + if benchmark_mode not in [ + BenchmarkModeType.DEFAULT_PERF, + BenchmarkModeType.STABLE_PREF, + ]: + raise ValueError( + f"Invalid benchmark mode: {benchmark_mode}. Valid modes are: {BenchmarkModeType.DEFAULT_PERF}, {BenchmarkModeType.STABLE_PREF}" + ) + prompt_fields = ["prompt_tokens", "output_tokens"] + ( + ["prefix_cache_num"] if self.perf_config.enable_prefix_cache else [] + ) + if data_type == DatasetType.SYNTHETIC: + invalid_fields = [] + for field in prompt_fields: + value = getattr(self.perf_config, field) + if not isinstance(value, list) or not value: + invalid_fields.append(field) + if invalid_fields: + raise ValueError( + f"The following dataset config fields must be non-empty list for synthetic data: {', '.join(invalid_fields)}" + ) + + length = { + field: len(getattr(self.perf_config, field)) for field in prompt_fields + } + if len(set(length.values())) > 1: + raise ValueError( + f"The following dataset config is not matched: {', '.join(length.keys())}" + ) + else: + if self.perf_config.dataset_file_path is None: + raise ValueError( + f"dataset_file_path is required for {data_type} data type" + ) + if not isinstance(self.perf_config.parallel_num, int): + raise TypeError( + f"parallel_num must be an integer for {data_type} data type" + ) + not_empty_fields = [ + field for field in prompt_fields if getattr(self.perf_config, field) + ] + if not_empty_fields: + raise ValueError( + f"The following dataset fields should be None for {data_type} data type: {not_empty_fields}" + ) + + return True + + def _valid_eval_config(self) -> bool: + data_type = self.perf_config.data_type + dataset_file_path = self.perf_config.dataset_file_path + benchmark_mode = self.perf_config.benchmark_mode + if benchmark_mode != BenchmarkModeType.EVAL: + raise ValueError( + f"Invalid benchmark mode: {benchmark_mode}. Valid modes are: {BenchmarkModeType.EVAL}" + ) + if data_type == DatasetType.SYNTHETIC or dataset_file_path is None: + raise ValueError( + f"Invalid dataset type: {data_type} or Invalid dataset file path: {dataset_file_path}" + ) + # TODO: add more validations + return True + + +class TaskFactory: + _dataset: BaseDataset = { + DatasetType.SYNTHETIC: SyntheticDataset, + DatasetType.MULTI_DIALOGUE: MultiTurnDialogueDataset, + DatasetType.DOC_QA: DocQADataset, + } + _client: BaseClient = { + DatasetType.SYNTHETIC: BaseClient, + DatasetType.MULTI_DIALOGUE: MultiDialogClient, + DatasetType.DOC_QA: DocQaClient, + } + _benchmark: BenchmarkBase = { + BenchmarkModeType.EVAL: EvaluatorBenchmark, + BenchmarkModeType.STABLE_PREF: PerformanceBenchmark, + BenchmarkModeType.DEFAULT_PERF: PerformanceBenchmark, + } + + @classmethod + def create_task( + cls, + model_config: ModelConfig, + perf_config: Optional[PerfConfig], + eval_config: Optional[EvalConfig], + ) -> Tuple[BaseDataset, BaseClient, BenchmarkBase]: + data_type = (perf_config or eval_config).data_type + tokenizer_path = model_config.tokenizer_path + enable_prefix_cache = perf_config.enable_prefix_cache + benchmark_mode = (perf_config or eval_config).benchmark_mode + stable = benchmark_mode == BenchmarkModeType.STABLE_PREF + if benchmark_mode in [ + BenchmarkModeType.STABLE_PREF, + BenchmarkModeType.DEFAULT_PERF, + ]: + stream = True + return ( + cls._dataset[data_type](tokenizer_path), + cls._client[data_type](model_config, stream, enable_prefix_cache), + cls._benchmark[benchmark_mode](stable if perf_config else eval_config), + ) diff --git a/test/common/uc_eval/utils/data_class.py b/test/common/uc_eval/utils/data_class.py new file mode 100644 index 000000000..817208b14 --- /dev/null +++ b/test/common/uc_eval/utils/data_class.py @@ -0,0 +1,160 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + + +class DatasetType(str, Enum): + """ + The dataset type of uc_eval, including synthetic, multi-turn dialogue, and document-QA. + """ + + SYNTHETIC = "synthetic" + MULTI_DIALOGUE = "multi_turn_dialogue" + DOC_QA = "doc_qa" + + +class BenchmarkModeType(str, Enum): + """ + The benchmark mode of uc_eval, including evaluate, stable-perf, and default-perf. + """ + + EVAL = "evaluate" + STABLE_PREF = "stable-perf" + DEFAULT_PERF = "default-perf" + + +@dataclass +class ModelConfig: + ip_ports: str = "" + tokenizer_path: str = "" + served_model_name: str = "" + enable_clear_hbm: bool = False + payload: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class EvalConfig: + data_type: str = "" + dataset_file_path: str = "" + enable_prefix_cache: str = False + parallel_num: int = 1 + benchmark_mode: str = "evaluate" + metrics: Optional[List[str]] = field(default_factory=list) + eval_class: Optional[str] = None + + +@dataclass +class PerfConfig: + data_type: str = "" + dataset_file_path: str = "" + enable_prefix_cache: bool = False + parallel_num: int | List[int] = 1 + prompt_tokens: List[int] = field(default_factory=list) + output_tokens: List[int] = field(default_factory=list) + prefix_cache_num: List[float] = field(default_factory=list) + benchmark_mode: str = "" + + +@dataclass +class SynthericParams: + """ + The parameters for synthetic dataset + """ + + parallel_num: int = -1 + # The number of tokens for total prompts + prompt_tokens: int = -1 + # The number of tokens for prefix cache + prefix_cache_tokens: int = -1 + # List of seeds, to ensure the prefix cache is consistent between warmup and inference + seeds: list[int] = field(default_factory=list) + + def to_dict(self): + return vars(self) + + +@dataclass +class RequestRecord: + """ + The record for single request + """ + + case_name: str = "" + request_id: str = "" + input_data: Optional[str] = "" + input_tokens: int = 0 + # The real output + output_data: str = "" + output_tokens: int = 0 + # The expected output + expected_output: str = "" + # The question of the request + question: str = "" + start_time: float = 0.0 + end_time: float = 0.0 + # The cost of the request + req_cost: float = 0.0 + # Time to first token, cost of the prefill + prefill_latency: float = 0.0 + # Time between tokens + tbt_list: list[float] = field(default_factory=list) + # Whether the request is successful + is_success: bool = False + + def to_dict(self): + return vars(self) + + +@dataclass +class MultiTurnDialogRecord(RequestRecord): + """ + The record for multi-turn dialogue request + """ + + # The total turn of the conversation + total_turns: int = -1 + # The current turn of the dialog + turn_id: int = -1 + # The input content of this dialog, which deletes the history information + in_content: str = "" + # If this request belongs to QA dialog + is_qa: bool = False + + +@dataclass +class LatencyStatistics: + """ + the latency statistics of all requests + """ + + # The total latency of all requests(ms) + e2e_latency_all: float = -1 + # The end to end average throughput(tokens/s) + output_token_throughput: float = -1 + # The average throughput of all requests(tokens/s) + token_throughput_per_request: float = -1 + # The TP50 latency of time to first tokens(ms) + p50_prefill_latency: float = -1 + # The TP90 latency of time to first tokens(ms) + p90_prefill_latency: float = -1 + # The TP99 latency of time to first tokens(ms) + p99_prefill_latency: float = -1 + # The max latency of time to first tokens(ms) + max_prefill_latency: float = -1 + # The average latency of time to first tokens(ms) + avg_prefill_latency: float = -1 + # The TP50 latency of decoder latency(ms) + p50_decode_latency: float = -1 + # The TP90 latency of decoder latency(ms) + p90_decode_latency: float = -1 + # The TP99 latency of decoder latency(ms) + p99_decode_latency: float = -1 + # The max latency of decoder latency(ms) + max_decode_latency: float = -1 + # The average latency of decoder latency(ms) + avg_decode_latency: float = -1 + # The average confidence + avg_confidence: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self): + return vars(self) diff --git a/test/common/uc_eval/utils/dataloader.py b/test/common/uc_eval/utils/dataloader.py new file mode 100644 index 000000000..f59c3b485 --- /dev/null +++ b/test/common/uc_eval/utils/dataloader.py @@ -0,0 +1,214 @@ +import json +import random +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Union + +import numpy as np +from common.uc_eval.utils.data_class import SynthericParams +from common.uc_eval.utils.utils import PathUtil, get_logger +from tqdm import tqdm +from transformers import AutoTokenizer, PreTrainedTokenizer + +logger = get_logger() +EPOCH_NUM = 10 + + +class BaseDataset(ABC): + def __init__( + self, + tokenizer_path: str = None, + ): + tokenizer_path = PathUtil.get_datasets_dir_path(tokenizer_path) + self.tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( + tokenizer_path + ) + + @abstractmethod + def prepare_data(self, param: Any): + raise NotImplementedError + + +class SyntheticDataset(BaseDataset): + def __init__(self, tokenizer_path: str): + super().__init__(tokenizer_path) + + def prepare_data(self, syntheric_params: SynthericParams) -> list[str]: + prompt_list = [] + for parallel_num in tqdm( + range(syntheric_params.parallel_num), + desc="Generate synthetic data", + unit="prompt", + ): + random_prompt_len = max( + 0, syntheric_params.prompt_tokens - syntheric_params.prefix_cache_tokens + ) + random_prompt = self.generate_random_str(random_prompt_len, time.time_ns()) + if syntheric_params.prefix_cache_tokens > 0: + pc_prompt = self.generate_random_str( + syntheric_params.prefix_cache_tokens, + syntheric_params.seeds[parallel_num], + ) + else: + pc_prompt = "" + final_prompt = pc_prompt + random_prompt + prompt_list.append(final_prompt) + return prompt_list + + def generate_random_str(self, length: int, seed: int) -> str: + """ + Sample random tokens from the tokenizer using a seed. + Use timestamp when cache hit is not required; otherwise use an incrementing seed. + """ + if length <= 0: + return "" + vocab_size = self.tokenizer.vocab_size + random.seed(seed) + ids_list = random.choices(range(vocab_size // 4, vocab_size // 3), k=length) + ids = np.array(ids_list) + text = self.tokenizer.decode(ids) + completion_token_ids = self.tokenizer([text]).input_ids + logger.debug( + f"len(completion_token_ids[0]) = {len(completion_token_ids[0])}, length = {length}" + ) + + epoch = EPOCH_NUM + while len(completion_token_ids[0]) != length and epoch > 0: + epoch -= 1 + while len(completion_token_ids[0]) > length: + diff = len(completion_token_ids[0]) - length + now_length = ids.shape[0] - diff + ids = ids[:now_length] + text = self.tokenizer.decode(ids) + completion_token_ids = self.tokenizer([text]).input_ids + + while len(completion_token_ids[0]) < length: + diff = length - len(completion_token_ids[0]) + diff_ids_list = random.choices( + range(vocab_size // 4, vocab_size // 3), k=diff + ) + diff_ids = np.array(diff_ids_list) + ids = np.append(ids, diff_ids) + text = self.tokenizer.decode(ids) + completion_token_ids = self.tokenizer([text]).input_ids + + if len(completion_token_ids[0]) != length: + logger.warning( + "The length of completion token ids is not equal to the length of input token ids" + ) + logger.warning( + f"Generate tokens, target: {length}, actual: {len(completion_token_ids[0])}" + ) + + return text + + +class MultiTurnDialogueDataset(BaseDataset): + def __init__(self, tokenizer_path: str): + super().__init__(tokenizer_path) + + def prepare_data(self, dataset_file_path) -> List[List[Union[str, Dict]]]: + """ + Load a JSON file containing multi-turn dialogue dataset paths. + :param file_path: JSON file listing multi-turn dialogue dataset paths to traverse. + the multi-turn dataset format: {"kimi": [{"conversion": [{"role": "user", "content": "xxx"}, ...], "qa": [{"question": "xxx", "answer": "xxx"}, ...]}]} + """ + cases = [] + # the path of multiturndialog.json + json_path = PathUtil.get_datasets_dir_path(dataset_file_path) + mtd_data: dict = self.load_json_file(json_path) + for dataset_name, files_list in mtd_data.items(): + for file_name in files_list: + case_path = PathUtil.get_dirname(json_path).joinpath( + dataset_name, file_name + ) + if case_path.exists(): + dialogues = self.load_json_file(case_path) + cases.extend(self.process_single_case_file(dialogues)) + else: + logger.warning( + f"JSON file {case_path} does not exist, please check the file path" + ) + if len(cases) == 0: + logger.warning( + f"The file {json_path} does not contain multi-turn dialogue data" + ) + return cases + + def process_single_case_file(self, dialogues: dict) -> List[List[Union[str, Dict]]]: + cases = [] + for dialogue_name, dialogue_data in dialogues.items(): + for i, dialog in enumerate(dialogue_data): + dialog_tokens = len( + self.tokenizer.tokenize(str(dialog["conversations"])) + ) + logger.info( + f"Current dialogue {dialogue_name}-{i} token count: {dialog_tokens}" + ) + cases.append([f"{dialogue_name}-{i}", dialog]) + return cases + + def load_json_file(self, file_path): + try: + with open(file_path, "r", encoding="utf-8") as f: + data = json.load(f) + return data + except FileNotFoundError: + logger.error(f"JSON file not found: {file_path}") + raise FileNotFoundError(f"JSON file not found: {file_path}") + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in file {file_path}: {e}") + raise ValueError(f"Invalid JSON format in file {file_path}: {e}") + except Exception as e: + logger.error(f"Unexpected error while loading JSON file {file_path}: {e}") + raise ValueError(f"Failed to load JSON file {file_path}: {e}") + + +class DocQADataset(BaseDataset): + def __init__(self, tokenizer_path: str): + super().__init__(tokenizer_path) + + def prepare_data(self, dataset_file_path) -> List[Union[str, str, str]]: + cases_list = [] + case_data = self._load_jsonl_file(dataset_file_path) + for case in case_data: + context = case.get("context") + question = case.get("question") + answer = case.get("answers")[0] + case_name = case.get("dataset") + "_" + case.get("_id") + cases_list.append([case_name, context, question, answer]) + return cases_list + + def _load_jsonl_file(self, file_path: str) -> List[Dict[str, Any]]: + """ + Load a JSONL file containing doc_qa data + :param file_path: Path to the jsonl file + :return: List of doc_qa data + """ + case_data = [] + try: + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + # In doc_qa, one line per sample; each sample contains: question, context, answer, etc. + json_line = json.loads(line) + extracted_data = { + "question": json_line.get("input", None), + "context": json_line.get("context", None), + "answers": json_line.get("answers", None), + "length": json_line.get("length", None), + "dataset": json_line.get("dataset", None), + "language": json_line.get("language", None), + "all_classes": json_line.get("all_classes", None), + "_id": json_line.get("_id", None), + } + case_data.append(extracted_data) + return case_data + except FileNotFoundError: + logger.error(f"JSONL file not found: {file_path}") + raise FileNotFoundError(f"JSONL file not found: {file_path}") + except json.JSONDecodeError as e: + logger.error(f"JSONL decode error in file {file_path}: {e}") + raise ValueError(f"Invalid JSONL format in file {file_path}: {e}") + except Exception as e: + logger.error(f"Unexpected error while loading JSONL file {file_path}: {e}") + raise ValueError(f"Failed to load JSONL file {file_path}: {e}") diff --git a/test/common/uc_eval/utils/utils.py b/test/common/uc_eval/utils/utils.py new file mode 100644 index 000000000..67a8183c0 --- /dev/null +++ b/test/common/uc_eval/utils/utils.py @@ -0,0 +1,96 @@ +import logging +import logging.handlers +import os +import sys +import time +from pathlib import Path +from typing import Dict + +current_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_current_time() -> str: + return time.strftime("%Y%m%d_%H%M%S", time.localtime()) + + +class PathUtil(object): + + @staticmethod + def get_dirname(file_path: str | Path): + return Path(os.path.dirname(file_path)) + + @staticmethod + def get_root_dir_path() -> Path: + root_path = Path(current_dir).parent.parent + return root_path + + @staticmethod + def get_other_dir_path(other: str) -> Path: + root_path = PathUtil.get_root_dir_path() + other_path = Path.joinpath(root_path, other) + other_path.mkdir(parents=True, exist_ok=True) + return other_path + + @staticmethod + def _default_datasets_path() -> Path: + return PathUtil.get_other_dir_path("UC-Eval-datasets") + + @staticmethod + def get_datasets_dir_path(in_file_path: str) -> Path: + if not in_file_path or in_file_path == "": + return PathUtil._default_datasets_path() + input_path = Path(in_file_path) + if input_path.is_absolute(): + return Path(in_file_path) + else: + return PathUtil.get_other_dir_path(in_file_path) + + +class LoggerHandler(logging.Logger): + def __init__( + self, name: str, level: int = logging.INFO, log_path: str = None + ) -> None: + super().__init__(name, level) + # format of the log message + fmt = "%(asctime)s.%(msecs)03d %(levelname)s [pid:%(process)d] [%(threadName)s] [tid:%(thread)d] [%(filename)s:%(lineno)d %(funcName)s] %(message)s" + data_fmt = "%Y-%m-%d %H:%M:%S" + formatter = logging.Formatter(fmt, data_fmt) + + # using file handler to log to file + if log_path is not None: + file_handler = logging.handlers.RotatingFileHandler( + filename=log_path, + maxBytes=1024 * 1024 * 10, + backupCount=20, + delay=True, + encoding="utf-8", + ) + file_handler.setFormatter(formatter) + file_handler.setLevel(self.level) + self.addHandler(file_handler) + + console_handler = logging.StreamHandler(stream=sys.stdout) + console_handler.setFormatter(formatter) + console_handler.setLevel(self.level) + self.addHandler(console_handler) + + def setLevel(self, level) -> None: + super().setLevel(level) + for handler in self.handlers: + handler.setLevel(level) + + +# the global dictionary to store all the logger instances +_logger_instances: Dict[str, LoggerHandler] = {} + + +def get_logger( + name: str = "evals", level: int = logging.INFO, log_file: str = None +) -> logging.Logger: + if name in _logger_instances: + return _logger_instances[name] + + # create a new logger instance + logger = LoggerHandler(name, level, log_file) + _logger_instances[name] = logger + return logger diff --git a/test/config.yaml b/test/config.yaml index 88d00a610..445b0a45a 100644 --- a/test/config.yaml +++ b/test/config.yaml @@ -9,10 +9,17 @@ reports: database: backup: "results/" - enabled: true + enabled: false host: "127.0.0.1" port: 3306 name: "ucm_pytest" user: "root" password: "123456" - charset: "utf8mb4" \ No newline at end of file + charset: "utf8mb4" + +models: + ip_ports: "" + tokenizer_path: "" + served_model_name: "" + payload: '' + enable_clear_hbm: false \ No newline at end of file diff --git a/test/suites/E2E/test_performance.py b/test/suites/E2E/test_performance.py new file mode 100644 index 000000000..5f5b7694c --- /dev/null +++ b/test/suites/E2E/test_performance.py @@ -0,0 +1,103 @@ +import dataclasses + +import pytest +from common.capture_utils import export_vars +from common.config_utils import config_utils as config_instance +from common.uc_eval.task import DocQaPerfTask, MultiPerfTask, SyntheticPerfTask +from common.uc_eval.utils.data_class import ModelConfig, PerfConfig + + +@pytest.fixture(scope="session") +def model_config() -> ModelConfig: + cfg = config_instance.get_config("models") or {} + field_name = [field.name for field in dataclasses.fields(ModelConfig)] + kwargs = {k: v for k, v in cfg.items() if k in field_name and v is not None} + return ModelConfig(**kwargs) + + +sync_perf_cases = [ + pytest.param( + PerfConfig( + data_type="synthetic", + enable_prefix_cache=False, + parallel_num=[1, 4, 8], + prompt_tokens=[4000, 8000], + output_tokens=[1000, 1000], + benchmark_mode="default-perf", + ), + id="benchmark-complete-recalculate-default-perf", + ), + pytest.param( + PerfConfig( + data_type="synthetic", + enable_prefix_cache=True, + parallel_num=[1, 4, 8], + prompt_tokens=[4000, 8000], + output_tokens=[1000, 1000], + prefix_cache_num=[0.8, 0.8], + benchmark_mode="stable-perf", + ), + id="benchmark-prefix-cache-stable-perf", + ), +] + + +@pytest.mark.feature("perf_test") +@pytest.mark.parametrize("perf_config", sync_perf_cases) +@export_vars +def test_sync_perf( + perf_config: PerfConfig, model_config: ModelConfig, request: pytest.FixtureRequest +): + task = SyntheticPerfTask(model_config, perf_config) + result = task.process() + return {"_name": request.node.callspec.id, "_data": result} + + +multiturn_dialogue_perf_cases = [ + pytest.param( + PerfConfig( + data_type="multi_turn_dialogue", + dataset_file_path="test/uc_eval/datasets/multi_turn_dialogues/multiturndialog.json", + enable_prefix_cache=False, + parallel_num=1, + benchmark_mode="default-perf", + ), + id="multiturn-dialogue-complete-recalculate-default-perf", + ) +] + + +@pytest.mark.feature("perf_test") +@pytest.mark.parametrize("perf_config", multiturn_dialogue_perf_cases) +@export_vars +def test_multiturn_dialogue_perf( + perf_config: PerfConfig, model_config: ModelConfig, request: pytest.FixtureRequest +): + task = MultiPerfTask(model_config, perf_config) + result = task.process() + return {"_name": request.node.callspec.id, "_data": result} + + +doc_qa_perf_cases = [ + pytest.param( + PerfConfig( + data_type="doc_qa", + dataset_file_path="test/uc-eval-new/datasets/doc_qa/demo.jsonl", + enable_prefix_cache=False, + parallel_num=1, + benchmark_mode="default-perf", + ), + id="doc-qa-complete-recalculate-default-perf", + ) +] + + +@pytest.mark.feature("perf_test") +@pytest.mark.parametrize("perf_config", doc_qa_perf_cases) +@export_vars +def test_doc_qa_perf( + perf_config: PerfConfig, model_config: ModelConfig, request: pytest.FixtureRequest +): + task = DocQaPerfTask(model_config, perf_config) + result = task.process() + return {"_name": request.node.callspec.id, "_data": result}