11"""Helper functions for the usage of the libary."""
22
33
4- from typing import TYPE_CHECKING , Callable , List , Literal
4+ from typing import TYPE_CHECKING , Callable , List , Literal , Optional
55
66from promptolution .tasks .judge_tasks import JudgeTask
77from promptolution .tasks .reward_tasks import RewardTask
4545logger = get_logger (__name__ )
4646
4747
48- def run_experiment (df : pd .DataFrame , config : "ExperimentConfig" ):
48+ def run_experiment (df : pd .DataFrame , config : "ExperimentConfig" ) -> pd . DataFrame :
4949 """Run a full experiment based on the provided configuration.
5050
5151 Args:
@@ -79,7 +79,7 @@ def run_optimization(df: pd.DataFrame, config: "ExperimentConfig") -> List[str]:
7979 llm = get_llm (config = config )
8080 predictor = get_predictor (llm , config = config )
8181
82- config .task_description = config .task_description + " " + predictor .extraction_description
82+ config .task_description = ( config .task_description or "" ) + " " + ( predictor .extraction_description or "" )
8383 if config .optimizer == "capo" and (config .eval_strategy is None or "block" not in config .eval_strategy ):
8484 logger .warning ("📌 CAPO requires block evaluation strategy. Setting it to 'sequential_block'." )
8585 config .eval_strategy = "sequential_block"
@@ -126,7 +126,7 @@ def run_evaluation(df: pd.DataFrame, config: "ExperimentConfig", prompts: List[s
126126 return df
127127
128128
129- def get_llm (model_id : str = None , config : "ExperimentConfig" = None ) -> "BaseLLM" :
129+ def get_llm (model_id : Optional [ str ] = None , config : Optional [ "ExperimentConfig" ] = None ) -> "BaseLLM" :
130130 """Factory function to create and return a language model instance based on the provided model_id.
131131
132132 This function supports three types of language models:
@@ -144,16 +144,18 @@ def get_llm(model_id: str = None, config: "ExperimentConfig" = None) -> "BaseLLM
144144 Returns:
145145 An instance of LocalLLM, or APILLM based on the model_id.
146146 """
147- if model_id is None :
148- model_id = config .model_id
149- if "local" in model_id :
150- model_id = "-" .join (model_id .split ("-" )[1 :])
151- return LocalLLM (model_id , config )
152- if "vllm" in model_id :
153- model_id = "-" .join (model_id .split ("-" )[1 :])
154- return VLLM (model_id , config = config )
147+ final_model_id = model_id or (config .model_id if config else None )
148+ if not final_model_id :
149+ raise ValueError ("model_id must be provided either directly or through config." )
155150
156- return APILLM (model_id = model_id , config = config )
151+ if "local" in final_model_id :
152+ model_name = "-" .join (final_model_id .split ("-" )[1 :])
153+ return LocalLLM (model_name , config = config )
154+ if "vllm" in final_model_id :
155+ model_name = "-" .join (final_model_id .split ("-" )[1 :])
156+ return VLLM (model_name , config = config )
157+
158+ return APILLM (model_id = final_model_id , config = config )
157159
158160
159161def get_task (
@@ -174,16 +176,19 @@ def get_task(
174176 Returns:
175177 BaseTask: An instance of a task class based on the provided DataFrame and configuration.
176178 """
177- if task_type is None :
178- task_type = config .task_type
179+ final_task_type = task_type or (config .task_type if config else None )
179180
180- if task_type == "reward" :
181+ if final_task_type == "reward" :
182+ if reward_function is None :
183+ reward_function = config .reward_function if config else None
184+ assert reward_function is not None , "Reward function must be provided for reward tasks."
181185 return RewardTask (
182186 df = df ,
183187 reward_function = reward_function ,
184188 config = config ,
185189 )
186- elif task_type == "judge" :
190+ elif final_task_type == "judge" :
191+ assert judge_llm is not None , "Judge LLM must be provided for judge tasks."
187192 return JudgeTask (df , judge_llm = judge_llm , config = config )
188193
189194 return ClassificationTask (df , config = config )
@@ -193,10 +198,9 @@ def get_optimizer(
193198 predictor : "BasePredictor" ,
194199 meta_llm : "BaseLLM" ,
195200 task : "BaseTask" ,
196- optimizer : OptimizerType = None ,
197- meta_prompt : str = None ,
198- task_description : str = None ,
199- config : "ExperimentConfig" = None ,
201+ optimizer : Optional [OptimizerType ] = None ,
202+ task_description : Optional [str ] = None ,
203+ config : Optional ["ExperimentConfig" ] = None ,
200204) -> "BaseOptimizer" :
201205 """Creates and returns an optimizer instance based on provided parameters.
202206
@@ -215,22 +219,18 @@ def get_optimizer(
215219 Raises:
216220 ValueError: If an unknown optimizer type is specified
217221 """
218- if optimizer is None :
219- optimizer = config .optimizer
220- if task_description is None :
221- task_description = config .task_description
222- if meta_prompt is None and hasattr (config , "meta_prompt" ):
223- meta_prompt = config .meta_prompt
224-
225- if config .optimizer == "capo" :
222+ final_optimizer = optimizer or (config .optimizer if config else None )
223+ final_task_description = task_description or (config .task_description if config else None )
224+
225+ if final_optimizer == "capo" :
226226 crossover_template = (
227- CAPO_CROSSOVER_TEMPLATE .replace ("<task_desc>" , task_description )
228- if task_description
227+ CAPO_CROSSOVER_TEMPLATE .replace ("<task_desc>" , final_task_description )
228+ if final_task_description
229229 else CAPO_CROSSOVER_TEMPLATE
230230 )
231231 mutation_template = (
232- CAPO_MUTATION_TEMPLATE .replace ("<task_desc>" , task_description )
233- if task_description
232+ CAPO_MUTATION_TEMPLATE .replace ("<task_desc>" , final_task_description )
233+ if final_task_description
234234 else CAPO_MUTATION_TEMPLATE
235235 )
236236
@@ -243,27 +243,29 @@ def get_optimizer(
243243 config = config ,
244244 )
245245
246- if config . optimizer == "evopromptde" :
246+ if final_optimizer == "evopromptde" :
247247 template = (
248- EVOPROMPT_DE_TEMPLATE_TD .replace ("<task_desc>" , task_description )
249- if task_description
248+ EVOPROMPT_DE_TEMPLATE_TD .replace ("<task_desc>" , final_task_description )
249+ if final_task_description
250250 else EVOPROMPT_DE_TEMPLATE
251251 )
252252 return EvoPromptDE (predictor = predictor , meta_llm = meta_llm , task = task , prompt_template = template , config = config )
253253
254- if config . optimizer == "evopromptga" :
254+ if final_optimizer == "evopromptga" :
255255 template = (
256- EVOPROMPT_GA_TEMPLATE_TD .replace ("<task_desc>" , task_description )
257- if task_description
256+ EVOPROMPT_GA_TEMPLATE_TD .replace ("<task_desc>" , final_task_description )
257+ if final_task_description
258258 else EVOPROMPT_GA_TEMPLATE
259259 )
260260 return EvoPromptGA (predictor = predictor , meta_llm = meta_llm , task = task , prompt_template = template , config = config )
261261
262- if config .optimizer == "opro" :
263- template = OPRO_TEMPLATE_TD .replace ("<task_desc>" , task_description ) if task_description else OPRO_TEMPLATE
262+ if final_optimizer == "opro" :
263+ template = (
264+ OPRO_TEMPLATE_TD .replace ("<task_desc>" , final_task_description ) if final_task_description else OPRO_TEMPLATE
265+ )
264266 return OPRO (predictor = predictor , meta_llm = meta_llm , task = task , prompt_template = template , config = config )
265267
266- raise ValueError (f"Unknown optimizer: { config . optimizer } " )
268+ raise ValueError (f"Unknown optimizer: { final_optimizer } " )
267269
268270
269271def get_exemplar_selector (
0 commit comments