diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 3959759d5..355305c81 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -36,6 +36,7 @@ Messages, MessageType, ModelResponse, + OnGroupComplete, RolloutInput, RolloutTiming, SamplingArgs, @@ -663,6 +664,7 @@ async def generate( save_results: bool = False, save_every: int = -1, use_tqdm: bool = True, + on_group_complete: OnGroupComplete | None = None, ) -> GenerateOutputs: """ Generate rollouts for a set of inputs by group. @@ -725,6 +727,7 @@ async def generate( ) groups_completed = 0 + total_groups = len(group_list) all_states: list[State] = [] try: for coro in asyncio.as_completed(group_tasks.keys()): @@ -735,7 +738,11 @@ async def generate( if pbar is not None: pbar.update(1) - # save intermediate results + if on_group_complete is not None: + await on_group_complete( + group_states, groups_completed, total_groups + ) + if ( save_results and save_every > 0 @@ -842,6 +849,7 @@ async def evaluate( state_columns: list[str] | None = None, save_results: bool = False, save_every: int = -1, + on_group_complete: OnGroupComplete | None = None, **kwargs, ) -> GenerateOutputs: """ @@ -860,6 +868,7 @@ async def evaluate( state_columns=state_columns, save_results=save_results, save_every=save_every, + on_group_complete=on_group_complete, **kwargs, ) diff --git a/verifiers/types.py b/verifiers/types.py index ab75bb91e..deac2c67d 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -49,6 +49,8 @@ GroupRewardFunc = Callable[..., list[float] | Awaitable[list[float]]] RewardFunc = IndividualRewardFunc | GroupRewardFunc +OnGroupComplete = Callable[["list[State]", int, int], Awaitable[None]] + class TrajectoryStepTokens(TypedDict): prompt_ids: list[int] diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index d8e1a9dc3..9dc100a4d 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -11,7 +11,13 @@ from datasets.utils import logging as ds_logging import verifiers as vf -from verifiers.types import Endpoints, EvalConfig, GenerateMetadata, GenerateOutputs +from verifiers.types import ( + Endpoints, + EvalConfig, + GenerateMetadata, + GenerateOutputs, + OnGroupComplete, +) from verifiers.utils.client_utils import setup_client from verifiers.utils.logging_utils import print_prompt_completions_sample from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls @@ -98,8 +104,10 @@ def print_results(results: GenerateOutputs, num_samples: int = 1): print(out) -async def run_evaluation(config: EvalConfig) -> GenerateOutputs: - # set up AsyncOpenAI client with high limits to prevent timeouts +async def run_evaluation( + config: EvalConfig, + on_group_complete: OnGroupComplete | None = None, +) -> GenerateOutputs: client = setup_client( config.client_config, ) @@ -107,10 +115,8 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs: f"Initialized AsyncOpenAI client with base_url: {config.client_config.api_base_url}" ) - # load environment vf_env = vf.load_environment(env_id=config.env_id, **config.env_args) - # run evaluation results_path = get_eval_results_path(config) logger.info(f"Starting evaluation with model: {config.model}") logger.info( @@ -130,6 +136,7 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs: state_columns=config.state_columns, save_results=config.save_results, save_every=config.save_every, + on_group_complete=on_group_complete, ) end_time = time.time() logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds")