From f824bc96eb2fed23b948b93ca0bdc6551425521b Mon Sep 17 00:00:00 2001 From: optimass Date: Fri, 10 Jan 2025 21:43:43 +0000 Subject: [PATCH 1/2] small API change - passing exp_root to study.run() --- src/agentlab/experiments/study.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index c04f1a74..f158295f 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -298,12 +298,13 @@ def run( strict_reproducibility=False, n_relaunch=3, relaunch_errors=True, + exp_root=RESULTS_DIR, ): self.set_reproducibility_info( strict_reproducibility=strict_reproducibility, comment=self.comment ) - self.save() + self.save(exp_root=exp_root) n_exp = len(self.exp_args_list) last_error_count = None From 53a395cf2811bdd19ae79cdd78942500e30a1851 Mon Sep 17 00:00:00 2001 From: optimass Date: Fri, 10 Jan 2025 21:59:06 +0000 Subject: [PATCH 2/2] same for SequetialStudy --- src/agentlab/experiments/study.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/agentlab/experiments/study.py b/src/agentlab/experiments/study.py index f158295f..195aea09 100644 --- a/src/agentlab/experiments/study.py +++ b/src/agentlab/experiments/study.py @@ -462,7 +462,14 @@ def find_incomplete(self, include_errors=True): for study in self.studies: study.find_incomplete(include_errors=include_errors) - def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_relaunch=3): + def run( + self, + n_jobs=1, + parallel_backend="ray", + strict_reproducibility=False, + n_relaunch=3, + exp_root=RESULTS_DIR, + ): # This sequence of of making directories is important to make sure objects are materialized # properly before saving. Otherwise relaunch may not work properly. @@ -470,7 +477,7 @@ def run(self, n_jobs=1, parallel_backend="ray", strict_reproducibility=False, n_ for study in self.studies: study.make_dir(exp_root=self.dir) - self.save() + self.save(exp_root=exp_root) self._run(n_jobs, parallel_backend, strict_reproducibility, n_relaunch) _, summary_df, _ = self.get_results() logger.info("\n" + str(summary_df))