-
Notifications
You must be signed in to change notification settings - Fork 1
DeepFinance Enhancements #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bac05b5
ba41164
c7ca8c7
7f2b017
9dd3c42
757f8a1
079e4bd
bcce8f0
4662d63
de81c1d
248acc4
9d651fd
7475ecc
b95d491
f20ab91
ea87d4b
3082bca
ef44b63
0889483
db7114c
5a25550
623b7d9
0aaab86
04f4959
d0ff68b
1c356d7
37dcbcc
529ae7e
f4eb231
1e07515
08ba184
3d55692
a478827
88be3e4
fb41962
a1f909b
8d2e5d7
3c85960
9b541c5
06fda5f
63cc682
c9b87ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ def save_trajectory_as_json(ctx_trackers, global_steps, prefix="train"): | |
| # Define save directory and file path | ||
| traj_save_dir = os.path.join( | ||
| os.environ.get("BEST_LOGGER_PATH", "launcher_record"), | ||
| "ctx_trackers", | ||
| "trajectory", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| prefix, | ||
| f"step_{global_steps}" | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,6 +125,7 @@ def compute_tool_metrics(tool_stats_list: List[Dict[str, Any]], prefix: str = "" | |
| if calls > 0: | ||
| error_rate = errors / calls * 100 | ||
| metrics[f"{prefix}tool_error/{tool_name}/error_rate"] = round(error_rate, 2) | ||
| metrics[f"{prefix}tool_error/{tool_name}/calls"] = calls | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| return metrics | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,13 +96,15 @@ def pty_wrapper_final(human_cmd, dir, env_dict): | |
| pty_wrapper(["/bin/bash", "-c", human_cmd], dir, env_dict) | ||
|
|
||
|
|
||
| def pty_launch(service_name: str, success_std_string="Starting server on"): | ||
| def pty_launch(service_name: str, success_std_string="Starting server on", prefix: str=""): | ||
| from ajet.utils.smart_daemon import LaunchCommandWhenAbsent | ||
|
|
||
| service_path = os.environ.get(f"{service_name.upper()}_PATH") | ||
| service_script = os.environ.get(f"{service_name.upper()}_SCRIPT") | ||
| if service_path is None or service_script is None: | ||
| raise ValueError(f"Environment variables for {service_name} not properly set.") | ||
| if prefix != "": | ||
| service_name = prefix + "_" + service_name | ||
|
Comment on lines
+106
to
+107
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| companion = LaunchCommandWhenAbsent( | ||
| full_argument_list=[service_script], | ||
| dir=service_path, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -373,8 +373,12 @@ def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowO | |
| fused_reward, contributions = self._fuse_grader_scores(grader_scores, rm_raw) | ||
|
|
||
| # 6. 计算惩罚项(保留原有的 tool_calls 惩罚逻辑) | ||
| tool_calls = metadata.get("tool_stats", {}).get("total_calls", 0) | ||
| # 从 log_metrics 中提取 tool_stats(deep_finance.py 将其放在 log_metrics 而非 metadata) | ||
| tool_stats = workflow_output.log_metrics.get("tool_stats", {}) | ||
| tool_calls = tool_stats.get("total_calls", 0) | ||
|
Comment on lines
+377
to
+378
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| penalty = self._compute_penalty(tool_calls) | ||
| if penalty < 0: | ||
| print(f"⚠️ Penalty applied: penalty={penalty}, tool_calls={tool_stats}") | ||
|
Comment on lines
+380
to
+381
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # 7. 汇总 | ||
| final_reward = fused_reward + step_reward + penalty | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ ajet: | |
| rollout: | ||
| # ✨✨✨✨ 编写并选择Agent | ||
| user_workflow: tutorial.example_deep_finance.deep_finance->ExampleDeepResearchProtocol | ||
| force_disable_toolcalls: True | ||
| force_disable_toolcalls: False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| enable_oversample: False | ||
| tensor_model_parallel_size: 8 | ||
| num_repeat: {{NUM_REPEAT}} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passing the
prefixargument topty_launchfor thedeepfinanceservice correctly integrates the new prefixing functionality. This ensures that DeepFinance instances can be uniquely identified when launched.