Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion verifiers/envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Messages,
MessageType,
ModelResponse,
OnGroupComplete,
RolloutInput,
RolloutTiming,
SamplingArgs,
Expand Down Expand Up @@ -663,6 +664,7 @@ async def generate(
save_results: bool = False,
save_every: int = -1,
use_tqdm: bool = True,
on_group_complete: OnGroupComplete | None = None,
) -> GenerateOutputs:
"""
Generate rollouts for a set of inputs by group.
Expand Down Expand Up @@ -725,6 +727,7 @@ async def generate(
)

groups_completed = 0
total_groups = len(group_list)
all_states: list[State] = []
try:
for coro in asyncio.as_completed(group_tasks.keys()):
Expand All @@ -735,7 +738,11 @@ async def generate(
if pbar is not None:
pbar.update(1)

# save intermediate results
if on_group_complete is not None:
await on_group_complete(
group_states, groups_completed, total_groups
)

if (
save_results
and save_every > 0
Expand Down Expand Up @@ -842,6 +849,7 @@ async def evaluate(
state_columns: list[str] | None = None,
save_results: bool = False,
save_every: int = -1,
on_group_complete: OnGroupComplete | None = None,
**kwargs,
) -> GenerateOutputs:
"""
Expand All @@ -860,6 +868,7 @@ async def evaluate(
state_columns=state_columns,
save_results=save_results,
save_every=save_every,
on_group_complete=on_group_complete,
**kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions verifiers/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
GroupRewardFunc = Callable[..., list[float] | Awaitable[list[float]]]
RewardFunc = IndividualRewardFunc | GroupRewardFunc

OnGroupComplete = Callable[["list[State]", int, int], Awaitable[None]]


class TrajectoryStepTokens(TypedDict):
prompt_ids: list[int]
Expand Down
17 changes: 12 additions & 5 deletions verifiers/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from datasets.utils import logging as ds_logging

import verifiers as vf
from verifiers.types import Endpoints, EvalConfig, GenerateMetadata, GenerateOutputs
from verifiers.types import (
Endpoints,
EvalConfig,
GenerateMetadata,
GenerateOutputs,
OnGroupComplete,
)
from verifiers.utils.client_utils import setup_client
from verifiers.utils.logging_utils import print_prompt_completions_sample
from verifiers.utils.message_utils import messages_to_printable, sanitize_tool_calls
Expand Down Expand Up @@ -98,19 +104,19 @@ def print_results(results: GenerateOutputs, num_samples: int = 1):
print(out)


async def run_evaluation(config: EvalConfig) -> GenerateOutputs:
# set up AsyncOpenAI client with high limits to prevent timeouts
async def run_evaluation(
config: EvalConfig,
on_group_complete: OnGroupComplete | None = None,
) -> GenerateOutputs:
client = setup_client(
config.client_config,
)
logger.debug(
f"Initialized AsyncOpenAI client with base_url: {config.client_config.api_base_url}"
)

# load environment
vf_env = vf.load_environment(env_id=config.env_id, **config.env_args)

# run evaluation
results_path = get_eval_results_path(config)
logger.info(f"Starting evaluation with model: {config.model}")
logger.info(
Expand All @@ -130,6 +136,7 @@ async def run_evaluation(config: EvalConfig) -> GenerateOutputs:
state_columns=config.state_columns,
save_results=config.save_results,
save_every=config.save_every,
on_group_complete=on_group_complete,
)
end_time = time.time()
logger.info(f"Evaluation completed in {end_time - start_time:.2f} seconds")
Expand Down
Loading