From 1892a79dc40ac9bc240eb9fd20ef257827b2217e Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:17:49 +0000 Subject: [PATCH 01/10] initial implementation of leaderboard. Lots of stuff can be improved but this brings the core idea Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- pyproject.toml | 1 + requirements.txt | 4 +- src/instructlab/eval/leaderboard.py | 583 ++++++++++++++++++++++++++++ 3 files changed, 587 insertions(+), 1 deletion(-) create mode 100644 src/instructlab/eval/leaderboard.py diff --git a/pyproject.toml b/pyproject.toml index 03faef94..44a7de52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ issues = "https://github.com/instructlab/eval/issues" "mmlu_branch" = "instructlab.eval.mmlu:MMLUBranchEvaluator" "mt_bench" = "instructlab.eval.mt_bench:MTBenchEvaluator" "mt_bench_branch" = "instructlab.eval.mt_bench:MTBenchBranchEvaluator" +"leaderboard_v2" = "instructlab.eval.leaderboard:LeaderboardV2Evaluator" [tool.setuptools_scm] version_file = "src/instructlab/eval/_version.py" diff --git a/requirements.txt b/requirements.txt index 8abed2c9..07839a7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,8 @@ transformers accelerate pandas pandas-stubs -lm-eval>=0.4.4 +# All optional dependencies like this can be found in lm-eval: +# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/pyproject.toml +lm-eval[math,ifeval,sentencepiece,vllm]>=0.4.4 httpx ragas>=0.2.11 diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py new file mode 100644 index 00000000..2014e8f5 --- /dev/null +++ b/src/instructlab/eval/leaderboard.py @@ -0,0 +1,583 @@ +from lm_eval.evaluator import simple_evaluate +from .evaluator import Evaluator + + +from pathlib import Path +import json +from lm_eval.evaluator import simple_evaluate +import typing as t +import os +import torch.multiprocessing as mp +from accelerate import Accelerator +import torch.distributed as dist +import torch +from torch import cuda +import gc +from enum import StrEnum + + +class ParsedScores(t.TypedDict): + """ + Just an ordinary dict that contains both the overall score as well as per-subtask scores. + """ + + score: float + subtasks: t.NotRequired[t.Dict[str, float]] + + +class LeaderboardV2EvalResult(t.TypedDict): + overall_score: float + leaderboard_gpqa: t.NotRequired[ParsedScores] + leaderboard_ifeval: t.NotRequired[ParsedScores] + leaderboard_bbh: t.NotRequired[ParsedScores] + leaderboard_mmlu_pro: t.NotRequired[ParsedScores] + leaderboard_musr: t.NotRequired[ParsedScores] + leaderboard_math_hard: t.NotRequired[ParsedScores] + + +class LeaderboardV2Tasks(StrEnum): + MATH_HARD = "leaderboard_math_hard" + IFEVAL = "leaderboard_ifeval" + MMLU_PRO = "leaderboard_mmlu_pro" + GPQA = "leaderboard_gpqa" + MUSR = "leaderboard_musr" + BBH = "leaderboard_bbh" + + +class LeaderboardArgs(t.TypedDict): + model_path: str + num_gpus: int + tasks: t.List[str] + + +class TaskGrouping(t.TypedDict): + """ + Class used to group the tasks by their optimal runtime. + """ + + huggingface: t.List[str] + vllm: t.List[str] + + +# generative tasks go here +LEADERBOARD_V2_GENERATIVE_TASKS = [ + LeaderboardV2Tasks.MATH_HARD.value, + LeaderboardV2Tasks.IFEVAL.value, +] + +# all the MCQ-style tasks in leaderboard v2 +LEADERBOARD_V2_MCQ_TASKS = [ + LeaderboardV2Tasks.BBH.value, + LeaderboardV2Tasks.MUSR.value, + LeaderboardV2Tasks.GPQA.value, + LeaderboardV2Tasks.MMLU_PRO.value, +] + + +def evaluate_with_vllm(args: LeaderboardArgs) -> t.Dict[str, t.Any]: + os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "true" + results = simple_evaluate( + tasks=args["tasks"], + model="vllm", + model_args={ + "pretrained": args["model_path"], + "dtype": "float16", + "data_parallel_size": args["num_gpus"], + "gpu_memory_utilization": 0.8, + "max_model_len": 32768, + "disable_custom_all_reduce": True, + "enforce_eager": False, + }, + apply_chat_template=True, + fewshot_as_multiturn=True, + batch_size="auto", + ) + return results + + +def worker(rank, world_size, args: LeaderboardArgs, result_queue: mp.Queue): + os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "true" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" # hopefully nobody else is using this port + + accelerator = Accelerator() + device = accelerator.device + assert device.type == "cuda", f"device is not a cuda device: {device}" + + results = simple_evaluate( + model="hf", + model_args={ + "pretrained": args["model_path"], + "dtype": "float16", + "trust_remote_code": True, + }, + tasks=args["tasks"], + apply_chat_template=True, + fewshot_as_multiturn=True, + batch_size="auto", + device=f"cuda:{device.index}", + cache_requests=True, + ) + + print(f"Rank {rank} got results: {type(results)}, putting them in the bucket") + result_queue.put((rank, results)) + print(f"Rank {rank} done putting results in the bucket") + + # clear torch memory + gc.collect() + torch.cuda.empty_cache() + + print(f"Rank {rank} destroying process group") + dist.destroy_process_group() + + +def evaluate_with_hf(args: LeaderboardArgs) -> t.Dict[str, t.Any]: + # we need to use torch.multiprocessing to run each task in a separate process, + # and then combine the results + import torch.multiprocessing as mp + + num_processes = args["num_gpus"] + + # Create the context and queue within the same context + mp_ctx = mp.get_context("spawn") # Explicitly use spawn context + result_queue = mp_ctx.Queue() + + # Use the same context's Process + processes = [] + for rank in range(num_processes): + p = mp_ctx.Process( + target=worker, args=(rank, num_processes, args, result_queue) + ) + p.start() + processes.append(p) + + results = {} + for _ in range(num_processes): + print(f"[master] getting results from the bucket") + rank, result = result_queue.get() + results[rank] = result + print(f"[master] got results from rank {rank}") + + # Wait for all processes to complete + for p in processes: + p.join() + + # extract the result which is not None + assert len([res for res in results.values() if res is not None]) == 1, ( + "we expect exactly 1 process to return a results dict properly" + ) + results_dict = [res for res in results.values() if res is not None][0] + return results_dict + + +def get_score_by_metric(score_dict: t.Dict[str, t.Any], metric: str) -> t.Any: + extracted_value = None + for key, value in score_dict.items(): + if "," not in key: + continue + parsed_metric, _ = key.split(",") + if parsed_metric == metric: + extracted_value = value + break + + if not extracted_value: + if alias := score_dict.get("alias", None): + error_msg = ( + f"Failed to find a metric matching '{metric}' for task '{alias}'." + ) + else: + error_msg = f"Failed to find a metric matching '{metric}'." + error_msg += f"\nAvailable fields: {list(score_dict.keys())}" + raise ValueError(error_msg) + return extracted_value + + +def parse_multitask_results( + result_dict: t.Dict[str, t.Any], benchmark: str, metric: str +) -> ParsedScores: + """ + Parse out the results of the given benchmark into a single floating-point score that can be consumed. + + The rules are like this: for a multi-task benchmark, the entry matching the exact benchmark name contains nothing. + Everything else is a subtask and contains a score. + + The end result is an unweighted average of all the subtasks, as well as a per-subtask breakdown. + """ + parsed_scores = {"score": 0.0, "subtasks": {}} + subtask_scores = {} + target_subtasks = result_dict["group_subtasks"].get(benchmark) + if not target_subtasks: + raise ValueError(f"Couldnt find '{benchmark}' in the group_subtasks section") + + for subtask in target_subtasks: + # pull out the score + subtask_results = result_dict["results"][subtask] + subtask_score = get_score_by_metric(subtask_results, metric) + subtask_scores[subtask] = subtask_score + + # exit early, base case + if not subtask_scores: + return parsed_scores + + parsed_scores["score"] = sum(subtask_scores.values()) / len(subtask_scores) + parsed_scores["subtasks"] = subtask_scores + return parsed_scores + + +def parse_bbh(result_dict: t.Dict[str, t.Any]) -> ParsedScores: + """ + Parses out the bbh scores from the result dict + """ + parsed_scores = parse_multitask_results( + result_dict, LeaderboardV2Tasks.BBH.value, "acc_norm" + ) + assert len(parsed_scores["subtasks"]) == 24, ( + "there should be 24 subtasks of bbh run" + ) + return parsed_scores + + +def parse_mmlu_pro(result_dict: t.Dict[str, t.Any]) -> ParsedScores: + """ + Parses out the mmlu_pro scores from the result dict + """ + mmlu_pro_results = result_dict["results"].get("leaderboard_mmlu_pro", None) + return { + "score": get_score_by_metric(mmlu_pro_results, "acc"), + } + + +def parse_ifeval(result_dict: t.Dict[str, t.Any]) -> ParsedScores: + """ + Parses out the ifeval scores from the result dict. + In particular, we only compute the average between the strict prompts + """ + ifeval_results = result_dict["results"].get("leaderboard_ifeval", None) + if not ifeval_results: + raise ValueError( + f"Failed to find `leaderboard_ifeval` in scores. Available results: {list(result_dict.keys())}" + ) + + # The format of ifeval looks like this: + # { + # "alias": "leaderboard_ifeval", + # "prompt_level_strict_acc,none": 0.6876155268022182, + # "prompt_level_strict_acc_stderr,none": 0.019944386293758908, + # "inst_level_strict_acc,none": 0.7745803357314148, + # "inst_level_strict_acc_stderr,none": "N/A", + # "prompt_level_loose_acc,none": 0.722735674676525, + # "prompt_level_loose_acc_stderr,none": 0.019263706963479364, + # "inst_level_loose_acc,none": 0.8033573141486811, + # "inst_level_loose_acc_stderr,none": "N/A" + # } + # + + target_metrics = {"prompt_level_strict_acc", "inst_level_strict_acc"} + scores = [] + + for key, value in ifeval_results.items(): + if "," not in key or "stderr" in key: + continue + + metric, _ = key.split(",") + if metric in target_metrics: + scores.append(value) + target_metrics.remove(metric) + + assert len(scores) == 2, ( + f"there should only be 2 values extracted in ifeval, got: {len(scores)}" + ) + return { + "score": sum(scores) / 2, + } + + +def parse_musr(result_dict: t.Dict[str, t.Any]) -> ParsedScores: + """ + Parses out the musr scores from the result dict + """ + parsed_scores = parse_multitask_results( + result_dict, LeaderboardV2Tasks.MUSR.value, "acc_norm" + ) + assert len(parsed_scores["subtasks"]) == 3 + return parsed_scores + + +def parse_gpqa(result_dict: t.Dict[str, t.Any]) -> ParsedScores: + """ + Parses out the gpqa scores from the result dict + """ + parsed_scores = parse_multitask_results( + result_dict, LeaderboardV2Tasks.GPQA.value, "acc_norm" + ) + assert len(parsed_scores["subtasks"]) == 3, ( + f"Expected 3 gpqa scores, got {len(parsed_scores['subtasks'])}" + ) + return parsed_scores + + +def parse_math_hard(result_dict: t.Dict[str, t.Any]) -> ParsedScores: + """ + Parses out the math_hard scores from the result dict. Result is an unweighted average. + """ + parsed_scores = parse_multitask_results( + result_dict, LeaderboardV2Tasks.MATH_HARD.value, "exact_match" + ) + assert len(parsed_scores["subtasks"]) == 7, ( + f"leaderboard_math_hard should have 7 subtasks, found: {len(parsed_scores['subtasks'])}" + ) + return parsed_scores + + +def get_parser(subtask: str) -> t.Callable[[t.Dict, str, str], ParsedScores]: + parser_map = { + LeaderboardV2Tasks.BBH.value: parse_bbh, + LeaderboardV2Tasks.GPQA.value: parse_gpqa, + LeaderboardV2Tasks.IFEVAL.value: parse_ifeval, + LeaderboardV2Tasks.MATH_HARD.value: parse_math_hard, + LeaderboardV2Tasks.MMLU_PRO.value: parse_mmlu_pro, + LeaderboardV2Tasks.MUSR.value: parse_musr, + } + return parser_map[ + LeaderboardV2Tasks(subtask) + ] # this will either parse and map into the correct section, or error + + +def get_scores_from_result_dicts( + *result_dicts: t.List[t.Dict[str, t.Any]], +) -> t.Dict[str, ParsedScores]: + """ + Parse out the scores of all the subtasks of leaderboard and return. + """ + parsed_scores = {} + for result_dict in result_dicts: + benchmarks_we_got = set(result_dict["results"].keys()) + benchmarks_we_care_about = set( + LEADERBOARD_V2_GENERATIVE_TASKS + LEADERBOARD_V2_MCQ_TASKS + ) + benchmarks_to_parse = benchmarks_we_got & benchmarks_we_care_about + + # this is just a sanity check step + benchmarks_already_covered = set(parsed_scores.keys()) + overlapping_benchmarks = benchmarks_already_covered & benchmarks_to_parse + assert len(benchmarks_already_covered & benchmarks_to_parse) == 0, ( + f"expected no overlapping benchmarks but found the following to overlap: {list(overlapping_benchmarks)}" + ) + + # now actually add them + for benchmark in benchmarks_to_parse: + parse_benchmark_fn = get_parser(benchmark) + parsed_scores[benchmark] = parse_benchmark_fn(result_dict) + + return parsed_scores + + +def validate_output_path(output_file: str) -> None: + """ + Validates that we can write to the specified output path. + Creates parent directories if they don't exist. + + Args: + output_file: Path to the desired output file + + Raises: + ValueError: If the path is invalid or we don't have proper permissions + """ + if not output_file: + raise ValueError("Output file path cannot be empty") + + # Convert to Path object for easier handling + output_path = Path(output_file) + + try: + # Create parent directories if they don't exist + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Test if we can write to the file by opening it in append mode + # We don't actually write anything + output_path.open("a").close() + + except PermissionError: + raise ValueError(f"Permission denied: Cannot write to {output_file}") + except OSError as e: + raise ValueError(f"Invalid output path: {output_file}. Error: {str(e)}") + + +def validate_leaderboard_v2_tasks(tasks: t.List[str]): + invalid_tasks = set(tasks) - set( + LEADERBOARD_V2_GENERATIVE_TASKS + LEADERBOARD_V2_MCQ_TASKS + ) + if invalid_tasks: + raise ValueError( + f"the following tasks were provided but are not valid leaderboard tasks: {list(invalid_tasks)}.\n" + f"Supported tasks are: {LEADERBOARD_V2_GENERATIVE_TASKS + LEADERBOARD_V2_MCQ_TASKS}" + ) + + +def get_task_groupings(tasks: t.List[str]) -> TaskGrouping: + """ + Given a list of tasks, bucket them per their optimal runtime. + """ + task_grouping: TaskGrouping = { + "vllm": [task for task in tasks if task in LEADERBOARD_V2_GENERATIVE_TASKS], + "huggingface": [task for task in tasks if task in LEADERBOARD_V2_MCQ_TASKS], + } + overlapping_tasks = set(task_grouping["vllm"]) & set(task_grouping["huggingface"]) + assert not overlapping_tasks + return task_grouping + + +def calculate_overall_leaderboard_score(results: t.Dict[str, ParsedScores]) -> float: + """ + Given a dict with leaderboard metrics, compute the average of the scores + """ + all_scores = [res["score"] for res in results.values() if "score" in res] + + return sum(all_scores) / len(all_scores) if len(all_scores) > 0 else 0.0 + + +# here we assume that we can +class LeaderboardV2Evaluator(Evaluator): + """ + Evaluator for Open Leaderboard v2. + """ + + name = "leaderboard_v2" + + def __init__( + self, + model_path: str, + tasks: t.List[str] = None, + num_gpus: int = None, + output_file: str = None, + ): + self.model_path = model_path + if not cuda.is_available(): + raise ValueError( + "Running without CUDA is currently unsupported. Contributions are welcome." + ) + + # set whatever we need here + self.num_gpus = num_gpus + self.tasks = tasks + + # validate output file + self.output_file = output_file + self._results = None + self._lm_eval_results = [] # TODO: make it merge everything back into a single result + + @property + def results(self) -> LeaderboardV2EvalResult: + """ + Returns the results of the most reccent leaderboard evaluation. + + Returns: + LeaderboardV2EvalResult: A dict containing the overall leaderboard score and the breakdown per subtask. + """ + return self._results + + @property + def lm_eval_results(self) -> t.List[t.Dict[str, t.Any]]: + """ + Returns the results of the most recent leaderboard evaluation. + + Returns: + t.List[t.Dict[str, t.Any]]: A list of dicts containing the results of the most recent leaderboard evaluation. + """ + return self._lm_eval_results + + def save_to_file(self, output_file: str = None): + """ + Saves the results to a file. + + Args: + output_file: The path to the file to save the results to. + """ + if output_file is None: + output_file = self.output_file + if output_file is None: + raise ValueError("Output file path cannot be empty") + + # create the directory if it doesn't exist + output_dir = os.path.dirname(output_file) + os.makedirs(output_dir, exist_ok=True) + with open(output_file, "w") as f: + json.dump(self._results, f, indent=2) + + def run( + self, + model_path: str | None = None, + tasks: t.List[str] = None, + num_gpus: int = None, + output_file: str = None, + ) -> LeaderboardV2EvalResult: + """ + Run the Open LLM Leaderboard v2 evaluation. + + This function will use both HF transformers and inline vLLM to run the evaluation. + It will then parse the results and save them to a file. + + Args: + model_path: The path to the model to evaluate. + tasks: The tasks to evaluate. + num_gpus: The number of GPUs to use. + output_file: The path to the file to save the results to. + + Returns: + LeaderboardV2EvalResult: A dict containing the overall leaderboard score and the breakdown per subtask. + """ + model_path = self.model_path if model_path is None else model_path + tasks = self.tasks if not tasks else tasks + num_gpus = self.num_gpus if not num_gpus else num_gpus + output_file = self.output_file if not output_file else output_file + + # validation logic + # no need to validate model path -- the inference libraries will either be able to + # load it, or they won't + + validate_leaderboard_v2_tasks(tasks) + if not num_gpus: + num_gpus = cuda.device_count() + if num_gpus <= 0 or num_gpus > cuda.device_count(): + raise ValueError( + f"invalid value for num_gpus, must be between 1 and {cuda.device_count()}; got: {num_gpus}" + ) + if output_file: + validate_output_path(output_file) + + # now we just have to run the task group in their most appropriate runtime + # this is important because certain tasks like MCQ are better-suited to be + # excuted in raw transformers due to the lack of KV-Cache overhead, + # whereas generative tasks are better suited for vLLM due to their need for + # accessing previous tokens + + grouped_tasks = get_task_groupings(tasks) + self._lm_eval_results = [] + vllm_results, hf_results = None, None + if vllm_tasks := grouped_tasks["vllm"]: + args: LeaderboardArgs = { + "model_path": model_path, + "num_gpus": num_gpus, + "tasks": vllm_tasks, + } + vllm_results = evaluate_with_vllm(args) + self._lm_eval_results.append(vllm_results) + if hf_tasks := grouped_tasks["huggingface"]: + args: LeaderboardArgs = { + "model_path": model_path, + "num_gpus": num_gpus, + "tasks": hf_tasks, + } + hf_results = evaluate_with_hf(args) + self._lm_eval_results.append(hf_results) + + # convert the output of lm-eval into something that's already parsed + results = get_scores_from_result_dicts(*self._lm_eval_results) + results["overall_score"] = calculate_overall_leaderboard_score(results) + + self._results = results + self.save_to_file(output_file) + return results From 15e9f75bcd919c70bbd25f9d3856bc47711db887 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:21:56 +0000 Subject: [PATCH 02/10] formatting Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/eval/leaderboard.py | 64 +++++++++++++++-------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index 2014e8f5..fbd85327 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -1,19 +1,21 @@ -from lm_eval.evaluator import simple_evaluate -from .evaluator import Evaluator - - +# Standard +from enum import StrEnum from pathlib import Path +import gc import json -from lm_eval.evaluator import simple_evaluate -import typing as t import os -import torch.multiprocessing as mp +import typing as t + +# Third Party from accelerate import Accelerator -import torch.distributed as dist -import torch +from lm_eval.evaluator import simple_evaluate from torch import cuda -import gc -from enum import StrEnum +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +# Local +from .evaluator import Evaluator class ParsedScores(t.TypedDict): @@ -137,6 +139,7 @@ def worker(rank, world_size, args: LeaderboardArgs, result_queue: mp.Queue): def evaluate_with_hf(args: LeaderboardArgs) -> t.Dict[str, t.Any]: # we need to use torch.multiprocessing to run each task in a separate process, # and then combine the results + # Third Party import torch.multiprocessing as mp num_processes = args["num_gpus"] @@ -166,9 +169,9 @@ def evaluate_with_hf(args: LeaderboardArgs) -> t.Dict[str, t.Any]: p.join() # extract the result which is not None - assert len([res for res in results.values() if res is not None]) == 1, ( - "we expect exactly 1 process to return a results dict properly" - ) + assert ( + len([res for res in results.values() if res is not None]) == 1 + ), "we expect exactly 1 process to return a results dict properly" results_dict = [res for res in results.values() if res is not None][0] return results_dict @@ -234,9 +237,9 @@ def parse_bbh(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.BBH.value, "acc_norm" ) - assert len(parsed_scores["subtasks"]) == 24, ( - "there should be 24 subtasks of bbh run" - ) + assert ( + len(parsed_scores["subtasks"]) == 24 + ), "there should be 24 subtasks of bbh run" return parsed_scores @@ -287,9 +290,9 @@ def parse_ifeval(result_dict: t.Dict[str, t.Any]) -> ParsedScores: scores.append(value) target_metrics.remove(metric) - assert len(scores) == 2, ( - f"there should only be 2 values extracted in ifeval, got: {len(scores)}" - ) + assert ( + len(scores) == 2 + ), f"there should only be 2 values extracted in ifeval, got: {len(scores)}" return { "score": sum(scores) / 2, } @@ -313,9 +316,9 @@ def parse_gpqa(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.GPQA.value, "acc_norm" ) - assert len(parsed_scores["subtasks"]) == 3, ( - f"Expected 3 gpqa scores, got {len(parsed_scores['subtasks'])}" - ) + assert ( + len(parsed_scores["subtasks"]) == 3 + ), f"Expected 3 gpqa scores, got {len(parsed_scores['subtasks'])}" return parsed_scores @@ -326,9 +329,9 @@ def parse_math_hard(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.MATH_HARD.value, "exact_match" ) - assert len(parsed_scores["subtasks"]) == 7, ( - f"leaderboard_math_hard should have 7 subtasks, found: {len(parsed_scores['subtasks'])}" - ) + assert ( + len(parsed_scores["subtasks"]) == 7 + ), f"leaderboard_math_hard should have 7 subtasks, found: {len(parsed_scores['subtasks'])}" return parsed_scores @@ -363,9 +366,9 @@ def get_scores_from_result_dicts( # this is just a sanity check step benchmarks_already_covered = set(parsed_scores.keys()) overlapping_benchmarks = benchmarks_already_covered & benchmarks_to_parse - assert len(benchmarks_already_covered & benchmarks_to_parse) == 0, ( - f"expected no overlapping benchmarks but found the following to overlap: {list(overlapping_benchmarks)}" - ) + assert ( + len(benchmarks_already_covered & benchmarks_to_parse) == 0 + ), f"expected no overlapping benchmarks but found the following to overlap: {list(overlapping_benchmarks)}" # now actually add them for benchmark in benchmarks_to_parse: @@ -579,5 +582,6 @@ def run( results["overall_score"] = calculate_overall_leaderboard_score(results) self._results = results - self.save_to_file(output_file) + if output_file: + self.save_to_file(output_file) return results From e2b41bde4c93923e90b187d8d206d4f2d884a21b Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:27:15 +0000 Subject: [PATCH 03/10] fix saving and add test script Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- scripts/test_leaderboard.py | 10 ++++++++ src/instructlab/eval/leaderboard.py | 39 +++++++++++++++-------------- 2 files changed, 30 insertions(+), 19 deletions(-) create mode 100644 scripts/test_leaderboard.py diff --git a/scripts/test_leaderboard.py b/scripts/test_leaderboard.py new file mode 100644 index 00000000..564e8baf --- /dev/null +++ b/scripts/test_leaderboard.py @@ -0,0 +1,10 @@ +# First Party +from instructlab.eval.leaderboard import LeaderboardV2Evaluator + +if __name__ == "__main__": + evaluator = LeaderboardV2Evaluator( + model_path="ibm-granite/granite-3.1-8b-instruct", + ) + results = evaluator.run() + print("got results from leaderboard v2") + print(json.dumps(results, indent=2)) diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index fbd85327..01554b5a 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -169,9 +169,9 @@ def evaluate_with_hf(args: LeaderboardArgs) -> t.Dict[str, t.Any]: p.join() # extract the result which is not None - assert ( - len([res for res in results.values() if res is not None]) == 1 - ), "we expect exactly 1 process to return a results dict properly" + assert len([res for res in results.values() if res is not None]) == 1, ( + "we expect exactly 1 process to return a results dict properly" + ) results_dict = [res for res in results.values() if res is not None][0] return results_dict @@ -237,9 +237,9 @@ def parse_bbh(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.BBH.value, "acc_norm" ) - assert ( - len(parsed_scores["subtasks"]) == 24 - ), "there should be 24 subtasks of bbh run" + assert len(parsed_scores["subtasks"]) == 24, ( + "there should be 24 subtasks of bbh run" + ) return parsed_scores @@ -290,9 +290,9 @@ def parse_ifeval(result_dict: t.Dict[str, t.Any]) -> ParsedScores: scores.append(value) target_metrics.remove(metric) - assert ( - len(scores) == 2 - ), f"there should only be 2 values extracted in ifeval, got: {len(scores)}" + assert len(scores) == 2, ( + f"there should only be 2 values extracted in ifeval, got: {len(scores)}" + ) return { "score": sum(scores) / 2, } @@ -316,9 +316,9 @@ def parse_gpqa(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.GPQA.value, "acc_norm" ) - assert ( - len(parsed_scores["subtasks"]) == 3 - ), f"Expected 3 gpqa scores, got {len(parsed_scores['subtasks'])}" + assert len(parsed_scores["subtasks"]) == 3, ( + f"Expected 3 gpqa scores, got {len(parsed_scores['subtasks'])}" + ) return parsed_scores @@ -329,9 +329,9 @@ def parse_math_hard(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.MATH_HARD.value, "exact_match" ) - assert ( - len(parsed_scores["subtasks"]) == 7 - ), f"leaderboard_math_hard should have 7 subtasks, found: {len(parsed_scores['subtasks'])}" + assert len(parsed_scores["subtasks"]) == 7, ( + f"leaderboard_math_hard should have 7 subtasks, found: {len(parsed_scores['subtasks'])}" + ) return parsed_scores @@ -366,9 +366,9 @@ def get_scores_from_result_dicts( # this is just a sanity check step benchmarks_already_covered = set(parsed_scores.keys()) overlapping_benchmarks = benchmarks_already_covered & benchmarks_to_parse - assert ( - len(benchmarks_already_covered & benchmarks_to_parse) == 0 - ), f"expected no overlapping benchmarks but found the following to overlap: {list(overlapping_benchmarks)}" + assert len(benchmarks_already_covered & benchmarks_to_parse) == 0, ( + f"expected no overlapping benchmarks but found the following to overlap: {list(overlapping_benchmarks)}" + ) # now actually add them for benchmark in benchmarks_to_parse: @@ -506,7 +506,8 @@ def save_to_file(self, output_file: str = None): # create the directory if it doesn't exist output_dir = os.path.dirname(output_file) - os.makedirs(output_dir, exist_ok=True) + if output_dir: + os.makedirs(output_dir, exist_ok=True) with open(output_file, "w") as f: json.dump(self._results, f, indent=2) From b43f697e1269061479174fb8c3dc77ca4721b732 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 20 Mar 2025 02:47:04 +0000 Subject: [PATCH 04/10] fix mypy errors Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/eval/leaderboard.py | 93 +++++++++++++++++++---------- 1 file changed, 63 insertions(+), 30 deletions(-) diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index 01554b5a..73ca7748 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -77,7 +77,6 @@ class TaskGrouping(t.TypedDict): def evaluate_with_vllm(args: LeaderboardArgs) -> t.Dict[str, t.Any]: - os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "true" results = simple_evaluate( tasks=args["tasks"], model="vllm", @@ -93,12 +92,12 @@ def evaluate_with_vllm(args: LeaderboardArgs) -> t.Dict[str, t.Any]: apply_chat_template=True, fewshot_as_multiturn=True, batch_size="auto", + confirm_run_unsafe_code=True, ) return results def worker(rank, world_size, args: LeaderboardArgs, result_queue: mp.Queue): - os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = "true" os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_RANK"] = str(rank) @@ -122,17 +121,15 @@ def worker(rank, world_size, args: LeaderboardArgs, result_queue: mp.Queue): batch_size="auto", device=f"cuda:{device.index}", cache_requests=True, + confirm_run_unsafe_code=True, ) - print(f"Rank {rank} got results: {type(results)}, putting them in the bucket") result_queue.put((rank, results)) - print(f"Rank {rank} done putting results in the bucket") # clear torch memory gc.collect() torch.cuda.empty_cache() - print(f"Rank {rank} destroying process group") dist.destroy_process_group() @@ -159,10 +156,8 @@ def evaluate_with_hf(args: LeaderboardArgs) -> t.Dict[str, t.Any]: results = {} for _ in range(num_processes): - print(f"[master] getting results from the bucket") rank, result = result_queue.get() results[rank] = result - print(f"[master] got results from rank {rank}") # Wait for all processes to complete for p in processes: @@ -209,7 +204,7 @@ def parse_multitask_results( The end result is an unweighted average of all the subtasks, as well as a per-subtask breakdown. """ - parsed_scores = {"score": 0.0, "subtasks": {}} + parsed_scores: ParsedScores = {"score": 0.0, "subtasks": {}} subtask_scores = {} target_subtasks = result_dict["group_subtasks"].get(benchmark) if not target_subtasks: @@ -335,7 +330,7 @@ def parse_math_hard(result_dict: t.Dict[str, t.Any]) -> ParsedScores: return parsed_scores -def get_parser(subtask: str) -> t.Callable[[t.Dict, str, str], ParsedScores]: +def get_parser(subtask: str) -> t.Callable[[t.Dict[str, t.Any]], ParsedScores]: parser_map = { LeaderboardV2Tasks.BBH.value: parse_bbh, LeaderboardV2Tasks.GPQA.value: parse_gpqa, @@ -349,13 +344,45 @@ def get_parser(subtask: str) -> t.Callable[[t.Dict, str, str], ParsedScores]: ] # this will either parse and map into the correct section, or error +def build_leaderboard_v2_result( + parsed_scores: t.Dict[str, ParsedScores], +) -> LeaderboardV2EvalResult: + """ + Build the leaderboard v2 result from the parsed scores. + """ + # now let's build the overall score + leaderboard_result: LeaderboardV2EvalResult = { + "overall_score": calculate_overall_leaderboard_score(parsed_scores), + } + + # explicitly set the score for each subtask in order to satisfy mypy + if "leaderboard_bbh" in parsed_scores: + leaderboard_result["leaderboard_bbh"] = parsed_scores["leaderboard_bbh"] + if "leaderboard_gpqa" in parsed_scores: + leaderboard_result["leaderboard_gpqa"] = parsed_scores["leaderboard_gpqa"] + if "leaderboard_ifeval" in parsed_scores: + leaderboard_result["leaderboard_ifeval"] = parsed_scores["leaderboard_ifeval"] + if "leaderboard_math_hard" in parsed_scores: + leaderboard_result["leaderboard_math_hard"] = parsed_scores[ + "leaderboard_math_hard" + ] + if "leaderboard_mmlu_pro" in parsed_scores: + leaderboard_result["leaderboard_mmlu_pro"] = parsed_scores[ + "leaderboard_mmlu_pro" + ] + if "leaderboard_musr" in parsed_scores: + leaderboard_result["leaderboard_musr"] = parsed_scores["leaderboard_musr"] + + return leaderboard_result + + def get_scores_from_result_dicts( - *result_dicts: t.List[t.Dict[str, t.Any]], -) -> t.Dict[str, ParsedScores]: + *result_dicts: t.Dict[str, t.Any], +) -> LeaderboardV2EvalResult: """ Parse out the scores of all the subtasks of leaderboard and return. """ - parsed_scores = {} + parsed_scores: t.Dict[str, ParsedScores] = {} for result_dict in result_dicts: benchmarks_we_got = set(result_dict["results"].keys()) benchmarks_we_care_about = set( @@ -375,7 +402,7 @@ def get_scores_from_result_dicts( parse_benchmark_fn = get_parser(benchmark) parsed_scores[benchmark] = parse_benchmark_fn(result_dict) - return parsed_scores + return build_leaderboard_v2_result(parsed_scores) def validate_output_path(output_file: str) -> None: @@ -453,9 +480,9 @@ class LeaderboardV2Evaluator(Evaluator): def __init__( self, model_path: str, - tasks: t.List[str] = None, - num_gpus: int = None, - output_file: str = None, + tasks: t.Optional[t.List[str]] = None, + num_gpus: t.Optional[int] = None, + output_file: t.Optional[str] = None, ): self.model_path = model_path if not cuda.is_available(): @@ -469,11 +496,13 @@ def __init__( # validate output file self.output_file = output_file - self._results = None - self._lm_eval_results = [] # TODO: make it merge everything back into a single result + self._results: t.Optional[LeaderboardV2EvalResult] = None + self._lm_eval_results: t.List[ + t.Dict[str, t.Any] + ] = [] # TODO: make it merge everything back into a single result @property - def results(self) -> LeaderboardV2EvalResult: + def results(self) -> t.Optional[LeaderboardV2EvalResult]: """ Returns the results of the most reccent leaderboard evaluation. @@ -492,7 +521,7 @@ def lm_eval_results(self) -> t.List[t.Dict[str, t.Any]]: """ return self._lm_eval_results - def save_to_file(self, output_file: str = None): + def save_to_file(self, output_file: t.Optional[str] = None) -> None: """ Saves the results to a file. @@ -513,10 +542,10 @@ def save_to_file(self, output_file: str = None): def run( self, - model_path: str | None = None, - tasks: t.List[str] = None, - num_gpus: int = None, - output_file: str = None, + model_path: t.Optional[str] = None, + tasks: t.Optional[t.List[str]] = None, + num_gpus: t.Optional[int] = None, + output_file: t.Optional[str] = None, ) -> LeaderboardV2EvalResult: """ Run the Open LLM Leaderboard v2 evaluation. @@ -538,6 +567,9 @@ def run( num_gpus = self.num_gpus if not num_gpus else num_gpus output_file = self.output_file if not output_file else output_file + if not tasks: + tasks = LEADERBOARD_V2_MCQ_TASKS + LEADERBOARD_V2_GENERATIVE_TASKS + # validation logic # no need to validate model path -- the inference libraries will either be able to # load it, or they won't @@ -562,25 +594,26 @@ def run( self._lm_eval_results = [] vllm_results, hf_results = None, None if vllm_tasks := grouped_tasks["vllm"]: - args: LeaderboardArgs = { + args_vllm: LeaderboardArgs = { "model_path": model_path, "num_gpus": num_gpus, "tasks": vllm_tasks, } - vllm_results = evaluate_with_vllm(args) + vllm_results = evaluate_with_vllm(args_vllm) self._lm_eval_results.append(vllm_results) if hf_tasks := grouped_tasks["huggingface"]: - args: LeaderboardArgs = { + args_hf: LeaderboardArgs = { "model_path": model_path, "num_gpus": num_gpus, "tasks": hf_tasks, } - hf_results = evaluate_with_hf(args) + hf_results = evaluate_with_hf(args_hf) self._lm_eval_results.append(hf_results) # convert the output of lm-eval into something that's already parsed - results = get_scores_from_result_dicts(*self._lm_eval_results) - results["overall_score"] = calculate_overall_leaderboard_score(results) + results: LeaderboardV2EvalResult = get_scores_from_result_dicts( + *self._lm_eval_results + ) self._results = results if output_file: From bd9567230eaf66d5991c3a464f0e73d3a2ced7d7 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 20 Mar 2025 03:06:38 +0000 Subject: [PATCH 05/10] enable users to override the default vLLM + HF settings, as well as options for the `simple_evaluate` function Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/eval/leaderboard.py | 176 ++++++++++++++++++++++++---- 1 file changed, 152 insertions(+), 24 deletions(-) diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index 73ca7748..2452a73f 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -5,6 +5,7 @@ import json import os import typing as t +from copy import deepcopy # Third Party from accelerate import Accelerator @@ -46,12 +47,17 @@ class LeaderboardV2Tasks(StrEnum): BBH = "leaderboard_bbh" -class LeaderboardArgs(t.TypedDict): +class LeaderboardArgsRequired(t.TypedDict): model_path: str num_gpus: int tasks: t.List[str] +class LeaderboardArgs(LeaderboardArgsRequired, total=False): + eval_config: t.Dict[str, t.Any] + backend_config: t.Dict[str, t.Any] + + class TaskGrouping(t.TypedDict): """ Class used to group the tasks by their optimal runtime. @@ -61,6 +67,30 @@ class TaskGrouping(t.TypedDict): vllm: t.List[str] +# Default configuration parameters for evaluation +DEFAULT_EVAL_CONFIG = { + "batch_size": "auto", + "apply_chat_template": True, + "fewshot_as_multiturn": True, + "confirm_run_unsafe_code": True, + "max_model_len": 32768, + "system_instruction": None, +} + +# Default backend-specific configuration parameters +DEFAULT_VLLM_CONFIG = { + "dtype": "float16", + "gpu_memory_utilization": 0.8, + "disable_custom_all_reduce": True, + "enforce_eager": False, +} + +DEFAULT_HF_CONFIG = { + "dtype": "float16", + "trust_remote_code": True, + "cache_requests": True, +} + # generative tasks go here LEADERBOARD_V2_GENERATIVE_TASKS = [ LeaderboardV2Tasks.MATH_HARD.value, @@ -77,22 +107,36 @@ class TaskGrouping(t.TypedDict): def evaluate_with_vllm(args: LeaderboardArgs) -> t.Dict[str, t.Any]: + # Start with default configurations + eval_config = deepcopy(DEFAULT_EVAL_CONFIG) + backend_config = deepcopy(DEFAULT_VLLM_CONFIG) + + # Override with user-provided configurations + if "eval_config" in args and args["eval_config"]: + eval_config.update(args["eval_config"]) + if "backend_config" in args and args["backend_config"]: + backend_config.update(args["backend_config"]) + + # Prepare model_args + model_args = { + "pretrained": args["model_path"], + "data_parallel_size": args["num_gpus"], + **backend_config, + } + + # Set max_model_len if provided in eval_config + if "max_model_len" in eval_config: + model_args["max_model_len"] = eval_config.pop("max_model_len") + + # Extract system_instruction if provided + system_instruction = eval_config.pop("system_instruction", None) + results = simple_evaluate( tasks=args["tasks"], model="vllm", - model_args={ - "pretrained": args["model_path"], - "dtype": "float16", - "data_parallel_size": args["num_gpus"], - "gpu_memory_utilization": 0.8, - "max_model_len": 32768, - "disable_custom_all_reduce": True, - "enforce_eager": False, - }, - apply_chat_template=True, - fewshot_as_multiturn=True, - batch_size="auto", - confirm_run_unsafe_code=True, + model_args=model_args, + system_instruction=system_instruction, + **eval_config, ) return results @@ -108,20 +152,33 @@ def worker(rank, world_size, args: LeaderboardArgs, result_queue: mp.Queue): device = accelerator.device assert device.type == "cuda", f"device is not a cuda device: {device}" + # Start with default configurations + eval_config = deepcopy(DEFAULT_EVAL_CONFIG) + backend_config = deepcopy(DEFAULT_HF_CONFIG) + + # Override with user-provided configurations + if "eval_config" in args and args["eval_config"]: + eval_config.update(args["eval_config"]) + if "backend_config" in args and args["backend_config"]: + backend_config.update(args["backend_config"]) + + # Prepare model_args + model_args = {"pretrained": args["model_path"], **backend_config} + + # Set max_model_len if provided in eval_config + if "max_model_len" in eval_config: + model_args["max_model_len"] = eval_config.pop("max_model_len") + + # Extract system_instruction if provided + system_instruction = eval_config.pop("system_instruction", None) + results = simple_evaluate( model="hf", - model_args={ - "pretrained": args["model_path"], - "dtype": "float16", - "trust_remote_code": True, - }, + model_args=model_args, tasks=args["tasks"], - apply_chat_template=True, - fewshot_as_multiturn=True, - batch_size="auto", device=f"cuda:{device.index}", - cache_requests=True, - confirm_run_unsafe_code=True, + system_instruction=system_instruction, + **eval_config, ) result_queue.put((rank, results)) @@ -483,7 +540,40 @@ def __init__( tasks: t.Optional[t.List[str]] = None, num_gpus: t.Optional[int] = None, output_file: t.Optional[str] = None, + eval_config: t.Optional[t.Dict[str, t.Any]] = None, + vllm_config: t.Optional[t.Dict[str, t.Any]] = None, + hf_config: t.Optional[t.Dict[str, t.Any]] = None, ): + """ + Initialize the evaluator. + + Args: + model_path: Path to the model to evaluate. + tasks: List of tasks to evaluate on. + num_gpus: Number of GPUs to use. + output_file: Path to save results to. + eval_config: Configuration for general evaluation parameters that apply to both backends. + Default values (can be overridden): + - batch_size: "auto" - Batch size for evaluation, or "auto" for automatic batching + - apply_chat_template: True - Whether to apply chat template formatting + - fewshot_as_multiturn: True - Whether to format few-shot examples as multi-turn conversations + - confirm_run_unsafe_code: True - Whether to run potentially unsafe code without confirmation + - max_model_len: 32768 - Maximum sequence length for the model + - system_instruction: None - Optional system instruction to prepend to prompts + vllm_config: Configuration for vLLM-specific parameters. + Default values (can be overridden): + - dtype: "float16" - Data type for model weights + - gpu_memory_utilization: 0.8 - Fraction of GPU memory to use + - disable_custom_all_reduce: True - Whether to disable custom all-reduce implementation + - enforce_eager: False - Whether to enforce eager execution + And any other vLLM parameters supported by simple_evaluate. + hf_config: Configuration for HuggingFace-specific parameters. + Default values (can be overridden): + - dtype: "float16" - Data type for model weights + - trust_remote_code: True - Whether to trust remote code in model loading + - cache_requests: True - Whether to cache requests + And any other HuggingFace parameters supported by simple_evaluate. + """ self.model_path = model_path if not cuda.is_available(): raise ValueError( @@ -494,6 +584,11 @@ def __init__( self.num_gpus = num_gpus self.tasks = tasks + # Store evaluation configurations + self.eval_config = eval_config or {} + self.vllm_config = vllm_config or {} + self.hf_config = hf_config or {} + # validate output file self.output_file = output_file self._results: t.Optional[LeaderboardV2EvalResult] = None @@ -546,6 +641,9 @@ def run( tasks: t.Optional[t.List[str]] = None, num_gpus: t.Optional[int] = None, output_file: t.Optional[str] = None, + eval_config: t.Optional[t.Dict[str, t.Any]] = None, + vllm_config: t.Optional[t.Dict[str, t.Any]] = None, + hf_config: t.Optional[t.Dict[str, t.Any]] = None, ) -> LeaderboardV2EvalResult: """ Run the Open LLM Leaderboard v2 evaluation. @@ -558,6 +656,27 @@ def run( tasks: The tasks to evaluate. num_gpus: The number of GPUs to use. output_file: The path to the file to save the results to. + eval_config: Configuration for general evaluation parameters that apply to both backends. + Default values (can be overridden): + - batch_size: "auto" - Batch size for evaluation, or "auto" for automatic batching + - apply_chat_template: True - Whether to apply chat template formatting + - fewshot_as_multiturn: True - Whether to format few-shot examples as multi-turn conversations + - confirm_run_unsafe_code: True - Whether to run potentially unsafe code without confirmation + - max_model_len: 32768 - Maximum sequence length for the model + - system_instruction: None - Optional system instruction to prepend to prompts + vllm_config: Configuration for vLLM-specific parameters. + Default values (can be overridden): + - dtype: "float16" - Data type for model weights + - gpu_memory_utilization: 0.8 - Fraction of GPU memory to use + - disable_custom_all_reduce: True - Whether to disable custom all-reduce implementation + - enforce_eager: False - Whether to enforce eager execution + And any other vLLM parameters supported by simple_evaluate. + hf_config: Configuration for HuggingFace-specific parameters. + Default values (can be overridden): + - dtype: "float16" - Data type for model weights + - trust_remote_code: True - Whether to trust remote code in model loading + - cache_requests: True - Whether to cache requests + And any other HuggingFace parameters supported by simple_evaluate. Returns: LeaderboardV2EvalResult: A dict containing the overall leaderboard score and the breakdown per subtask. @@ -567,6 +686,11 @@ def run( num_gpus = self.num_gpus if not num_gpus else num_gpus output_file = self.output_file if not output_file else output_file + # Merge configurations with instance configurations, with run-time configs taking precedence + final_eval_config = {**self.eval_config, **(eval_config or {})} + final_vllm_config = {**self.vllm_config, **(vllm_config or {})} + final_hf_config = {**self.hf_config, **(hf_config or {})} + if not tasks: tasks = LEADERBOARD_V2_MCQ_TASKS + LEADERBOARD_V2_GENERATIVE_TASKS @@ -598,6 +722,8 @@ def run( "model_path": model_path, "num_gpus": num_gpus, "tasks": vllm_tasks, + "eval_config": final_eval_config, + "backend_config": final_vllm_config, } vllm_results = evaluate_with_vllm(args_vllm) self._lm_eval_results.append(vllm_results) @@ -606,6 +732,8 @@ def run( "model_path": model_path, "num_gpus": num_gpus, "tasks": hf_tasks, + "eval_config": final_eval_config, + "backend_config": final_hf_config, } hf_results = evaluate_with_hf(args_hf) self._lm_eval_results.append(hf_results) From a257e92892f16b638649e9bfa6d53e2fb8ab7f46 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 20 Mar 2025 03:11:08 +0000 Subject: [PATCH 06/10] make cache_requests be a eval_config Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/eval/leaderboard.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index 2452a73f..3a35af07 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -75,6 +75,7 @@ class TaskGrouping(t.TypedDict): "confirm_run_unsafe_code": True, "max_model_len": 32768, "system_instruction": None, + "cache_requests": False, } # Default backend-specific configuration parameters @@ -88,7 +89,6 @@ class TaskGrouping(t.TypedDict): DEFAULT_HF_CONFIG = { "dtype": "float16", "trust_remote_code": True, - "cache_requests": True, } # generative tasks go here @@ -560,6 +560,7 @@ def __init__( - confirm_run_unsafe_code: True - Whether to run potentially unsafe code without confirmation - max_model_len: 32768 - Maximum sequence length for the model - system_instruction: None - Optional system instruction to prepend to prompts + - cache_requests: False - Whether to cache requests for the dataset vllm_config: Configuration for vLLM-specific parameters. Default values (can be overridden): - dtype: "float16" - Data type for model weights @@ -571,7 +572,6 @@ def __init__( Default values (can be overridden): - dtype: "float16" - Data type for model weights - trust_remote_code: True - Whether to trust remote code in model loading - - cache_requests: True - Whether to cache requests And any other HuggingFace parameters supported by simple_evaluate. """ self.model_path = model_path @@ -675,7 +675,6 @@ def run( Default values (can be overridden): - dtype: "float16" - Data type for model weights - trust_remote_code: True - Whether to trust remote code in model loading - - cache_requests: True - Whether to cache requests And any other HuggingFace parameters supported by simple_evaluate. Returns: From aa573d94350453f7c8767ca92c7d8c22018febed Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 20 Mar 2025 03:56:01 +0000 Subject: [PATCH 07/10] enable leaderboard to run with a remote openai provider Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/eval/leaderboard.py | 212 ++++++++++++++++++++-------- 1 file changed, 153 insertions(+), 59 deletions(-) diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index 3a35af07..6b02df91 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -65,6 +65,7 @@ class TaskGrouping(t.TypedDict): huggingface: t.List[str] vllm: t.List[str] + openai: t.List[str] # Default configuration parameters for evaluation @@ -73,7 +74,6 @@ class TaskGrouping(t.TypedDict): "apply_chat_template": True, "fewshot_as_multiturn": True, "confirm_run_unsafe_code": True, - "max_model_len": 32768, "system_instruction": None, "cache_requests": False, } @@ -84,11 +84,20 @@ class TaskGrouping(t.TypedDict): "gpu_memory_utilization": 0.8, "disable_custom_all_reduce": True, "enforce_eager": False, + "max_model_len": 32768, } DEFAULT_HF_CONFIG = { "dtype": "float16", "trust_remote_code": True, + "max_length": 32768, +} + +# 1. Add OpenAI configuration defaults +DEFAULT_OPENAI_CONFIG = { + "max_tokens": 768, + "temperature": 0.0, + "seed": 1337, } # generative tasks go here @@ -124,10 +133,6 @@ def evaluate_with_vllm(args: LeaderboardArgs) -> t.Dict[str, t.Any]: **backend_config, } - # Set max_model_len if provided in eval_config - if "max_model_len" in eval_config: - model_args["max_model_len"] = eval_config.pop("max_model_len") - # Extract system_instruction if provided system_instruction = eval_config.pop("system_instruction", None) @@ -165,10 +170,6 @@ def worker(rank, world_size, args: LeaderboardArgs, result_queue: mp.Queue): # Prepare model_args model_args = {"pretrained": args["model_path"], **backend_config} - # Set max_model_len if provided in eval_config - if "max_model_len" in eval_config: - model_args["max_model_len"] = eval_config.pop("max_model_len") - # Extract system_instruction if provided system_instruction = eval_config.pop("system_instruction", None) @@ -504,14 +505,24 @@ def validate_leaderboard_v2_tasks(tasks: t.List[str]): ) -def get_task_groupings(tasks: t.List[str]) -> TaskGrouping: +def get_task_groupings( + tasks: t.List[str], api_endpoint: t.Optional[str] = None +) -> TaskGrouping: """ Given a list of tasks, bucket them per their optimal runtime. + When an API endpoint is provided, all tasks are routed to OpenAI. """ + if api_endpoint: + # When using an API endpoint, route all tasks to OpenAI + return {"vllm": [], "huggingface": [], "openai": tasks} + + # Default behavior when no API endpoint is provided task_grouping: TaskGrouping = { "vllm": [task for task in tasks if task in LEADERBOARD_V2_GENERATIVE_TASKS], "huggingface": [task for task in tasks if task in LEADERBOARD_V2_MCQ_TASKS], + "openai": [], } + overlapping_tasks = set(task_grouping["vllm"]) & set(task_grouping["huggingface"]) assert not overlapping_tasks return task_grouping @@ -543,51 +554,63 @@ def __init__( eval_config: t.Optional[t.Dict[str, t.Any]] = None, vllm_config: t.Optional[t.Dict[str, t.Any]] = None, hf_config: t.Optional[t.Dict[str, t.Any]] = None, + openai_config: t.Optional[t.Dict[str, t.Any]] = None, + api_endpoint: t.Optional[str] = None, ): """ Initialize the evaluator. Args: - model_path: Path to the model to evaluate. + model_path: Path to the model to evaluate or model name for OpenAI API. tasks: List of tasks to evaluate on. - num_gpus: Number of GPUs to use. + num_gpus: Number of GPUs to use (ignored when using API endpoint). output_file: Path to save results to. - eval_config: Configuration for general evaluation parameters that apply to both backends. + eval_config: Configuration for general evaluation parameters that apply to all backends. Default values (can be overridden): - batch_size: "auto" - Batch size for evaluation, or "auto" for automatic batching - apply_chat_template: True - Whether to apply chat template formatting - fewshot_as_multiturn: True - Whether to format few-shot examples as multi-turn conversations - confirm_run_unsafe_code: True - Whether to run potentially unsafe code without confirmation - - max_model_len: 32768 - Maximum sequence length for the model - system_instruction: None - Optional system instruction to prepend to prompts - - cache_requests: False - Whether to cache requests for the dataset + - cache_requests: False - Whether to cache requests vllm_config: Configuration for vLLM-specific parameters. Default values (can be overridden): - dtype: "float16" - Data type for model weights - gpu_memory_utilization: 0.8 - Fraction of GPU memory to use - disable_custom_all_reduce: True - Whether to disable custom all-reduce implementation - enforce_eager: False - Whether to enforce eager execution + - max_model_len: 32768 - Maximum sequence length for the model And any other vLLM parameters supported by simple_evaluate. hf_config: Configuration for HuggingFace-specific parameters. Default values (can be overridden): - dtype: "float16" - Data type for model weights - trust_remote_code: True - Whether to trust remote code in model loading + - max_length: 32768 - Maximum sequence length for the model And any other HuggingFace parameters supported by simple_evaluate. + openai_config: Configuration for OpenAI-specific parameters. + Default values (can be overridden): + - max_tokens: 768 - Maximum tokens to generate + - temperature: 0.0 - Temperature for sampling + - seed: 1337 - Seed for reproducibility + api_endpoint: Optional OpenAI-compatible API endpoint. + When provided, tasks are evaluated using the OpenAI API instead of local models. """ self.model_path = model_path - if not cuda.is_available(): + if not api_endpoint and not cuda.is_available(): raise ValueError( - "Running without CUDA is currently unsupported. Contributions are welcome." + "Running without CUDA is currently unsupported unless using an API endpoint." ) # set whatever we need here self.num_gpus = num_gpus self.tasks = tasks + self.api_endpoint = api_endpoint # Store evaluation configurations self.eval_config = eval_config or {} self.vllm_config = vllm_config or {} self.hf_config = hf_config or {} + self.openai_config = openai_config or {} # validate output file self.output_file = output_file @@ -644,38 +667,49 @@ def run( eval_config: t.Optional[t.Dict[str, t.Any]] = None, vllm_config: t.Optional[t.Dict[str, t.Any]] = None, hf_config: t.Optional[t.Dict[str, t.Any]] = None, + openai_config: t.Optional[t.Dict[str, t.Any]] = None, + api_endpoint: t.Optional[str] = None, ) -> LeaderboardV2EvalResult: """ Run the Open LLM Leaderboard v2 evaluation. - This function will use both HF transformers and inline vLLM to run the evaluation. - It will then parse the results and save them to a file. + This function will use the appropriate backend based on the provided parameters: + - With api_endpoint: Uses the OpenAI API for all tasks + - Without api_endpoint: Uses both HF transformers and vLLM for optimal performance Args: - model_path: The path to the model to evaluate. + model_path: The path to the model to evaluate or model name for API. tasks: The tasks to evaluate. - num_gpus: The number of GPUs to use. + num_gpus: The number of GPUs to use (ignored when using API). output_file: The path to the file to save the results to. - eval_config: Configuration for general evaluation parameters that apply to both backends. + eval_config: Configuration for general evaluation parameters that apply to all backends. Default values (can be overridden): - batch_size: "auto" - Batch size for evaluation, or "auto" for automatic batching - apply_chat_template: True - Whether to apply chat template formatting - fewshot_as_multiturn: True - Whether to format few-shot examples as multi-turn conversations - confirm_run_unsafe_code: True - Whether to run potentially unsafe code without confirmation - - max_model_len: 32768 - Maximum sequence length for the model - system_instruction: None - Optional system instruction to prepend to prompts + - cache_requests: False - Whether to cache requests vllm_config: Configuration for vLLM-specific parameters. Default values (can be overridden): - dtype: "float16" - Data type for model weights - gpu_memory_utilization: 0.8 - Fraction of GPU memory to use - disable_custom_all_reduce: True - Whether to disable custom all-reduce implementation - enforce_eager: False - Whether to enforce eager execution + - max_model_len: 32768 - Maximum sequence length for the model And any other vLLM parameters supported by simple_evaluate. hf_config: Configuration for HuggingFace-specific parameters. Default values (can be overridden): - dtype: "float16" - Data type for model weights - trust_remote_code: True - Whether to trust remote code in model loading + - max_length: 32768 - Maximum sequence length for the model And any other HuggingFace parameters supported by simple_evaluate. + openai_config: Configuration for OpenAI-specific parameters. + Default values (can be overridden): + - max_tokens: 768 - Maximum tokens to generate + - temperature: 0.0 - Temperature for sampling + - seed: 1337 - Seed for reproducibility + api_endpoint: Optional OpenAI-compatible API endpoint. Returns: LeaderboardV2EvalResult: A dict containing the overall leaderboard score and the breakdown per subtask. @@ -684,60 +718,76 @@ def run( tasks = self.tasks if not tasks else tasks num_gpus = self.num_gpus if not num_gpus else num_gpus output_file = self.output_file if not output_file else output_file + api_endpoint = self.api_endpoint if api_endpoint is None else api_endpoint # Merge configurations with instance configurations, with run-time configs taking precedence final_eval_config = {**self.eval_config, **(eval_config or {})} final_vllm_config = {**self.vllm_config, **(vllm_config or {})} final_hf_config = {**self.hf_config, **(hf_config or {})} + final_openai_config = {**self.openai_config, **(openai_config or {})} + + # If API endpoint is provided, add it to the OpenAI config + if api_endpoint and "base_url" not in final_openai_config: + final_openai_config["base_url"] = api_endpoint if not tasks: tasks = LEADERBOARD_V2_MCQ_TASKS + LEADERBOARD_V2_GENERATIVE_TASKS # validation logic - # no need to validate model path -- the inference libraries will either be able to - # load it, or they won't - validate_leaderboard_v2_tasks(tasks) - if not num_gpus: - num_gpus = cuda.device_count() - if num_gpus <= 0 or num_gpus > cuda.device_count(): - raise ValueError( - f"invalid value for num_gpus, must be between 1 and {cuda.device_count()}; got: {num_gpus}" - ) + + # Only validate GPU requirements when not using an API endpoint + if not api_endpoint: + if not num_gpus: + num_gpus = cuda.device_count() + if num_gpus <= 0 or num_gpus > cuda.device_count(): + raise ValueError( + f"invalid value for num_gpus, must be between 1 and {cuda.device_count()}; got: {num_gpus}" + ) + if output_file: validate_output_path(output_file) - # now we just have to run the task group in their most appropriate runtime - # this is important because certain tasks like MCQ are better-suited to be - # excuted in raw transformers due to the lack of KV-Cache overhead, - # whereas generative tasks are better suited for vLLM due to their need for - # accessing previous tokens - - grouped_tasks = get_task_groupings(tasks) + # Group tasks by optimal runtime + grouped_tasks = get_task_groupings(tasks, api_endpoint) self._lm_eval_results = [] - vllm_results, hf_results = None, None - if vllm_tasks := grouped_tasks["vllm"]: - args_vllm: LeaderboardArgs = { - "model_path": model_path, - "num_gpus": num_gpus, - "tasks": vllm_tasks, - "eval_config": final_eval_config, - "backend_config": final_vllm_config, - } - vllm_results = evaluate_with_vllm(args_vllm) - self._lm_eval_results.append(vllm_results) - if hf_tasks := grouped_tasks["huggingface"]: - args_hf: LeaderboardArgs = { + + # Execute tasks using the appropriate backends + if openai_tasks := grouped_tasks["openai"]: + args_openai: LeaderboardArgs = { "model_path": model_path, - "num_gpus": num_gpus, - "tasks": hf_tasks, + "num_gpus": 1, # Not used for API calls but required by the type + "tasks": openai_tasks, "eval_config": final_eval_config, - "backend_config": final_hf_config, + "backend_config": final_openai_config, } - hf_results = evaluate_with_hf(args_hf) - self._lm_eval_results.append(hf_results) - - # convert the output of lm-eval into something that's already parsed + openai_results = evaluate_with_openai(args_openai) + self._lm_eval_results.append(openai_results) + else: + # Only run local evaluation if not using OpenAI API + if vllm_tasks := grouped_tasks["vllm"]: + args_vllm: LeaderboardArgs = { + "model_path": model_path, + "num_gpus": num_gpus, + "tasks": vllm_tasks, + "eval_config": final_eval_config, + "backend_config": final_vllm_config, + } + vllm_results = evaluate_with_vllm(args_vllm) + self._lm_eval_results.append(vllm_results) + + if hf_tasks := grouped_tasks["huggingface"]: + args_hf: LeaderboardArgs = { + "model_path": model_path, + "num_gpus": num_gpus, + "tasks": hf_tasks, + "eval_config": final_eval_config, + "backend_config": final_hf_config, + } + hf_results = evaluate_with_hf(args_hf) + self._lm_eval_results.append(hf_results) + + # Convert the output of lm-eval into something that's already parsed results: LeaderboardV2EvalResult = get_scores_from_result_dicts( *self._lm_eval_results ) @@ -746,3 +796,47 @@ def run( if output_file: self.save_to_file(output_file) return results + + +def evaluate_with_openai(args: LeaderboardArgs) -> t.Dict[str, t.Any]: + # Start with default configurations + eval_config = deepcopy(DEFAULT_EVAL_CONFIG) + backend_config = deepcopy(DEFAULT_OPENAI_CONFIG) + + # Override with user-provided configurations + if "eval_config" in args and args["eval_config"]: + eval_config.update(args["eval_config"]) + if "backend_config" in args and args["backend_config"]: + backend_config.update(args["backend_config"]) + + # Extract base_url and api_key from backend_config if provided + base_url = backend_config.pop("base_url", None) + api_key = backend_config.pop("api_key", None) + + # Build model_args for lm-eval's OpenAI client + model_args = { + "model": args["model_path"], # model name as recognized by the API + } + + # Add base_url if provided + if base_url: + model_args["base_url"] = base_url + + # Add API key if provided + if api_key: + model_args["api_key"] = api_key + + # Add any remaining backend config options + model_args.update(backend_config) + + # Extract system_instruction if provided + system_instruction = eval_config.pop("system_instruction", None) + + results = simple_evaluate( + tasks=args["tasks"], + model="openai", + model_args=model_args, + system_instruction=system_instruction, + **eval_config, + ) + return results From ce38464d503b4e82a7e37661a3a0e8da96086e27 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Thu, 20 Mar 2025 04:08:06 +0000 Subject: [PATCH 08/10] make the leaderboard dependencies into an optional target under instructlab-eval[leaderboard] Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- README.md | 10 ++++++++++ pyproject.toml | 1 + requirements.txt | 5 ++--- scripts/test_leaderboard.py | 12 +++++++++++- src/instructlab/eval/leaderboard.py | 3 +++ 5 files changed, 27 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e72b41ac..9a9900d4 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,16 @@ the phase. At the end of each phase, we evaluate all the checkpoints in order to Once training is complete, and we have picked the best checkpoint from the output of the final phase, we can run full-scale evaluation suite which runs MT-Bench, MMLU, MT-Bench Branch and MMLU Branch. +### Leaderboard Evaluation + +For cases when you want to run the full Open LLM Leaderboard v2 evaluation suite, we provide an optional dependency package for the leaderboard tasks. This includes additional benchmarks like GPQA, IFEVAL, BBH, MMLU-PRO, MUSR, and MATH-HARD. + +To install the optional leaderboard dependencies, use: + +```bash +pip install instructlab-eval[leaderboard] +``` + ## Methods of Evaluation Below are more in-depth explanations of the suite of benchmarks we are using as methods for evaluation of models. diff --git a/pyproject.toml b/pyproject.toml index 44a7de52..4fb00fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ package-dir = {"" = "src"} [tool.setuptools.dynamic] dependencies = {file = ["requirements.txt"]} +optional-dependencies = {leaderboard = {file = ["requirements-leaderboard.txt"]}} [tool.setuptools.packages.find] where = ["src"] diff --git a/requirements.txt b/requirements.txt index 07839a7f..9e2002fb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,7 @@ transformers accelerate pandas pandas-stubs -# All optional dependencies like this can be found in lm-eval: -# https://github.com/EleutherAI/lm-evaluation-harness/blob/main/pyproject.toml -lm-eval[math,ifeval,sentencepiece,vllm]>=0.4.4 +# Base lm-eval dependency +lm-eval>=0.4.4 httpx ragas>=0.2.11 diff --git a/scripts/test_leaderboard.py b/scripts/test_leaderboard.py index 564e8baf..e5e2b770 100644 --- a/scripts/test_leaderboard.py +++ b/scripts/test_leaderboard.py @@ -1,9 +1,19 @@ +#!/usr/bin/env python +# SPDX-License-Identifier: Apache-2.0 + +# NOTE: This script requires the leaderboard optional dependencies. +# Install with: pip install instructlab-eval[leaderboard] + # First Party +import json from instructlab.eval.leaderboard import LeaderboardV2Evaluator if __name__ == "__main__": evaluator = LeaderboardV2Evaluator( - model_path="ibm-granite/granite-3.1-8b-instruct", + model_path="ibm-granite/granite-3.1-8b-base", + eval_config={ + "apply_chat_template": False, + }, ) results = evaluator.run() print("got results from leaderboard v2") diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index 6b02df91..37c58477 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -541,6 +541,9 @@ def calculate_overall_leaderboard_score(results: t.Dict[str, ParsedScores]) -> f class LeaderboardV2Evaluator(Evaluator): """ Evaluator for Open Leaderboard v2. + + NOTE: This evaluator requires the optional leaderboard dependencies. + Install with: pip install instructlab-eval[leaderboard] """ name = "leaderboard_v2" From cd47eaacc425205248bbd4978a7794d4f817e6ed Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 31 Mar 2025 15:04:39 +0000 Subject: [PATCH 09/10] push up evaluation script Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- scripts/evaluate_best_checkpoint.py | 81 +++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 scripts/evaluate_best_checkpoint.py diff --git a/scripts/evaluate_best_checkpoint.py b/scripts/evaluate_best_checkpoint.py new file mode 100644 index 00000000..f8128f85 --- /dev/null +++ b/scripts/evaluate_best_checkpoint.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python3 + +""" +Example usage: +python scripts/evaluate_best_checkpoint.py \ + /path/to/checkpoint_dir \ + --output-file /path/to/output_file +""" + +import json +import typer +from pathlib import Path +from typing import Optional + + +app = typer.Typer() + + +@app.command() +def main( + input_dir: Path = typer.Argument(..., help="Input directory to process"), + output_file: Optional[Path] = typer.Option(None, help="Optional output file path"), +): + """ + Process files in the input directory and optionally save results to an output file. + """ + if not input_dir.exists(): + typer.echo(f"Error: Input directory '{input_dir}' does not exist") + raise typer.Exit(1) + + if not input_dir.is_dir(): + typer.echo(f"Error: '{input_dir}' is not a directory") + raise typer.Exit(1) + + checkpoint_dirs = list(input_dir.glob("hf_format/samples_*")) + typer.echo(f"Found {len(checkpoint_dirs)} samples files") + + if not checkpoint_dirs: + typer.echo( + f"No checkpoint directories found in the input directory: {input_dir}" + ) + raise typer.Exit(1) + + typer.echo("importing LeaderboardV2Evaluator, this may take a while...") + from instructlab.eval.leaderboard import LeaderboardV2Evaluator + + checkpoint_results = {} + for checkpoint in checkpoint_dirs: + typer.echo(f"Processing checkpoint: {checkpoint}") + ckpt_output_file = checkpoint / "leaderboard_results.json" + evaluator = LeaderboardV2Evaluator( + model_path=str(checkpoint), output_file=ckpt_output_file + ) + result = evaluator.run() + checkpoint_results[checkpoint.name] = result + typer.echo(f"Checkpoint {checkpoint.name} results: {result['score']}") + + # Sort checkpoints by score + sorted_checkpoints = sorted( + checkpoint_results.items(), key=lambda x: x[1]["score"], reverse=True + ) + typer.echo("Sorted checkpoints by score:") + for checkpoint_name, result in sorted_checkpoints: + typer.echo(f"{'=' * 100}") + typer.echo(json.dumps(result, indent=2)) + + typer.echo(f"{'=' * 100}") + typer.echo(f"Best checkpoint: {sorted_checkpoints[0][0]}") + + if output_file: + typer.echo(f"Output will be saved to: {output_file}") + with open(output_file, "w") as f: + json.dump(checkpoint_results, f, indent=2) + + # Add your processing logic here + + typer.echo("Processing complete!") + + +if __name__ == "__main__": + app() From 66fb8bb519a39ee33001d4951ec73fb4db9c5561 Mon Sep 17 00:00:00 2001 From: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> Date: Mon, 31 Mar 2025 16:31:01 +0000 Subject: [PATCH 10/10] add requirements file for leaderboard Signed-off-by: Oleg Silkin <97077423+RobotSail@users.noreply.github.com> --- requirements-leaderboard.txt | 10 +++ scripts/evaluate_best_checkpoint.py | 13 ++-- scripts/test_leaderboard.py | 4 +- src/instructlab/eval/leaderboard.py | 110 ++++++++++++++++------------ tests/test_project.py | 2 + tox.ini | 4 +- 6 files changed, 88 insertions(+), 55 deletions(-) create mode 100644 requirements-leaderboard.txt diff --git a/requirements-leaderboard.txt b/requirements-leaderboard.txt new file mode 100644 index 00000000..df2e19d0 --- /dev/null +++ b/requirements-leaderboard.txt @@ -0,0 +1,10 @@ +lm-eval[ifeval,vllm,math,sentencepiece]>=0.4.4 + +# vLLM 0.8.3 + torch 2.6.0 doesn't work when running vLLM on granite-3.1-8b-instruct +vllm<=0.7.3 +torch<=2.5.1 + +# XXX(osilkin): We use StrEnum in leaderboard, but Python3.10 doesn't have it as part of +# the standard library, so we have to install it from the older library. +strenum>=0.4.15; python_version < '3.11' +typing-extensions>=4.0.0; python_version < '3.11' diff --git a/scripts/evaluate_best_checkpoint.py b/scripts/evaluate_best_checkpoint.py index f8128f85..ff1aab3f 100644 --- a/scripts/evaluate_best_checkpoint.py +++ b/scripts/evaluate_best_checkpoint.py @@ -7,11 +7,13 @@ --output-file /path/to/output_file """ -import json -import typer +# Standard from pathlib import Path from typing import Optional +import json +# Third Party +import typer app = typer.Typer() @@ -42,6 +44,7 @@ def main( raise typer.Exit(1) typer.echo("importing LeaderboardV2Evaluator, this may take a while...") + # First Party from instructlab.eval.leaderboard import LeaderboardV2Evaluator checkpoint_results = {} @@ -49,15 +52,15 @@ def main( typer.echo(f"Processing checkpoint: {checkpoint}") ckpt_output_file = checkpoint / "leaderboard_results.json" evaluator = LeaderboardV2Evaluator( - model_path=str(checkpoint), output_file=ckpt_output_file + model_path=str(checkpoint), output_file=ckpt_output_file, num_gpus=8 ) result = evaluator.run() checkpoint_results[checkpoint.name] = result - typer.echo(f"Checkpoint {checkpoint.name} results: {result['score']}") + typer.echo(f"Checkpoint {checkpoint.name} results: {result['overall_score']}") # Sort checkpoints by score sorted_checkpoints = sorted( - checkpoint_results.items(), key=lambda x: x[1]["score"], reverse=True + checkpoint_results.items(), key=lambda x: x[1]["overall_score"], reverse=True ) typer.echo("Sorted checkpoints by score:") for checkpoint_name, result in sorted_checkpoints: diff --git a/scripts/test_leaderboard.py b/scripts/test_leaderboard.py index e5e2b770..2020bb6d 100644 --- a/scripts/test_leaderboard.py +++ b/scripts/test_leaderboard.py @@ -4,8 +4,10 @@ # NOTE: This script requires the leaderboard optional dependencies. # Install with: pip install instructlab-eval[leaderboard] -# First Party +# Standard import json + +# First Party from instructlab.eval.leaderboard import LeaderboardV2Evaluator if __name__ == "__main__": diff --git a/src/instructlab/eval/leaderboard.py b/src/instructlab/eval/leaderboard.py index 37c58477..ff2145ae 100644 --- a/src/instructlab/eval/leaderboard.py +++ b/src/instructlab/eval/leaderboard.py @@ -1,11 +1,10 @@ # Standard -from enum import StrEnum +from copy import deepcopy from pathlib import Path import gc import json import os import typing as t -from copy import deepcopy # Third Party from accelerate import Accelerator @@ -18,6 +17,22 @@ # Local from .evaluator import Evaluator +# Since StrEnum wasn't part of the STL until Python3.11, we must do this +try: + # Standard + from enum import StrEnum +except ImportError: + # Third Party + from strenum import StrEnum # type: ignore[no-redef] + +# And do the same thing to bring in NotRequired from typing +try: + # Standard + from typing import NotRequired +except ImportError: + # Third Party + from typing_extensions import NotRequired + class ParsedScores(t.TypedDict): """ @@ -25,17 +40,17 @@ class ParsedScores(t.TypedDict): """ score: float - subtasks: t.NotRequired[t.Dict[str, float]] + subtasks: NotRequired[t.Dict[str, float]] class LeaderboardV2EvalResult(t.TypedDict): overall_score: float - leaderboard_gpqa: t.NotRequired[ParsedScores] - leaderboard_ifeval: t.NotRequired[ParsedScores] - leaderboard_bbh: t.NotRequired[ParsedScores] - leaderboard_mmlu_pro: t.NotRequired[ParsedScores] - leaderboard_musr: t.NotRequired[ParsedScores] - leaderboard_math_hard: t.NotRequired[ParsedScores] + leaderboard_gpqa: NotRequired[ParsedScores] + leaderboard_ifeval: NotRequired[ParsedScores] + leaderboard_bbh: NotRequired[ParsedScores] + leaderboard_mmlu_pro: NotRequired[ParsedScores] + leaderboard_musr: NotRequired[ParsedScores] + leaderboard_math_hard: NotRequired[ParsedScores] class LeaderboardV2Tasks(StrEnum): @@ -94,7 +109,7 @@ class TaskGrouping(t.TypedDict): } # 1. Add OpenAI configuration defaults -DEFAULT_OPENAI_CONFIG = { +DEFAULT_OPENAI_CONFIG: t.Dict[str, t.Any] = { "max_tokens": 768, "temperature": 0.0, "seed": 1337, @@ -194,9 +209,6 @@ def worker(rank, world_size, args: LeaderboardArgs, result_queue: mp.Queue): def evaluate_with_hf(args: LeaderboardArgs) -> t.Dict[str, t.Any]: # we need to use torch.multiprocessing to run each task in a separate process, # and then combine the results - # Third Party - import torch.multiprocessing as mp - num_processes = args["num_gpus"] # Create the context and queue within the same context @@ -222,9 +234,9 @@ def evaluate_with_hf(args: LeaderboardArgs) -> t.Dict[str, t.Any]: p.join() # extract the result which is not None - assert len([res for res in results.values() if res is not None]) == 1, ( - "we expect exactly 1 process to return a results dict properly" - ) + assert ( + len([res for res in results.values() if res is not None]) == 1 + ), "we expect exactly 1 process to return a results dict properly" results_dict = [res for res in results.values() if res is not None][0] return results_dict @@ -290,9 +302,9 @@ def parse_bbh(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.BBH.value, "acc_norm" ) - assert len(parsed_scores["subtasks"]) == 24, ( - "there should be 24 subtasks of bbh run" - ) + assert ( + len(parsed_scores["subtasks"]) == 24 + ), "there should be 24 subtasks of bbh run" return parsed_scores @@ -343,9 +355,9 @@ def parse_ifeval(result_dict: t.Dict[str, t.Any]) -> ParsedScores: scores.append(value) target_metrics.remove(metric) - assert len(scores) == 2, ( - f"there should only be 2 values extracted in ifeval, got: {len(scores)}" - ) + assert ( + len(scores) == 2 + ), f"there should only be 2 values extracted in ifeval, got: {len(scores)}" return { "score": sum(scores) / 2, } @@ -369,9 +381,9 @@ def parse_gpqa(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.GPQA.value, "acc_norm" ) - assert len(parsed_scores["subtasks"]) == 3, ( - f"Expected 3 gpqa scores, got {len(parsed_scores['subtasks'])}" - ) + assert ( + len(parsed_scores["subtasks"]) == 3 + ), f"Expected 3 gpqa scores, got {len(parsed_scores['subtasks'])}" return parsed_scores @@ -382,9 +394,9 @@ def parse_math_hard(result_dict: t.Dict[str, t.Any]) -> ParsedScores: parsed_scores = parse_multitask_results( result_dict, LeaderboardV2Tasks.MATH_HARD.value, "exact_match" ) - assert len(parsed_scores["subtasks"]) == 7, ( - f"leaderboard_math_hard should have 7 subtasks, found: {len(parsed_scores['subtasks'])}" - ) + assert ( + len(parsed_scores["subtasks"]) == 7 + ), f"leaderboard_math_hard should have 7 subtasks, found: {len(parsed_scores['subtasks'])}" return parsed_scores @@ -451,9 +463,9 @@ def get_scores_from_result_dicts( # this is just a sanity check step benchmarks_already_covered = set(parsed_scores.keys()) overlapping_benchmarks = benchmarks_already_covered & benchmarks_to_parse - assert len(benchmarks_already_covered & benchmarks_to_parse) == 0, ( - f"expected no overlapping benchmarks but found the following to overlap: {list(overlapping_benchmarks)}" - ) + assert ( + len(benchmarks_already_covered & benchmarks_to_parse) == 0 + ), f"expected no overlapping benchmarks but found the following to overlap: {list(overlapping_benchmarks)}" # now actually add them for benchmark in benchmarks_to_parse: @@ -486,12 +498,15 @@ def validate_output_path(output_file: str) -> None: # Test if we can write to the file by opening it in append mode # We don't actually write anything - output_path.open("a").close() + with output_path.open("a", encoding="utf-8") as _: + pass - except PermissionError: - raise ValueError(f"Permission denied: Cannot write to {output_file}") - except OSError as e: - raise ValueError(f"Invalid output path: {output_file}. Error: {str(e)}") + except PermissionError as pe: + raise ValueError(f"Permission denied: Cannot write to {output_file}") from pe + except OSError as ose: + raise ValueError( + f"Invalid output path: {output_file}. Error: {str(ose)}" + ) from ose def validate_leaderboard_v2_tasks(tasks: t.List[str]): @@ -658,7 +673,7 @@ def save_to_file(self, output_file: t.Optional[str] = None) -> None: output_dir = os.path.dirname(output_file) if output_dir: os.makedirs(output_dir, exist_ok=True) - with open(output_file, "w") as f: + with open(output_file, "w", encoding="utf-8") as f: json.dump(self._results, f, indent=2) def run( @@ -739,15 +754,6 @@ def run( # validation logic validate_leaderboard_v2_tasks(tasks) - # Only validate GPU requirements when not using an API endpoint - if not api_endpoint: - if not num_gpus: - num_gpus = cuda.device_count() - if num_gpus <= 0 or num_gpus > cuda.device_count(): - raise ValueError( - f"invalid value for num_gpus, must be between 1 and {cuda.device_count()}; got: {num_gpus}" - ) - if output_file: validate_output_path(output_file) @@ -767,6 +773,14 @@ def run( openai_results = evaluate_with_openai(args_openai) self._lm_eval_results.append(openai_results) else: + # Only validate GPU requirements when not using an API endpoint + if not num_gpus: + num_gpus = cuda.device_count() + if num_gpus <= 0 or num_gpus > cuda.device_count(): + raise ValueError( + f"invalid value for num_gpus, must be between 1 and {cuda.device_count()}; got: {num_gpus}" + ) + # Only run local evaluation if not using OpenAI API if vllm_tasks := grouped_tasks["vllm"]: args_vllm: LeaderboardArgs = { @@ -823,11 +837,11 @@ def evaluate_with_openai(args: LeaderboardArgs) -> t.Dict[str, t.Any]: # Add base_url if provided if base_url: - model_args["base_url"] = base_url + model_args.update({"base_url": base_url}) # Add API key if provided if api_key: - model_args["api_key"] = api_key + model_args.update({"api_key": api_key}) # Add any remaining backend config options model_args.update(backend_config) diff --git a/tests/test_project.py b/tests/test_project.py index 13c4dbc8..46f863eb 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -4,6 +4,7 @@ # First Party from instructlab.eval.evaluator import Evaluator +from instructlab.eval.leaderboard import LeaderboardV2Evaluator from instructlab.eval.mmlu import MMLUBranchEvaluator, MMLUEvaluator from instructlab.eval.mt_bench import MTBenchBranchEvaluator, MTBenchEvaluator @@ -14,6 +15,7 @@ def test_evaluator_eps(): "mmlu_branch": MMLUBranchEvaluator, "mt_bench": MTBenchEvaluator, "mt_bench_branch": MTBenchBranchEvaluator, + "leaderboard_v2": LeaderboardV2Evaluator, } eps = entry_points(group="instructlab.eval.evaluator") found = {} diff --git a/tox.ini b/tox.ini index 5d41cb67..8adeebda 100644 --- a/tox.ini +++ b/tox.ini @@ -19,7 +19,9 @@ setenv = package = wheel wheel_build_env = pkg # equivalent to `pip install instructlab[cpu]` -extras = cpu +extras = + cpu + leaderboard deps = pytest pytest-asyncio