Skip to content

Commit 51e574f

Browse files
authored
added mypy to pre-commit and improved typing (#54)
* added mypy to pre-commit and improved typing * go green * reset notebook * fix security vulrnerabilities
1 parent 91585f8 commit 51e574f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+661
-563
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ results/
1111
poetry.lock
1212
CLAUDE.md
1313
**/CLAUDE.local.md
14+
.mypy_cache/

.pre-commit-config.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ repos:
1818
rev: 5.12.0
1919
hooks:
2020
- id: isort
21+
- repo: https://github.com/pre-commit/mirrors-mypy
22+
rev: v1.8.0
23+
hooks:
24+
- id: mypy
25+
files: ^promptolution/
26+
additional_dependencies:
27+
- types-requests
28+
- pandas-stubs
29+
- numpy
30+
args: [--explicit-package-bases, --config-file=pyproject.toml]
2131
- repo: https://github.com/pycqa/pydocstyle
2232
rev: 6.3.0
2333
hooks:

promptolution/exemplar_selectors/base_exemplar_selector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from abc import ABC, abstractmethod
55

6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Optional
77

88
if TYPE_CHECKING: # pragma: no cover
99
from promptolution.predictors.base_predictor import BasePredictor
@@ -18,7 +18,7 @@ class BaseExemplarSelector(ABC):
1818
that all exemplar selectors should implement.
1919
"""
2020

21-
def __init__(self, task: "BaseTask", predictor: "BasePredictor", config: "ExperimentConfig" = None):
21+
def __init__(self, task: "BaseTask", predictor: "BasePredictor", config: Optional["ExperimentConfig"] = None):
2222
"""Initialize the BaseExemplarSelector.
2323
2424
Args:

promptolution/exemplar_selectors/random_search_selector.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,29 @@ class RandomSearchSelector(BaseExemplarSelector):
1010
evaluates their performance, and selects the best performing set.
1111
"""
1212

13-
def select_exemplars(self, prompt, n_examples: int = 5, n_trials: int = 5):
13+
def select_exemplars(self, prompt: str, n_trials: int = 5) -> str:
1414
"""Select exemplars using a random search strategy.
1515
1616
This method generates multiple sets of random examples, evaluates their performance
1717
when combined with the original prompt, and returns the best performing set.
1818
1919
Args:
2020
prompt (str): The input prompt to base the exemplar selection on.
21-
n_examples (int, optional): The number of exemplars to select in each trial. Defaults to 5.
2221
n_trials (int, optional): The number of random trials to perform. Defaults to 5.
2322
2423
Returns:
2524
str: The best performing prompt, which includes the original prompt and the selected exemplars.
2625
"""
27-
best_score = 0
26+
best_score = 0.0
2827
best_prompt = prompt
2928

3029
for _ in range(n_trials):
31-
_, seq = self.task.evaluate(prompt, self.predictor, n_samples=n_examples, subsample=True, return_seq=True)
32-
prompt_with_examples = "\n\n".join([prompt] + seq) + "\n\n"
30+
_, seq = self.task.evaluate(
31+
prompt, self.predictor, eval_strategy="subsample", return_seq=True, return_agg_scores=False
32+
)
33+
prompt_with_examples = "\n\n".join([prompt] + [seq[0][0]]) + "\n\n"
3334
# evaluate prompts as few shot prompt
34-
score = self.task.evaluate(prompt_with_examples, self.predictor, subsample=True)
35+
score = self.task.evaluate(prompt_with_examples, self.predictor, eval_strategy="subsample")[0]
3536
if score > best_score:
3637
best_score = score
3738
best_prompt = prompt_with_examples

promptolution/exemplar_selectors/random_selector.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Random exemplar selector."""
22

3-
from typing import TYPE_CHECKING
3+
import numpy as np
4+
5+
from typing import TYPE_CHECKING, List, Optional
46

57
from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
68

@@ -18,8 +20,12 @@ class RandomSelector(BaseExemplarSelector):
1820
"""
1921

2022
def __init__(
21-
self, task: "BaseTask", predictor: "BasePredictor", desired_score: int = 1, config: "ExperimentConfig" = None
22-
):
23+
self,
24+
task: "BaseTask",
25+
predictor: "BasePredictor",
26+
desired_score: int = 1,
27+
config: Optional["ExperimentConfig"] = None,
28+
) -> None:
2329
"""Initialize the RandomSelector.
2430
2531
Args:
@@ -44,11 +50,13 @@ def select_exemplars(self, prompt: str, n_examples: int = 5) -> str:
4450
Returns:
4551
str: A new prompt that includes the original prompt and the selected exemplars.
4652
"""
47-
examples = []
53+
examples: List[str] = []
4854
while len(examples) < n_examples:
49-
score, seq = self.task.evaluate(prompt, self.predictor, n_samples=1, return_seq=True)
55+
scores, seqs = self.task.evaluate(
56+
prompt, self.predictor, eval_strategy="subsample", return_seq=True, return_agg_scores=False
57+
)
58+
score = np.mean(scores)
59+
seq = seqs[0][0]
5060
if score == self.desired_score:
51-
examples.append(seq[0])
52-
prompt = "\n\n".join([prompt] + examples) + "\n\n"
53-
54-
return prompt
61+
examples.append(seq)
62+
return "\n\n".join([prompt] + examples) + "\n\n"

promptolution/helpers.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Helper functions for the usage of the libary."""
22

33

4-
from typing import TYPE_CHECKING, Callable, List, Literal
4+
from typing import TYPE_CHECKING, Callable, List, Literal, Optional
55

66
from promptolution.tasks.judge_tasks import JudgeTask
77
from promptolution.tasks.reward_tasks import RewardTask
@@ -45,7 +45,7 @@
4545
logger = get_logger(__name__)
4646

4747

48-
def run_experiment(df: pd.DataFrame, config: "ExperimentConfig"):
48+
def run_experiment(df: pd.DataFrame, config: "ExperimentConfig") -> pd.DataFrame:
4949
"""Run a full experiment based on the provided configuration.
5050
5151
Args:
@@ -79,7 +79,7 @@ def run_optimization(df: pd.DataFrame, config: "ExperimentConfig") -> List[str]:
7979
llm = get_llm(config=config)
8080
predictor = get_predictor(llm, config=config)
8181

82-
config.task_description = config.task_description + " " + predictor.extraction_description
82+
config.task_description = (config.task_description or "") + " " + (predictor.extraction_description or "")
8383
if config.optimizer == "capo" and (config.eval_strategy is None or "block" not in config.eval_strategy):
8484
logger.warning("📌 CAPO requires block evaluation strategy. Setting it to 'sequential_block'.")
8585
config.eval_strategy = "sequential_block"
@@ -126,7 +126,7 @@ def run_evaluation(df: pd.DataFrame, config: "ExperimentConfig", prompts: List[s
126126
return df
127127

128128

129-
def get_llm(model_id: str = None, config: "ExperimentConfig" = None) -> "BaseLLM":
129+
def get_llm(model_id: Optional[str] = None, config: Optional["ExperimentConfig"] = None) -> "BaseLLM":
130130
"""Factory function to create and return a language model instance based on the provided model_id.
131131
132132
This function supports three types of language models:
@@ -144,16 +144,18 @@ def get_llm(model_id: str = None, config: "ExperimentConfig" = None) -> "BaseLLM
144144
Returns:
145145
An instance of LocalLLM, or APILLM based on the model_id.
146146
"""
147-
if model_id is None:
148-
model_id = config.model_id
149-
if "local" in model_id:
150-
model_id = "-".join(model_id.split("-")[1:])
151-
return LocalLLM(model_id, config)
152-
if "vllm" in model_id:
153-
model_id = "-".join(model_id.split("-")[1:])
154-
return VLLM(model_id, config=config)
147+
final_model_id = model_id or (config.model_id if config else None)
148+
if not final_model_id:
149+
raise ValueError("model_id must be provided either directly or through config.")
155150

156-
return APILLM(model_id=model_id, config=config)
151+
if "local" in final_model_id:
152+
model_name = "-".join(final_model_id.split("-")[1:])
153+
return LocalLLM(model_name, config=config)
154+
if "vllm" in final_model_id:
155+
model_name = "-".join(final_model_id.split("-")[1:])
156+
return VLLM(model_name, config=config)
157+
158+
return APILLM(model_id=final_model_id, config=config)
157159

158160

159161
def get_task(
@@ -174,16 +176,19 @@ def get_task(
174176
Returns:
175177
BaseTask: An instance of a task class based on the provided DataFrame and configuration.
176178
"""
177-
if task_type is None:
178-
task_type = config.task_type
179+
final_task_type = task_type or (config.task_type if config else None)
179180

180-
if task_type == "reward":
181+
if final_task_type == "reward":
182+
if reward_function is None:
183+
reward_function = config.reward_function if config else None
184+
assert reward_function is not None, "Reward function must be provided for reward tasks."
181185
return RewardTask(
182186
df=df,
183187
reward_function=reward_function,
184188
config=config,
185189
)
186-
elif task_type == "judge":
190+
elif final_task_type == "judge":
191+
assert judge_llm is not None, "Judge LLM must be provided for judge tasks."
187192
return JudgeTask(df, judge_llm=judge_llm, config=config)
188193

189194
return ClassificationTask(df, config=config)
@@ -193,10 +198,9 @@ def get_optimizer(
193198
predictor: "BasePredictor",
194199
meta_llm: "BaseLLM",
195200
task: "BaseTask",
196-
optimizer: OptimizerType = None,
197-
meta_prompt: str = None,
198-
task_description: str = None,
199-
config: "ExperimentConfig" = None,
201+
optimizer: Optional[OptimizerType] = None,
202+
task_description: Optional[str] = None,
203+
config: Optional["ExperimentConfig"] = None,
200204
) -> "BaseOptimizer":
201205
"""Creates and returns an optimizer instance based on provided parameters.
202206
@@ -215,22 +219,18 @@ def get_optimizer(
215219
Raises:
216220
ValueError: If an unknown optimizer type is specified
217221
"""
218-
if optimizer is None:
219-
optimizer = config.optimizer
220-
if task_description is None:
221-
task_description = config.task_description
222-
if meta_prompt is None and hasattr(config, "meta_prompt"):
223-
meta_prompt = config.meta_prompt
224-
225-
if config.optimizer == "capo":
222+
final_optimizer = optimizer or (config.optimizer if config else None)
223+
final_task_description = task_description or (config.task_description if config else None)
224+
225+
if final_optimizer == "capo":
226226
crossover_template = (
227-
CAPO_CROSSOVER_TEMPLATE.replace("<task_desc>", task_description)
228-
if task_description
227+
CAPO_CROSSOVER_TEMPLATE.replace("<task_desc>", final_task_description)
228+
if final_task_description
229229
else CAPO_CROSSOVER_TEMPLATE
230230
)
231231
mutation_template = (
232-
CAPO_MUTATION_TEMPLATE.replace("<task_desc>", task_description)
233-
if task_description
232+
CAPO_MUTATION_TEMPLATE.replace("<task_desc>", final_task_description)
233+
if final_task_description
234234
else CAPO_MUTATION_TEMPLATE
235235
)
236236

@@ -243,27 +243,29 @@ def get_optimizer(
243243
config=config,
244244
)
245245

246-
if config.optimizer == "evopromptde":
246+
if final_optimizer == "evopromptde":
247247
template = (
248-
EVOPROMPT_DE_TEMPLATE_TD.replace("<task_desc>", task_description)
249-
if task_description
248+
EVOPROMPT_DE_TEMPLATE_TD.replace("<task_desc>", final_task_description)
249+
if final_task_description
250250
else EVOPROMPT_DE_TEMPLATE
251251
)
252252
return EvoPromptDE(predictor=predictor, meta_llm=meta_llm, task=task, prompt_template=template, config=config)
253253

254-
if config.optimizer == "evopromptga":
254+
if final_optimizer == "evopromptga":
255255
template = (
256-
EVOPROMPT_GA_TEMPLATE_TD.replace("<task_desc>", task_description)
257-
if task_description
256+
EVOPROMPT_GA_TEMPLATE_TD.replace("<task_desc>", final_task_description)
257+
if final_task_description
258258
else EVOPROMPT_GA_TEMPLATE
259259
)
260260
return EvoPromptGA(predictor=predictor, meta_llm=meta_llm, task=task, prompt_template=template, config=config)
261261

262-
if config.optimizer == "opro":
263-
template = OPRO_TEMPLATE_TD.replace("<task_desc>", task_description) if task_description else OPRO_TEMPLATE
262+
if final_optimizer == "opro":
263+
template = (
264+
OPRO_TEMPLATE_TD.replace("<task_desc>", final_task_description) if final_task_description else OPRO_TEMPLATE
265+
)
264266
return OPRO(predictor=predictor, meta_llm=meta_llm, task=task, prompt_template=template, config=config)
265267

266-
raise ValueError(f"Unknown optimizer: {config.optimizer}")
268+
raise ValueError(f"Unknown optimizer: {final_optimizer}")
267269

268270

269271
def get_exemplar_selector(

promptolution/llms/api_llm.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
"""Module to interface with various language models through their respective APIs."""
22

3-
43
try:
54
import asyncio
65

76
from openai import AsyncOpenAI
7+
from openai.types.chat import ChatCompletion, ChatCompletionMessageParam
88

99
import_successful = True
1010
except ImportError:
1111
import_successful = False
1212

1313

14-
from typing import TYPE_CHECKING, List
14+
from typing import TYPE_CHECKING, Dict, List, Optional
1515

1616
from promptolution.llms.base_llm import BaseLLM
1717

@@ -23,9 +23,21 @@
2323
logger = get_logger(__name__)
2424

2525

26-
async def _invoke_model(prompt, system_prompt, max_tokens, model_id, client, semaphore, max_retries=20, retry_delay=5):
26+
async def _invoke_model(
27+
prompt: str,
28+
system_prompt: str,
29+
max_tokens: int,
30+
model_id: str,
31+
client: AsyncOpenAI,
32+
semaphore: asyncio.Semaphore,
33+
max_retries: int = 20,
34+
retry_delay: float = 5,
35+
) -> ChatCompletion:
2736
async with semaphore:
28-
messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
37+
messages: List[ChatCompletionMessageParam] = [
38+
{"role": "system", "content": system_prompt},
39+
{"role": "user", "content": prompt},
40+
]
2941

3042
for attempt in range(max_retries + 1): # +1 for the initial attempt
3143
try:
@@ -46,7 +58,8 @@ async def _invoke_model(prompt, system_prompt, max_tokens, model_id, client, sem
4658
else:
4759
# Log the final failure and re-raise the exception
4860
logger.error(f"❌ API call failed after {max_retries + 1} attempts: {str(e)}")
49-
raise
61+
raise # Re-raise the exception after all retries fail
62+
raise RuntimeError("Failed to get response after multiple retries.")
5063

5164

5265
class APILLM(BaseLLM):
@@ -65,13 +78,13 @@ class APILLM(BaseLLM):
6578

6679
def __init__(
6780
self,
68-
api_url: str = None,
69-
model_id: str = None,
70-
api_key: str = None,
71-
max_concurrent_calls=50,
72-
max_tokens=512,
73-
config: "ExperimentConfig" = None,
74-
):
81+
api_url: Optional[str] = None,
82+
model_id: Optional[str] = None,
83+
api_key: Optional[str] = None,
84+
max_concurrent_calls: int = 50,
85+
max_tokens: int = 512,
86+
config: Optional["ExperimentConfig"] = None,
87+
) -> None:
7588
"""Initialize the APILLM with a specific model and API configuration.
7689
7790
Args:
@@ -103,14 +116,26 @@ def __init__(
103116

104117
def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
105118
# Setup for async execution in sync context
106-
loop = asyncio.get_event_loop()
119+
try:
120+
loop = asyncio.get_running_loop()
121+
except RuntimeError: # 'get_running_loop' raises a RuntimeError if there is no running loop
122+
loop = asyncio.new_event_loop()
123+
asyncio.set_event_loop(loop)
124+
107125
responses = loop.run_until_complete(self._get_response_async(prompts, system_prompts))
108126
return responses
109127

110128
async def _get_response_async(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
129+
assert self.model_id is not None, "model_id must be set"
111130
tasks = [
112131
_invoke_model(prompt, system_prompt, self.max_tokens, self.model_id, self.client, self.semaphore)
113132
for prompt, system_prompt in zip(prompts, system_prompts)
114133
]
115-
responses = await asyncio.gather(*tasks)
116-
return [response.choices[0].message.content for response in responses]
134+
messages = await asyncio.gather(*tasks)
135+
responses = []
136+
for message in messages:
137+
response = message.choices[0].message.content
138+
if response is None:
139+
raise ValueError("Received None response from the API.")
140+
responses.append(response)
141+
return responses

0 commit comments

Comments
 (0)