diff --git a/aworld/agents/llm_agent.py b/aworld/agents/llm_agent.py
index 8a71b7f68..a43a7b61c 100644
--- a/aworld/agents/llm_agent.py
+++ b/aworld/agents/llm_agent.py
@@ -252,6 +252,45 @@ def messages_transform(self,
return sync_exec(self.async_messages_transform, image_urls=image_urls, observation=observation,
message=message, **kwargs)
+ def _is_amni_context(self, context: Context):
+ from aworld.core.context.amni import AmniContext
+ return isinstance(context, AmniContext)
+
+ def _build_memory_filters(self, context: Context, additional_filters: Dict[str, Any] = None) -> Dict[str, Any]:
+ filters = {
+ "agent_id": self.id()
+ }
+
+ # Decide which filter to add based on history_scope
+ agent_memory_config = self.memory_config
+ if self._is_amni_context(context):
+ agent_context_config = context.get_config().get_agent_context_config(self.id())
+ agent_memory_config = agent_context_config.to_memory_config()
+
+ query_scope = agent_memory_config.history_scope if agent_memory_config and agent_memory_config.history_scope else "task"
+ task = context.get_task()
+
+ if query_scope == "user":
+ # Pass user_id when query_scope is user
+ if hasattr(context, 'user_id') and context.user_id:
+ filters["user_id"] = context.user_id
+ elif hasattr(task, 'user_id') and task.user_id:
+ filters["user_id"] = task.user_id
+ elif query_scope == "session":
+ # Pass session_id when query_scope is session
+ if task and task.session_id:
+ filters["session_id"] = task.session_id
+ else: # query_scope == "task" or default
+ # Pass task_id when query_scope is task
+ if task and task.id:
+ filters["task_id"] = task.id
+
+ # Add additional filter conditions
+ if additional_filters:
+ filters.update(additional_filters)
+
+ return filters
+
def _clean_redundant_tool_call_messages(self, histories: List[MemoryItem]) -> None:
try:
for i in range(len(histories) - 1, -1, -1):
@@ -269,14 +308,8 @@ def postprocess_terminate_loop(self, message: Message):
logger.info(f"Agent {self.id()} postprocess_terminate_loop: {self.loop_step}")
super().postprocess_terminate_loop(message)
try:
- session_id = message.context.get_task().session_id
- task_id = message.context.get_task().id
- histories = self.memory.get_all(filters={
- "agent_id": self.id(),
- "session_id": session_id,
- "task_id": task_id,
- "memory_type": "message"
- })
+ filters = self._build_memory_filters(message.context, additional_filters={"memory_type": "message"})
+ histories = self.memory.get_all(filters=filters)
self._clean_redundant_tool_call_messages(histories)
except Exception:
logger.error(f"Agent {self.id()} postprocess_terminate_loop error: {traceback.format_exc()}")
@@ -304,14 +337,8 @@ async def async_messages_transform(self,
if self.system_prompt:
await self._add_message_to_memory(context=message.context, payload=content, message_type=MemoryType.SYSTEM)
- session_id = message.context.get_task().session_id
- task_id = message.context.get_task().id
- histories = self.memory.get_all(filters={
- "agent_id": self.id(),
- "session_id": session_id,
- "task_id": task_id,
- "memory_type": "message"
- })
+ filters = self._build_memory_filters(message.context, additional_filters={"memory_type": "message"})
+ histories = self.memory.get_all(filters=filters)
# append observation to memory
tool_result_added = False
@@ -333,11 +360,12 @@ async def async_messages_transform(self,
context=message.context)
# from memory get last n messages
- histories = self.memory.get_last_n(self.memory_config.history_rounds, filters={
- "agent_id": self.id(),
- "session_id": session_id,
- "task_id": task_id
- }, agent_memory_config=self.memory_config)
+ filters = self._build_memory_filters(message.context)
+ agent_memory_config = self.memory_config
+ if self._is_amni_context(message.context):
+ agent_context_config = message.context.get_config().get_agent_context_config(self.id())
+ agent_memory_config = agent_context_config.to_memory_config()
+ histories = self.memory.get_last_n(agent_memory_config.history_rounds, filters=filters, agent_memory_config=agent_memory_config)
if histories:
tool_calls_map = {}
last_tool_calls = []
@@ -841,12 +869,8 @@ async def _add_tool_result_token_ids_to_context(self, context: Context):
"""Add tool result token ids to context"""
if context.get_task().conf.get("run_mode") != TaskRunMode.INTERACTIVE:
return
- histories = self.memory.get_all(filters={
- "agent_id": self.id(),
- "session_id": context.get_task().session_id,
- "task_id": context.get_task().id,
- "memory_type": "message"
- })
+ filters = self._build_memory_filters(context, additional_filters={"memory_type": "message"})
+ histories = self.memory.get_all(filters=filters)
tool_openai_messages_after_last_assistant = []
found_assistant = False
tool_call_ids = []
diff --git a/aworld/config/conf.py b/aworld/config/conf.py
index 7787be541..a22c4674f 100644
--- a/aworld/config/conf.py
+++ b/aworld/config/conf.py
@@ -188,7 +188,8 @@ class AgentMemoryConfig(BaseConfig):
description="rounds of message msg; when the number of messages is greater than the history_rounds, the memory will be trimmed")
history_write_strategy: HistoryWriteStrategy = Field(default=HistoryWriteStrategy.EVENT_DRIVEN,
description="History write strategy: event_driven (through message system) or direct (direct call to handler)")
-
+ history_scope: Optional[str] = Field(default="task", description="History initialization scope: user, session, or task")
+
enable_summary: bool = Field(default=False,
description="enable_summary use llm to create summary short-term memory")
summary_model: Optional[str] = Field(default=None, description="short-term summary model")
diff --git a/aworld/core/context/amni/config.py b/aworld/core/context/amni/config.py
index d79c94176..037ff550f 100644
--- a/aworld/core/context/amni/config.py
+++ b/aworld/core/context/amni/config.py
@@ -90,6 +90,7 @@ class AgentContextConfig(BaseConfig):
description="rounds of message msg; when the number of messages is greater than the history_rounds, the memory will be trimmed")
history_write_strategy: HistoryWriteStrategy = Field(default=HistoryWriteStrategy.EVENT_DRIVEN,
description="History write strategy: event_driven (through message system) or direct (direct call to handler)")
+ history_scope: Optional[str] = Field(default="task", description="History initialization scope: user, session, or task")
# Context Reduce - Compress
enable_summary: bool = Field(default=False,
@@ -118,6 +119,7 @@ def to_memory_config(self) -> AgentMemoryConfig:
return AgentMemoryConfig(
history_rounds=self.history_rounds,
history_write_strategy=self.history_write_strategy,
+ history_scope=self.history_scope,
enable_summary=self.enable_summary,
summary_rounds=self.summary_rounds,
summary_context_length=self.summary_context_length,
diff --git a/aworld/core/tool/base.py b/aworld/core/tool/base.py
index e21b44bbe..8d72f1abb 100644
--- a/aworld/core/tool/base.py
+++ b/aworld/core/tool/base.py
@@ -501,6 +501,15 @@ async def post_step(self,
headers={"context": context})
return result
+ # tool hooks
+ try:
+ events = []
+ async for event in run_hooks(context=message.context, hook_point=HookPoint.POST_TOOL_CALL, hook_from=result.caller, payload=step_res):
+ events.append(event)
+ except Exception:
+ logger.debug(traceback.format_exc())
+ return result
+
async def _exec_tool_callback(self, step_res: Tuple[Observation, float, bool, bool, Dict[str, Any]],
action: List[ActionModel],
message: Message,
diff --git a/aworld/dataset/trajectory_strategy.py b/aworld/dataset/trajectory_strategy.py
index 2d0ca2856..e7a5ccda2 100644
--- a/aworld/dataset/trajectory_strategy.py
+++ b/aworld/dataset/trajectory_strategy.py
@@ -312,7 +312,7 @@ async def generate_trajectory_for_memory(self, swarm: Swarm, context: Context):
}, agent_memory_config=swarm.cur_agent[0].memory_config)
# Convert memory items to OpenAI message format
- result = {}
+ result = []
for i, item in enumerate(memory_items):
# Check if item has to_openai_message method
if hasattr(item, 'to_openai_message'):
@@ -320,10 +320,10 @@ async def generate_trajectory_for_memory(self, swarm: Swarm, context: Context):
# Add usage to the message if it exists in metadata
if hasattr(item, 'metadata') and item.metadata and 'usage' in item.metadata:
message['usage'] = item.metadata['usage']
- result[i] = message
+ result.append(message)
else:
# If item doesn't have to_openai_message, return the item as is
- result[i] = item
+ result.append(item)
return result
diff --git a/aworld/evaluations/scorers/llm_as_judge.py b/aworld/evaluations/scorers/llm_as_judge.py
index f865aabb9..3264baadd 100644
--- a/aworld/evaluations/scorers/llm_as_judge.py
+++ b/aworld/evaluations/scorers/llm_as_judge.py
@@ -60,7 +60,7 @@ def build_judge_prompt(self, index: int, input: EvalDataCase[EvalCaseDataType],
raise NotImplementedError("build_judge_prompt must be implemented in subclasses")
@abc.abstractmethod
- def build_judge_data(self, index: int, input: EvalDataCase[EvalCaseDataType], output: dict) -> str:
+ def build_judge_data(self, index: int, input: EvalDataCase[EvalCaseDataType], output: dict) -> [str, dict]:
"""Builds the input for the judge agent task.
Args:
@@ -72,7 +72,7 @@ def build_judge_data(self, index: int, input: EvalDataCase[EvalCaseDataType], ou
str: The input string for the judge agent task.
Example:
- [Question]: {input.case_data.get('question', '')}
+ [疑问]: {input.case_data.get('question', '')}
[Correct_Answer]: {input.case_data.get('answer', '')}
[Response]: {output.get('answer', '')}
"""
@@ -106,6 +106,8 @@ async def score(self, index: int, input: EvalDataCase[EvalCaseDataType], output:
agent_prompt=self.build_judge_prompt(index=index, input=input, output=output))
task_input = self.build_judge_data(index=index, input=input, output=output)
+ if not task_input:
+ return ScorerResult(scorer_name=self.name, metric_results={})
response = await exec_agent(task_input, agent=score_agent, context=Context())
metric_results = self.convert_judge_response_to_score(response.answer)
if metric_results:
@@ -130,4 +132,4 @@ def _build_judge_system_prompt(self) -> str:
"""
return '''
You are a judge model that evaluates the quality of the response.
- '''
+ '''
\ No newline at end of file
diff --git a/aworld/evaluations/scorers/metrics.py b/aworld/evaluations/scorers/metrics.py
index e9596b418..a596a6f91 100644
--- a/aworld/evaluations/scorers/metrics.py
+++ b/aworld/evaluations/scorers/metrics.py
@@ -2,4 +2,5 @@ class MetricNames:
LABEL_DISTRIBUTION = 'label_distribution'
SUMMARIZE_QUALITY = 'summarize_quality'
ANSWER_ACCURACY = 'answer_accuracy'
- PREDICT_TIME_COST_MS = 'predict_time_cost_ms'
\ No newline at end of file
+ PREDICT_TIME_COST_MS = 'predict_time_cost_ms'
+ FLIGHT_JUDGE = 'flight_judge'
\ No newline at end of file
diff --git a/aworld/output/workspace.py b/aworld/output/workspace.py
index 94d781d08..98321afa9 100644
--- a/aworld/output/workspace.py
+++ b/aworld/output/workspace.py
@@ -495,6 +495,23 @@ def save(self) -> None:
self.repository.save_index(workspace_data)
self._rebuild_artifact_id_index()
+ def get_raw_file_content_by_artifact_id(self, artifact_id: str) -> str:
+ """
+ Get concatenated content of all artifacts with the same filename.
+
+ Args:
+ artifact_id: artifact_id
+
+ Returns:
+ Raw unescaped concatenated content of all matching artifacts
+ """
+ filename = artifact_id
+ artifact_data = self.repository.retrieve_latest_artifact(artifact_id)
+ if not artifact_data:
+ return ""
+ artifact = Artifact.from_dict(artifact_data)
+ return artifact.content
+
def get_file_content_by_artifact_id(self, artifact_id: str) -> str:
"""
Get concatenated content of all artifacts with the same filename.
diff --git a/aworld/runners/handler/memory.py b/aworld/runners/handler/memory.py
index 891d0b70a..fc332fd8c 100644
--- a/aworld/runners/handler/memory.py
+++ b/aworld/runners/handler/memory.py
@@ -1,22 +1,21 @@
# aworld/runners/handler/output.py
import copy
-import json
import time
import traceback
from datetime import datetime
-from typing import AsyncGenerator, Any
+from typing import Any
from aworld.agents.llm_agent import Agent
from aworld.config import ConfigDict
+from aworld.core.common import ActionResult
from aworld.core.context.base import Context
+from aworld.core.event.base import Message, Constants, MemoryEventMessage, MemoryEventType
+from aworld.logs.util import logger
from aworld.memory.main import MemoryFactory
from aworld.memory.models import MemoryToolMessage, MessageMetadata, MemoryHumanMessage, MemorySystemMessage, \
MemoryAIMessage
from aworld.runners import HandlerFactory
from aworld.runners.handler.base import DefaultHandler
-from aworld.core.common import TaskItem, ActionResult
-from aworld.core.event.base import Message, Constants, TopicType, MemoryEventMessage, MemoryEventType
-from aworld.logs.util import logger
from aworld.runners.hook.hook_factory import HookFactory
@@ -182,7 +181,7 @@ async def _add_llm_response_to_memory(self, agent: Agent, llm_response, context:
"""Add LLM response to memory"""
# Get start time from context (if exists)
start_time = context.context_info.get("llm_call_start_time")
-
+
ai_message = MemoryAIMessage(
content=llm_response.content,
tool_calls=llm_response.tool_calls,
@@ -197,13 +196,13 @@ async def _add_llm_response_to_memory(self, agent: Agent, llm_response, context:
}
)
)
-
+
# If start time exists in context, update it
if start_time:
ai_message.start_time = start_time
# Record message end time
ai_message.end_time = None
-
+
agent_memory_config = agent.memory_config
if self._is_amni_context(context):
agent_memory_config = context.get_config().get_agent_memory_config(agent.id())
@@ -283,10 +282,10 @@ async def _do_add_tool_result_to_memory(self, agent: 'Agent', tool_call_id: str,
tool_use_summary = None
if isinstance(tool_result, ActionResult):
tool_use_summary = tool_result.metadata.get("tool_use_summary")
-
+
# Get start time from context (if exists)
start_time = context.context_info.get(f"tool_call_start_time_{tool_call_id}")
-
+
tool_message = MemoryToolMessage(
content=tool_result.content if hasattr(tool_result, 'content') else tool_result,
tool_call_id=tool_call_id,
@@ -301,14 +300,14 @@ async def _do_add_tool_result_to_memory(self, agent: 'Agent', tool_call_id: str,
ext_info={"tool_name": tool_result.tool_name, "action_name": tool_result.action_name}
)
)
-
+
# If start time exists in context, update it
if start_time:
tool_message.start_time = start_time
-
+
# Record message end time
tool_message.end_time = None
-
+
await memory.add(tool_message, agent_memory_config=agent.memory_config)
def _is_amni_context(self, context: Context):
@@ -318,7 +317,7 @@ def _is_amni_context(self, context: Context):
@staticmethod
async def handle_memory_message_directly(memory_msg: MemoryEventMessage, context: Context):
"""Handle memory message directly without going through message system
-
+
Args:
memory_msg: Memory event message
context: Context object
@@ -329,7 +328,7 @@ class SimpleRunner:
def __init__(self, task):
self.task = task
self.start_time = 0
-
+
task = context.get_task()
simple_runner = SimpleRunner(task)
handler = DefaultMemoryHandler(simple_runner)
diff --git a/aworld/runners/hook/hooks.py b/aworld/runners/hook/hooks.py
index 5a74d0f3a..567a2d664 100644
--- a/aworld/runners/hook/hooks.py
+++ b/aworld/runners/hook/hooks.py
@@ -13,6 +13,7 @@ class HookPoint:
ERROR = "error"
PRE_LLM_CALL = "pre_llm_call"
POST_LLM_CALL = "post_llm_call"
+ POST_TOOL_CALL = "post_tool_call"
OUTPUT_PROCESS = "output_process"
PRE_TOOL_CALL = "pre_tool_call"
POST_TOOL_CALL = "post_tool_call"
diff --git a/train/adapter/common.py b/train/adapter/common.py
index 95c001a5e..a653f6c05 100644
--- a/train/adapter/common.py
+++ b/train/adapter/common.py
@@ -81,6 +81,25 @@ async def encode_messages(tokenizer: AutoTokenizer,
chat_list.append(messages[i])
i += 1
continue
+
+ # summary message
+ if messages[i].get("memory_type") == "summary":
+ chat_list.append(messages[i])
+ cur_response_ids = await loop.run_in_executor(
+ None,
+ lambda: tokenizer.apply_chat_template(
+ chat_list,
+ add_generation_prompt=False,
+ tokenize=True,
+ chat_template=chat_template
+ ),
+ )
+ chat_list = []
+ response_ids += cur_response_ids
+ response_mask += [1] * len(cur_response_ids)
+ i += 1
+ continue
+
# initial chat completion
if messages[i].get("role") == "user":
if i == 0 or messages[i - 1].get("role") == "system":
@@ -141,6 +160,7 @@ async def encode_messages(tokenizer: AutoTokenizer,
chat_template=chat_template
),
)
+ # append tool_response
while i < len(messages) and messages[i].get("role") == "tool":
chat_list.append(messages[i])
i += 1
@@ -153,10 +173,48 @@ async def encode_messages(tokenizer: AutoTokenizer,
chat_template=chat_template
),
)
+ # append last tool message's usage
+ if i < len(messages) and messages[i - 1].get("role") == "tool":
+ tool_message = dict(messages[i - 1])
+ # 将usage信息拼接到工具消息的content末尾
+ if "usage" in tool_message:
+ usage_chat_template = """
+{%- for message in messages -%}
+ {%- if message.role == "tool" -%}
+ {%- if message.usage is defined -%}
+ {%- if message.usage is mapping -%}
+ {{- '' ~ (message.usage | tojson | safe) ~ '\n' -}}
+ {%- else -%}
+ {{- '' ~ (message.usage | string) ~ '\n' -}}
+ {%- endif -%}
+ {%- endif -%}
+ {%- else -%}
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}
+ {%- endif -%}
+{%- endfor -%}
+{%- if add_generation_prompt -%}
+ {{- '<|im_start|>assistant\n' -}}
+{%- endif -%}
+ """
+ usage_token = await loop.run_in_executor(
+ None,
+ lambda: tokenizer.apply_chat_template(
+ [tool_message],
+ add_generation_prompt=False,
+ tokenize=True,
+ chat_template=usage_chat_template
+ ),
+ )
+
+ print("长度", len(response_ids), ' ', len(token_assistant), ' ', len(token_assistant_tool), ' ', len(usage_token))
+ print("response_ids ", tokenizer.decode(token_assistant_tool))
+ print('usage_ids ', tokenizer.decode(usage_token))
tool_response_ids = token_assistant_tool[len(token_assistant):]
chat_list = []
- response_ids += tool_response_ids
- response_mask += [0] * len(tool_response_ids)
+ response_ids += tool_response_ids + usage_token
+ # 如果有usage,则response_mask长度应当减去usage的长度
+ response_mask += [0] * (len(token_assistant_tool) - len(token_assistant))
+ response_mask += [1] * (len(usage_token))
except Exception as e:
raise Exception(f"Failed to convert messages to agentloop_output: {messages}. {traceback.format_exc()}")
diff --git a/train/adapter/verl/agent_template.py b/train/adapter/verl/agent_template.py
index 761b433a0..b6c78d0e4 100644
--- a/train/adapter/verl/agent_template.py
+++ b/train/adapter/verl/agent_template.py
@@ -8,8 +8,11 @@
from aworld.agents.llm_agent import Agent
from aworld.config import AgentConfig, ConfigDict
from aworld.core.agent.swarm import Swarm
+from aworld.core.context.base import Context
from aworld.logs.util import logger
from {parser_module} import {parser_name}
+from aworld.config import BaseConfig, ConfigDict, load_config, TaskConfig
+from aworld.core.context.amni import AmniContextConfig, AgentContextConfig, ApplicationContext
{agent_import_str}
{tool_aggregate_func_import_str}
@@ -17,6 +20,14 @@
class VerlAgentLoop(AworldAgentLoop):
+ async def build_context(self, input: Any) -> Context:
+ return await ApplicationContext.from_input(task_input=input, context_config=AmniContextConfig({context_config}))
+
+ async def build_task_config(self) -> TaskConfig:
+ return TaskConfig(
+ {task_config}
+ )
+
async def build_agents(self) -> Union[Agent, Swarm]:
conf = AgentConfig(
llm_config=ConfigDict(
diff --git a/train/adapter/verl/aworld_agent_loop.py b/train/adapter/verl/aworld_agent_loop.py
index 23a8bbedb..e0c2479b7 100644
--- a/train/adapter/verl/aworld_agent_loop.py
+++ b/train/adapter/verl/aworld_agent_loop.py
@@ -9,19 +9,20 @@
import uuid
from typing import Any, List, Dict, Union, Sequence
+from train.adapter.verl.utils import build_task
+from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput
+
from aworld.agents.llm_agent import Agent
+from aworld.config import TaskConfig
from aworld.config.agent_loader import _load_yaml
from aworld.core.agent.swarm import Swarm
from aworld.core.task import TaskResponse, Task
-from aworld.runner import Runners
+from aworld.dataset.trajectory_strategy import MemoryTrajectoryStrategy
from aworld.logs.util import logger
-
-from verl.experimental.agent_loop.agent_loop import AgentLoopBase, AgentLoopOutput, AgentLoopMetrics
-
+from aworld.runner import Runners
from aworld.trace.base import Span
from aworld.trace.span_cosumer import register_span_consumer, SpanConsumer
from train.adapter.common import encode_messages, turns_num
-from train.adapter.verl.verl_provider import VerlProvider
@register_span_consumer()
@@ -80,17 +81,26 @@ async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutpu
elapsed_time = end_time - start_time
logger.warning(f"######## trajectory finish, time costs {elapsed_time:.2f} s ########\n")
- logger.warning(f"######## res[-1]['exp_data']: {res[-1]['exp_data']} ########\n")
- logger.warning(f"######## res[-1]['exp_data']['actions']: {res[-1]['exp_data']['actions']} ########\n")
- logger.warning(f"######## res[-1]['exp_data']['messages']: {res[-1]['exp_data']['messages']} ########\n")
-
# build agent loop output
- output = await self.convert_agent_output(trajectory=res)
+ task_config = await self.build_task_config()
+ if task_config.trajectory_strategy == MemoryTrajectoryStrategy:
+ output = await self.convert_memory_trajectory_agent_output(trajectory=res)
+ else:
+ output = await self.convert_agent_output(trajectory=res)
if hasattr(result, 'id'):
output.extra_fields['task_id'] = result.id
return output
+ async def build_context_config(self):
+ return None
+
+ async def build_task_config(self) -> TaskConfig:
+ return TaskConfig(
+ stream=False,
+ exit_on_failure=True
+ )
+
async def run_agents(self, input: Any, agent: Union[Agent, Swarm]):
async def run(task: Task):
@@ -105,7 +115,9 @@ async def run(task: Task):
if isinstance(input, dict):
input = input.get("content", "")
- task = Task(id=str(uuid.uuid4()), input=input, timeout=1200, agent=agent)
+ task = await build_task(user_input=input, target=agent, timeout=1200,
+ context_config=await self.build_context_config(),
+ task_config=await self.build_task_config())
resp = TaskResponse(id=task.id, trajectory=[{
"exp_meta": {
"task_id": "timeout_default",
@@ -191,6 +203,11 @@ async def get_agent_tool_config(self, config_path: str) -> Dict[str, Any]:
def get_num_turns(self, trajectory: List[Dict[str, Any]]):
return len(trajectory)
+ async def convert_memory_trajectory_agent_output(self, trajectory: List[Any], chat_template: str = None) -> AgentLoopOutput:
+ logger.warning(f"######## res: {trajectory} ########\n")
+
+ return await self.to_agent_loop_output(trajectory, chat_template=chat_template)
+
async def convert_agent_output(self, trajectory: List[Dict[str, Any]]) -> AgentLoopOutput:
"""Convert trajectory to AgentLoopOutput.
@@ -204,6 +221,10 @@ async def convert_agent_output(self, trajectory: List[Dict[str, Any]]) -> AgentL
if not trajectory:
raise Exception("Trajectory is empty")
+ logger.warning(f"######## res[-1]['exp_data']: {trajectory[-1]['exp_data']} ########\n")
+ logger.warning(f"######## res[-1]['exp_data']['actions']: {trajectory[-1]['exp_data']['actions']} ########\n")
+ logger.warning(f"######## res[-1]['exp_data']['messages']: {trajectory[-1]['exp_data']['messages']} ########\n")
+
num_turns = self.get_num_turns(trajectory)
messages = trajectory[-1].get("exp_data", {}).get("messages", [])
if not messages:
@@ -251,7 +272,7 @@ async def convert_agent_output(self, trajectory: List[Dict[str, Any]]) -> AgentL
output = await self.to_agent_loop_output(messages=messages)
return output
- async def to_agent_loop_output(self, messages: List[Dict[str, Any]]) -> AgentLoopOutput:
+ async def to_agent_loop_output(self, messages: List[Dict[str, Any]], chat_template = None) -> AgentLoopOutput:
"""Convert messages to AgentLoopOutput.
Args:
@@ -266,7 +287,8 @@ async def to_agent_loop_output(self, messages: List[Dict[str, Any]]) -> AgentLoo
prompt_ids, response_ids, response_mask = await encode_messages(self.tokenizer,
messages,
response_length=response_length,
- tools=self.agent.tools)
+ tools=self.agent.tools,
+ chat_template=chat_template)
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids,
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/gaia.py b/train/adapter/verl/utils.py
similarity index 68%
rename from train/examples/train_gaia_with_aworld_verl/rollout/gaia.py
rename to train/adapter/verl/utils.py
index 8010dcec8..4f3a106e5 100644
--- a/train/examples/train_gaia_with_aworld_verl/rollout/gaia.py
+++ b/train/adapter/verl/utils.py
@@ -9,28 +9,23 @@
from aworld.config.conf import HistoryWriteStrategy
from aworld.core.agent.swarm import Swarm
from aworld.core.context.amni import TaskInput, ApplicationContext
+from aworld.core.context.amni.config import get_default_config, init_middlewares, AgentContextConfig, AmniContextConfig
from aworld.core.context.amni.config import get_default_config, init_middlewares, AgentContextConfig
+from aworld.core.context.base import Context
from aworld.core.task import Task
from aworld.dataset.trajectory_strategy import MemoryTrajectoryStrategy
from aworld.logs.util import logger
# from train.adapter.verl.aworld_agent_loop import AworldAgentLoop
from aworld.memory.main import AWORLD_MEMORY_EXTRACT_NEW_SUMMARY
-# Import from prompts module directly to avoid circular import
+# Import from prompts module inside functions to avoid circular import
# (rollout/__init__.py imports this file at the top)
-from train.examples.train_gaia_with_aworld_verl.rollout.prompts import (
- GAIA_SYSTEM_PROMPT,
- episode_memory_summary_rule,
- working_memory_summary_rule,
- working_memory_summary_schema,
- tool_memory_summary_rule,
- tool_memory_summary_schema,
- episode_memory_summary_schema,
-)
def is_summary():
return os.getenv("GAIA_AGENT_CONTEXT", 'common') == 'amni'
-def build_gaia_agent(llm_model_name, llm_base_url, llm_api_key, mcp_config, server_manager = None, tokenizer = None):
+def build_context_aware_agent(llm_model_name, llm_base_url, llm_api_key, mcp_config, llm_provider = "openai", server_manager = None, tokenizer = None):
+ # Import here to avoid circular import
+ from train.examples.train_gaia_with_aworld_verl.rollout.prompts import GAIA_SYSTEM_PROMPT
# init middlewares
init_middlewares()
@@ -41,7 +36,7 @@ def build_gaia_agent(llm_model_name, llm_base_url, llm_api_key, mcp_config, serv
llm_model_name=llm_model_name,
llm_base_url=llm_base_url,
llm_api_key=llm_api_key,
- llm_provider="openai",
+ llm_provider=llm_provider,
llm_temperature=1.0,
top_k=20,
timeout=7200,
@@ -67,9 +62,24 @@ def build_gaia_agent(llm_model_name, llm_base_url, llm_api_key, mcp_config, serv
mcp_servers=list(server_name for server_name in mcp_config.get("mcpServers", {}).keys())
)
+def build_context_aware_task_config() -> TaskConfig:
+ return TaskConfig(
+ stream=False,
+ exit_on_failure=True,
+ trajectory_strategy=MemoryTrajectoryStrategy
+ )
-
-async def build_gaia_task(user_input: str, target: [Agent, Swarm], timeout, session_id: str = None, task_id: str = None):
+def build_context_config() -> AmniContextConfig:
+ # Import here to avoid circular import
+ from train.examples.train_gaia_with_aworld_verl.rollout.prompts import (
+ episode_memory_summary_rule,
+ working_memory_summary_rule,
+ working_memory_summary_schema,
+ tool_memory_summary_rule,
+ tool_memory_summary_schema,
+ episode_memory_summary_schema,
+ )
+
# 1. init middlewares
init_middlewares()
@@ -79,6 +89,7 @@ async def build_gaia_task(user_input: str, target: [Agent, Swarm], timeout, sess
if is_summary():
context_config.agent_config = AgentContextConfig(
history_rounds= 100,
+ history_write_strategy= HistoryWriteStrategy.DIRECT,
enable_summary= True,
summary_rounds= 30,
summary_context_length= 40960,
@@ -96,6 +107,22 @@ async def build_gaia_task(user_input: str, target: [Agent, Swarm], timeout, sess
],
)
+ # debug_mode
+ debug_mode = os.getenv("CONTEXT_DEBUG_MODE", "false").lower() in ("true", "1", "yes")
+ context_config.debug_mode = debug_mode
+
+ return context_config
+
+async def build_task(user_input: str, target: [Agent, Swarm], timeout, context_config: Context = None, task_config: TaskConfig = None, session_id: str = None, task_id: str = None):
+ if not context_config:
+ context_config = build_context_config()
+ if not task_config:
+ task_config = TaskConfig(
+ stream=False,
+ exit_on_failure=True,
+ trajectory_strategy=MemoryTrajectoryStrategy
+ )
+
# 3. build context
if not session_id:
session_id = f"session_{datetime.now().strftime('%Y%m%d%H%M%S')}"
@@ -120,15 +147,11 @@ async def build_gaia_task(user_input: str, target: [Agent, Swarm], timeout, sess
id=context.task_id,
user_id=context.user_id,
session_id=context.session_id,
- input=context.task_input,
+ input=task_input.task_content,
endless_threshold=5,
swarm=swarm,
context=context,
- conf=TaskConfig(
- stream=False,
- exit_on_failure=True,
- trajectory_strategy=MemoryTrajectoryStrategy
- ),
+ conf=task_config,
timeout=timeout
)
else:
@@ -137,15 +160,11 @@ async def build_gaia_task(user_input: str, target: [Agent, Swarm], timeout, sess
id=context.task_id,
user_id=context.user_id,
session_id=context.session_id,
- input=context.task_input,
+ input=task_input.task_content,
endless_threshold=5,
agent=target,
context=context,
- conf=TaskConfig(
- stream=False,
- exit_on_failure=True,
- trajectory_strategy=MemoryTrajectoryStrategy
- ),
+ conf=task_config,
timeout=timeout
)
diff --git a/train/adapter/verl/verl_trainer.py b/train/adapter/verl/verl_trainer.py
index d9d1985e8..c076f7e97 100644
--- a/train/adapter/verl/verl_trainer.py
+++ b/train/adapter/verl/verl_trainer.py
@@ -112,7 +112,7 @@ def check_reward(self, reward_func: Union[str, Callable[..., float]] = None) ->
logger.info(f"View reward function in file: {reward_file_path}, name is: {self.reward_file_path}")
return reward_file_path, reward_func.__name__
- def check_agent(self, agent: Union[str, Agent]) -> str:
+ def check_agent(self, agent: Union[str, Agent], context_config, task_config) -> str:
"""Check single agent instance, and create agent loop dynamically.
NOTE: Single-agent only now, Swarm to be added in the future.
@@ -179,6 +179,8 @@ def check_agent(self, agent: Union[str, Agent]) -> str:
black_tool_actions=agent.black_tool_actions,
skill_configs=agent.skill_configs,
event_handler_name=agent.event_handler_name,
+ context_config=context_config,
+ task_config=task_config,
tool_aggregate_func_import_str=func_str,
tools_aggregate_func=func_name,
parser_module=type(agent.model_output_parser).__module__,
@@ -194,8 +196,8 @@ def check_agent(self, agent: Union[str, Agent]) -> str:
# VeRL agent config file
module = module.replace(os.getcwd(), '').replace('/', '.')
module = module[1:] if module[0] == '.' else module
- con = f"""- name: {agent.name()}
- _target_: {module}.VerlAgentLoop
+ con = f"""- name: gaia_agent
+ _target_: train.examples.train_gaia_with_aworld_verl.rollout.verl_contextaware_agent_loop.VerlAgentLoop
"""
agent_yaml = f"{self.run_path}/agent.yaml"
diff --git a/train/examples/train_gaia_with_aworld_verl/context_train.py b/train/examples/train_gaia_with_aworld_verl/context_train.py
new file mode 100644
index 000000000..35900c31a
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/context_train.py
@@ -0,0 +1,54 @@
+# coding: utf-8
+# Copyright (c) 2025 inclusionAI.
+import asyncio
+import os
+
+from dotenv import load_dotenv
+
+from aworld.config import TaskConfig
+from train.adapter.verl.utils import build_context_aware_task_config
+from train.examples.train_gaia_with_aworld_verl.mcp_tools import build_mcp_config
+from train.examples.train_gaia_with_aworld_verl.reward import gaia_reward_func
+from train.examples.train_gaia_with_aworld_verl.rollout import build_context_aware_agent
+
+
+async def main():
+ # config module divided into environmental variables and training configurations
+ success = load_dotenv('.env')
+ custom_train_config = 'train/examples/train_gaia_with_aworld_verl/grpo_trainer.yaml'
+
+ from train.trainer.agent_trainer import AgentTrainer
+ from aworld.dataset.trajectory_strategy import MemoryTrajectoryStrategy
+ agent = build_context_aware_agent(llm_model_name=os.getenv("LLM_MODEL_NAME"),
+ llm_base_url=os.getenv("LLM_BASE_URL"),
+ llm_api_key=os.getenv("LLM_API_KEY"),
+ llm_provider="verl",
+ mcp_config=await build_mcp_config())
+ context_config = build_context_aware_task_config()
+ task_config = TaskConfig(
+ stream=False,
+ exit_on_failure=True,
+ trajectory_strategy=MemoryTrajectoryStrategy
+ )
+
+ # dataset module contains train and test dataset
+ train_dataset = f'train/examples/train_gaia_with_aworld_verl/gaia_data/sample_train.parquet'
+ test_dataset = f'train/examples/train_gaia_with_aworld_verl/gaia_data/sample_test.parquet'
+ abs_train_dataset = os.path.abspath(train_dataset)
+ abs_test_dataset = os.path.abspath(test_dataset)
+
+ # reward module contains reward function or reward function code file path
+ reward_func = gaia_reward_func
+
+ trainer = AgentTrainer(agent=agent,
+ context_config=context_config,
+ task_config=task_config,
+ config=custom_train_config,
+ reward_func=reward_func,
+ train_dataset=abs_train_dataset,
+ test_dataset=abs_test_dataset)
+ trainer.train()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/train/examples/train_gaia_with_aworld_verl/gaia_data/DeepSearch_decrypted_9.parquet b/train/examples/train_gaia_with_aworld_verl/gaia_data/DeepSearch_decrypted_9.parquet
new file mode 100644
index 000000000..a184e515f
Binary files /dev/null and b/train/examples/train_gaia_with_aworld_verl/gaia_data/DeepSearch_decrypted_9.parquet differ
diff --git a/train/examples/train_gaia_with_aworld_verl/grpo_trainer.yaml b/train/examples/train_gaia_with_aworld_verl/grpo_trainer.yaml
index 73e5a6a91..26ccd6f7e 100644
--- a/train/examples/train_gaia_with_aworld_verl/grpo_trainer.yaml
+++ b/train/examples/train_gaia_with_aworld_verl/grpo_trainer.yaml
@@ -86,7 +86,7 @@ actor_rollout_ref:
temperature: 1.0
n: 1
trace:
- backend: mlflow
+ backend: # mlflow may error
# config for the algorithm
algorithm:
diff --git a/train/examples/train_gaia_with_aworld_verl/log_processor/analyze_state_manager.py b/train/examples/train_gaia_with_aworld_verl/log_processor/analyze_state_manager.py
index f9b3eb2cd..b60235b7a 100644
--- a/train/examples/train_gaia_with_aworld_verl/log_processor/analyze_state_manager.py
+++ b/train/examples/train_gaia_with_aworld_verl/log_processor/analyze_state_manager.py
@@ -324,6 +324,19 @@ def get_tree_depth(node_id: str, visited: set = None) -> int:
if width > total_duration * 0.02: # Width must be at least 2% of total duration
display_label = busi_type_labels.get(node.busi_type, node.busi_type)
label = f"{display_label}:{duration:.3f}"
+
+ # If it's an AGENT, add agent_name or agent_id on a new line
+ if node.busi_type == 'AGENT':
+ agent_info = None
+ if node.metadata and isinstance(node.metadata, dict):
+ # Try to get agent_name from metadata
+ agent_info = node.metadata.get('agent_name') or node.metadata.get('name')
+ # If no agent_name in metadata, use busi_id as agent_id
+ if not agent_info:
+ agent_info = node.busi_id
+ if agent_info:
+ label = f"{label}
{agent_info}"
+
# if len(label) > 20:
# label = label[:17] + "..."
annotations.append(dict(
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/ip_pool.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/ip_pool.py
index 1ce8576f7..ea7dbbc54 100644
--- a/train/examples/train_gaia_with_aworld_verl/mcp_tools/ip_pool.py
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/ip_pool.py
@@ -7,12 +7,20 @@
# Record used IP addresses
_used_proxies = set()
+# Map task_id to real_out_ip for releasing IP after task completion
+_task_ip_mapping = {}
# Maximum retry count to avoid infinite loop
_MAX_RETRIES = 100
async def get_proxy_server():
+ """
+ Get a proxy server IP address.
+
+ Returns:
+ Proxy server string in format "ip:port", or None if failed.
+ """
api = f"{os.getenv('IP_POOL_PROXY')}/get_cn_proxy?interval=0&protocol=HTTP"
-
+
for attempt in range(_MAX_RETRIES):
try:
response = requests.get(api)
@@ -20,20 +28,41 @@ async def get_proxy_server():
p = j["result"]["data"]
real_out_ip = p['real_out_ip']
proxy = f"{p['proxy_public_ip']}:{p['proxy_port']}"
-
+
# Check for duplicates (filter by real_out_ip)
if real_out_ip in _used_proxies:
logger.warning(f"Duplicate real_out_ip detected: {real_out_ip} (proxy: {proxy}), retrying... (attempt {attempt + 1}/{_MAX_RETRIES})")
continue
-
+
# Record new IP (record by real_out_ip)
_used_proxies.add(real_out_ip)
+
+ # track the mapping for later release
logger.info(f"Got new proxy: {proxy} (real_out_ip: {real_out_ip})")
+
return proxy
except:
logger.error(f"Get proxy server error: {traceback.format_exc()}")
return None
-
+
# If maximum retry count reached without getting a new IP
logger.error(f"Failed to get a new proxy after {_MAX_RETRIES} attempts, all proxies seem to be duplicates")
- return None
\ No newline at end of file
+ return None
+
+
+def release_proxy_by_task_id(task_id: str):
+ """
+ Release the IP address used by a task, making it available for reuse.
+
+ Args:
+ task_id: The task ID that was used when getting the proxy.
+ """
+ if task_id in _task_ip_mapping:
+ real_out_ip = _task_ip_mapping.pop(task_id)
+ if real_out_ip in _used_proxies:
+ _used_proxies.remove(real_out_ip)
+ logger.info(f"Released proxy (real_out_ip: {real_out_ip}) for task_id: {task_id}")
+ else:
+ logger.warning(f"real_out_ip {real_out_ip} not found in _used_proxies when releasing for task_id: {task_id}")
+ else:
+ logger.warning(f"task_id {task_id} not found in _task_ip_mapping, nothing to release")
\ No newline at end of file
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/mcp_config.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/mcp_config.py
index 9774b9c93..ce30f476a 100644
--- a/train/examples/train_gaia_with_aworld_verl/mcp_tools/mcp_config.py
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/mcp_config.py
@@ -103,6 +103,17 @@ async def build_local_mcp_config():
async def build_distributed_mcp_config():
return {
"mcpServers": {
+ "qwen_file_parser": {
+ "command": "python",
+ "args": [
+ "-m",
+ "train.examples.train_gaia_with_aworld_verl.mcp_tools.qwen.qwen_file_parser"
+ ],
+ "env": {
+ "ALIBABA_CLOUD_ACCESS_KEY_ID": os.environ['ALIBABA_CLOUD_ACCESS_KEY_ID'],
+ "ALIBABA_CLOUD_ACCESS_KEY_SECRET": os.environ['ALIBABA_CLOUD_ACCESS_KEY_SECRET'],
+ }
+ },
"virtualpc-mcp-server": {
"type": "streamable-http",
"url": "http://mcp.aworldagents.com/vpc/mcp",
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/__init__.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/__init__.py
new file mode 100644
index 000000000..f0e49993c
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/__init__.py
@@ -0,0 +1,3 @@
+from .qwen_file_parser import *
+
+__all__ = ['parse_file_by_idp']
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/base.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/base.py
new file mode 100644
index 000000000..807b7b42c
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/base.py
@@ -0,0 +1,27 @@
+# qwen_agent tools base
+from typing import Any, Dict, Optional, Union
+import json
+
+
+class BaseTool:
+ """Base class for all tools"""
+
+ def __init__(self, cfg: Optional[Dict] = None):
+ self.cfg = cfg or {}
+
+ def _verify_json_format_args(self, params: Union[str, dict]) -> dict:
+ """Verify and convert parameters to dict format"""
+ if isinstance(params, str):
+ try:
+ return json.loads(params)
+ except json.JSONDecodeError:
+ raise ValueError(f"Invalid JSON format: {params}")
+ return params
+
+
+def register_tool(tool_name: str):
+ """Decorator to register a tool"""
+ def decorator(func):
+ func._tool_name = tool_name
+ return func
+ return decorator
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/idp.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/idp.py
new file mode 100644
index 000000000..dc06dced7
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/idp.py
@@ -0,0 +1,111 @@
+import logging
+import os
+import json
+
+from aworld.logs.util import logger
+
+# Try to import IDP dependencies, fallback to mock implementation if not available
+try:
+ from alibabacloud_docmind_api20220711.client import Client as docmind_api20220711Client
+ from alibabacloud_tea_openapi import models as open_api_models
+ from alibabacloud_docmind_api20220711 import models as docmind_api20220711_models
+ from alibabacloud_tea_util.client import Client as UtilClient
+ from alibabacloud_tea_util import models as util_models
+ from alibabacloud_credentials.client import Client as CredClient
+ IDP_AVAILABLE = True
+ logger.info("IDP_AVAILABLE=True")
+except ImportError:
+ IDP_AVAILABLE = False
+ logger.info("Warning: IDP dependencies not available. IDP functionality will be disabled.")
+
+class IDP():
+ def __init__(self):
+ if not IDP_AVAILABLE:
+ logger.info("IDP not available - dependencies missing")
+ self.client = None
+ return
+ config = open_api_models.Config(
+ access_key_id=os.environ['ALIBABA_CLOUD_ACCESS_KEY_ID'],
+ access_key_secret=os.environ['ALIBABA_CLOUD_ACCESS_KEY_SECRET']
+ )
+ config.endpoint = f'docmind-api.cn-hangzhou.aliyuncs.com'
+ self.client = docmind_api20220711Client(config)
+
+ def file_submit_with_url(self, file_url):
+ if not IDP_AVAILABLE or not self.client:
+ logger.info('IDP not available - skipping URL submission')
+ return None
+
+ logger.info('parsing with document url ', file_url)
+ file_name = os.path.basename(file_url)
+ request = docmind_api20220711_models.SubmitDocParserJobAdvanceRequest(
+ file_url=file_url,
+ file_name=file_name,
+ reveal_markdown=True,
+ )
+ runtime = util_models.RuntimeOptions()
+ result_dict = None
+ try:
+ response = self.client.submit_doc_parser_job_advance(request,runtime)
+ result_dict = response.body.data.id
+ except Exception as error:
+ UtilClient.assert_as_string(error.message)
+
+ return result_dict
+
+
+ def file_submit_with_path(self, file_path):
+ if not IDP_AVAILABLE or not self.client:
+ logger.info('IDP not available - skipping path submission')
+ return None
+
+ logger.info(f'parsing with document local path {file_path}')
+ file_name = os.path.basename(file_path)
+ request = docmind_api20220711_models.SubmitDocParserJobAdvanceRequest(
+ file_url_object=open(file_path, "rb"),
+ file_name=file_name,
+ )
+ runtime = util_models.RuntimeOptions()
+ result_dict = None
+ try:
+ response = self.client.submit_doc_parser_job_advance(request, runtime)
+ result_dict = response.body.data.id
+ except Exception as error:
+ logger.info(error)
+ UtilClient.assert_as_string(error.message)
+
+ return result_dict
+
+ def file_parser_query(self,fid):
+ if not IDP_AVAILABLE or not self.client:
+ logger.info('IDP not available - skipping query')
+ return None, 'unavailable'
+
+ request = docmind_api20220711_models.QueryDocParserStatusRequest(
+ id=fid
+ )
+ try:
+ response = self.client.query_doc_parser_status(request)
+ NumberOfSuccessfulParsing = response.body.data
+ except Exception as e:
+ logger.info(e)
+ return None
+ status_parse = response.body.data.status
+ NumberOfSuccessfulParsing = NumberOfSuccessfulParsing.__dict__
+ responses = dict()
+ for i in range(0, NumberOfSuccessfulParsing["number_of_successful_parsing"], 3000):
+ request = docmind_api20220711_models.GetDocParserResultRequest(
+ id=fid,
+ layout_step_size=3000,
+ layout_num=i
+ )
+ try:
+ response = self.client.get_doc_parser_result(request)
+ result = response.body.data
+ if not responses:
+ responses = result
+ else:
+ responses['layouts'].extend(result['layouts'])
+ except Exception as error:
+ return None,status_parse
+ return responses,status_parse
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/qwen_file_parser.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/qwen_file_parser.py
new file mode 100644
index 000000000..b29dd8bf7
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/qwen_file_parser.py
@@ -0,0 +1,638 @@
+from mcp.server.fastmcp import FastMCP
+
+from aworld.logs.util import logger
+
+mcp = FastMCP("qwen_file_parser")
+
+import re
+import sys
+import traceback
+from collections import Counter
+
+from dotenv import load_dotenv
+
+import json
+import os
+import time
+import zipfile
+import math
+
+from typing import Any, Dict, List, Optional, Union
+import xml.etree.ElementTree as ET
+from pandas import Timestamp
+from datetime import datetime
+from pandas.api.types import is_datetime64_any_dtype
+
+import pandas as pd
+from tabulate import tabulate
+from .settings import DEFAULT_WORKSPACE, DEFAULT_MAX_INPUT_TOKENS
+from .base import BaseTool
+from .storage import KeyNotExistsError, Storage
+from .utils import (get_file_type, hash_sha256, is_http_url, get_basename_from_url,
+ sanitize_chrome_file_path, save_url_to_local_work_dir)
+from .utils import count_tokens, tokenizer
+from .idp import IDP
+# Configuration constants
+PARSER_SUPPORTED_FILE_TYPES = ['pdf', 'docx', 'pptx', 'txt', 'html', 'csv', 'tsv', 'xlsx', 'xls', 'doc', 'zip', '.mp4', '.mov', '.mkv', '.webm', '.mp3', '.wav']
+def str_to_bool(value):
+ """Convert string to boolean, handling common true/false representations"""
+ if isinstance(value, bool):
+ return value
+ return str(value).lower() in ('true', '1', 'yes', 'on')
+USE_IDP = str_to_bool(os.getenv("USE_IDP", "True"))
+IDP_TIMEOUT = 150000
+ENABLE_CSI = False
+PARAGRAPH_SPLIT_SYMBOL = '\n'
+
+
+class CustomJSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, (datetime, Timestamp)):
+ return obj.isoformat()
+ return super().default(obj)
+
+
+class FileParserError(Exception):
+ """Custom exception for document parsing errors"""
+
+ def __init__(self, message: str, code: str = '400', exception: Optional[Exception] = None):
+ super().__init__(message)
+ self.code = code
+ self.exception = exception
+
+
+@mcp.tool(
+ description="Parse files using IDP (Intelligent Document Processing) service for supported formats"
+)
+def parse_file_by_idp(file_path: str = None, file_url: str = None) -> List[dict]:
+ idp = IDP()
+ try:
+ logger.info(f"parse_file_by_idp|start|{file_path}")
+ fid = idp.file_submit_with_url(file_url) if file_url else idp.file_submit_with_path(file_path)
+ if not fid:
+ return []
+
+ for _ in range(10):
+ result, status = idp.file_parser_query(fid)
+ if status == 'success':
+ return process_idp_result(result)
+ time.sleep(10)
+
+ logger.error("IDP parsing timeout")
+ return []
+ except Exception as e:
+ logger.error(f"IDP processing failed: {str(e)} {traceback.format_exc()}")
+ return []
+
+
+def process_idp_result(result: dict) -> List[dict]:
+ pages = []
+ current_page = None
+ logger.info("result: ", result)
+
+ for layout in result.get('layouts', []):
+ page_num = layout.get('pageNum', 0)
+ content = layout.get('markdownContent', '')
+
+ if current_page and current_page['page_num'] == page_num:
+ current_page['content'].append({'text': content})
+ else:
+ current_page = {'page_num': page_num, 'content': [{'text': content}]}
+ pages.append(current_page)
+
+ return pages
+
+
+def clean_text(text: str) -> str:
+ cleaners = [
+ lambda x: re.sub(r'\n+', '\n', x),
+ lambda x: x.replace("Add to Qwen's Reading List", ''),
+ lambda x: re.sub(r'-{6,}', '-----', x),
+ lambda x: x.strip()
+ ]
+ for cleaner in cleaners:
+ text = cleaner(text)
+ return text
+
+
+def get_plain_doc(doc: list):
+ paras = []
+ for page in doc:
+ for para in page['content']:
+ for k, v in para.items():
+ if k in ['text', 'table', 'image']:
+ paras.append(v)
+ return PARAGRAPH_SPLIT_SYMBOL.join(paras)
+
+
+def df_to_markdown(df: pd.DataFrame) -> str:
+ df = df.dropna(how='all').fillna('')
+ return tabulate(df, headers='keys', tablefmt='pipe', showindex=False)
+
+
+@mcp.tool(
+ description="Parse Word document (.docx/.doc) and return structured content with text and tables"
+)
+def parse_word(docx_path: str, extract_image: bool = False):
+ if extract_image:
+ raise ValueError('Currently, extracting images is not supported!')
+
+ from docx import Document
+ doc = Document(docx_path)
+
+ content = []
+ for para in doc.paragraphs:
+ content.append({'text': para.text})
+ for table in doc.tables:
+ tbl = []
+ for row in table.rows:
+ tbl.append('|' + '|'.join([cell.text for cell in row.cells]) + '|')
+ tbl = '\n'.join(tbl)
+ content.append({'table': tbl})
+ return [{'page_num': 1, 'content': content}]
+
+
+@mcp.tool(
+ description="Parse PowerPoint presentation (.pptx) and return structured content with text and tables"
+)
+def parse_ppt(path: str, extract_image: bool = False):
+ if extract_image:
+ raise ValueError('Currently, extracting images is not supported!')
+
+ from pptx import Presentation
+ from pptx.exc import PackageNotFoundError
+ try:
+ ppt = Presentation(path)
+ except PackageNotFoundError as ex:
+ logger.warning(ex)
+ return []
+ doc = []
+ for slide_number, slide in enumerate(ppt.slides):
+ page = {'page_num': slide_number + 1, 'content': []}
+
+ for shape in slide.shapes:
+ if not shape.has_text_frame and not shape.has_table:
+ pass
+
+ if shape.has_text_frame:
+ for paragraph in shape.text_frame.paragraphs:
+ paragraph_text = ''.join(run.text for run in paragraph.runs)
+ paragraph_text = clean_text(paragraph_text)
+ if paragraph_text.strip():
+ page['content'].append({'text': paragraph_text})
+
+ if shape.has_table:
+ tbl = []
+ for row_number, row in enumerate(shape.table.rows):
+ tbl.append('|' + '|'.join([cell.text for cell in row.cells]) + '|')
+ tbl = '\n'.join(tbl)
+ page['content'].append({'table': tbl})
+ doc.append(page)
+ return doc
+
+@mcp.tool(
+ description="Read and return content from PDF file with optional image extraction. return the parsed content. Cannot process https://URLs files."
+)
+def parse_pdf(pdf_path: str, extract_image: bool = False) -> List[dict]:
+ # Todo: header and footer
+ from pdfminer.high_level import extract_pages
+ from pdfminer.layout import LTImage, LTRect, LTTextContainer
+
+ doc = []
+ import pdfplumber
+ pdf = pdfplumber.open(pdf_path)
+ for i, page_layout in enumerate(extract_pages(pdf_path)):
+ page = {'page_num': page_layout.pageid, 'content': []}
+
+ elements = []
+ for element in page_layout:
+ elements.append(element)
+
+ # Init params for table
+ table_num = 0
+ tables = []
+
+ for element in elements:
+ if isinstance(element, LTRect):
+ if not tables:
+ tables = extract_tables(pdf, i)
+ if table_num < len(tables):
+ table_string = table_converter(tables[table_num])
+ table_num += 1
+ if table_string:
+ page['content'].append({'table': table_string, 'obj': element})
+ elif isinstance(element, LTTextContainer):
+ # Delete line breaks in the same paragraph
+ text = element.get_text()
+ # Todo: Further analysis using font
+ font = get_font(element)
+ if text.strip():
+ new_content_item = {'text': text, 'obj': element}
+ if font:
+ new_content_item['font-size'] = round(font[1])
+ # new_content_item['font-name'] = font[0]
+ page['content'].append(new_content_item)
+ elif extract_image and isinstance(element, LTImage):
+ # Todo: ocr
+ raise ValueError('Currently, extracting images is not supported!')
+ else:
+ pass
+
+ # merge elements
+ page['content'] = postprocess_page_content(page['content'])
+ doc.append(page)
+
+ return doc
+
+
+@mcp.tool(
+ description="Parse text file (.txt) and return structured content"
+)
+def parse_txt(path: str):
+ with open(path, 'r', encoding='utf-8') as f:
+ text = f.read()
+ paras = text.split(PARAGRAPH_SPLIT_SYMBOL)
+ content = []
+ for p in paras:
+ content.append({'text': p})
+ return [{'page_num': 1, 'content': content}]
+
+
+def get_font(element):
+ from pdfminer.layout import LTChar, LTTextContainer
+
+ fonts_list = []
+ for text_line in element:
+ if isinstance(text_line, LTTextContainer):
+ for character in text_line:
+ if isinstance(character, LTChar):
+ fonts_list.append((character.fontname, character.size))
+
+ fonts_list = list(set(fonts_list))
+ if fonts_list:
+ counter = Counter(fonts_list)
+ most_common_fonts = counter.most_common(1)[0][0]
+ return most_common_fonts
+ else:
+ return []
+
+
+def extract_tables(pdf, page_num):
+ table_page = pdf.pages[page_num]
+ tables = table_page.extract_tables()
+ return tables
+
+
+def table_converter(table):
+ table_string = ''
+ for row_num in range(len(table)):
+ row = table[row_num]
+ cleaned_row = [
+ item.replace('\n', ' ') if item is not None and '\n' in item else 'None' if item is None else item
+ for item in row
+ ]
+ table_string += ('|' + '|'.join(cleaned_row) + '|' + '\n')
+ table_string = table_string[:-1]
+ return table_string
+
+
+def postprocess_page_content(page_content: list) -> list:
+ # rm repetitive identification for table and text
+ # Some documents may repeatedly recognize LTRect and LTTextContainer
+ table_obj = [p['obj'] for p in page_content if 'table' in p]
+ tmp = []
+ for p in page_content:
+ repetitive = False
+ if 'text' in p:
+ for t in table_obj:
+ if t.bbox[0] <= p['obj'].bbox[0] and p['obj'].bbox[1] <= t.bbox[1] and t.bbox[2] <= p['obj'].bbox[
+ 2] and p['obj'].bbox[3] <= t.bbox[3]:
+ repetitive = True
+ break
+
+ if not repetitive:
+ tmp.append(p)
+ page_content = tmp
+
+ # merge paragraphs that have been separated by mistake
+ new_page_content = []
+ for p in page_content:
+ if new_page_content and 'text' in new_page_content[-1] and 'text' in p and abs(
+ p.get('font-size', 12) -
+ new_page_content[-1].get('font-size', 12)) < 2 and p['obj'].height < p.get('font-size', 12) + 1:
+ # Merge those lines belonging to a paragraph
+ new_page_content[-1]['text'] += f' {p["text"]}'
+ # new_page_content[-1]['font-name'] = p.get('font-name', '')
+ new_page_content[-1]['font-size'] = p.get('font-size', 12)
+ else:
+ p.pop('obj')
+ new_page_content.append(p)
+ for i in range(len(new_page_content)):
+ if 'text' in new_page_content[i]:
+ new_page_content[i]['text'] = clean_text(new_page_content[i]['text'])
+ return new_page_content
+
+
+def extract_xls_schema(file_path: str) -> Dict[str, Any]:
+ xls = pd.ExcelFile(file_path)
+ schema = {
+ "sheets": [],
+ "n_sheets": len(xls.sheet_names)
+ }
+
+ for sheet_name in xls.sheet_names:
+ df = xls.parse(sheet_name, nrows=3) # Read first 3 rows
+
+ dtype_mapping = {
+ 'object': 'string',
+ 'datetime64[ns]': 'datetime',
+ 'timedelta64[ns]': 'timedelta'
+ }
+ dtypes = df.dtypes.astype(str).replace(dtype_mapping).to_dict()
+
+ sample_df = df.head(3).copy()
+ for col in sample_df.columns:
+ if is_datetime64_any_dtype(sample_df[col]):
+ sample_df[col] = sample_df[col].dt.strftime('%Y-%m-%dT%H:%M:%S')
+
+ sheet_info = {
+ "name": sheet_name,
+ "columns": df.columns.tolist(),
+ "dtypes": dtypes,
+ "sample_data": sample_df.to_dict(orient='list')
+ }
+ schema["sheets"].append(sheet_info)
+
+ return schema
+
+
+def extract_csv_schema(file_path: str) -> Dict[str, Any]:
+ df_dtype = pd.read_csv(file_path, nrows=100)
+ df_sample = pd.read_csv(file_path, nrows=3)
+
+ return {
+ "columns": df_dtype.columns.tolist(),
+ "dtypes": df_dtype.dtypes.astype(str).to_dict(),
+ "sample_data": df_sample.to_dict(orient='list'),
+ "estimated_total_rows": _estimate_total_rows(file_path)
+ }
+
+
+def _estimate_total_rows(file_path) -> int:
+ with open(file_path, 'rb') as f:
+ line_count = 0
+ chunk_size = 1024 * 1024
+ while chunk := f.read(chunk_size):
+ line_count += chunk.count(b'\n')
+ return line_count - 1
+
+
+@mcp.tool(
+ description="Parse tabular files (.csv, .tsv, .xlsx, .xls) and return structured data or schema"
+)
+def parse_tabular_file(file_path: str, **kwargs) -> List[dict]:
+ try:
+ df = pd.read_excel(file_path) if file_path.endswith(('.xlsx', '.xls')) else \
+ pd.read_csv(file_path)
+ if count_tokens(df_to_markdown(df)) > DEFAULT_MAX_INPUT_TOKENS:
+ schema = extract_xls_schema(file_path) if file_path.endswith(('.xlsx', '.xls')) else \
+ extract_csv_schema(file_path)
+ return [{'page_num': 1, 'content': [{'schema': schema}]}]
+ else:
+ return [{'page_num': 1, 'content': [{'table': df_to_markdown(df)}]}]
+ except Exception as e:
+ logger.error(f"Table parsing failed: {str(e)}")
+ return []
+
+
+@mcp.tool(
+ description="Extract and parse files from ZIP archive"
+)
+def parse_zip(file_path: str, extract_dir: str) -> List[dict]:
+ with zipfile.ZipFile(file_path, 'r') as zip_ref:
+ zip_ref.extractall(extract_dir)
+ return [os.path.join(extract_dir, f) for f in zip_ref.namelist()]
+
+
+@mcp.tool(
+ description="Parse HTML file and return structured content with text"
+)
+def parse_html(file_path: str) -> List[dict]:
+ from bs4 import BeautifulSoup
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ soup = BeautifulSoup(f, 'lxml')
+
+ content = [{'text': clean_text(p.get_text())}
+ for p in soup.find_all(['p', 'div']) if p.get_text().strip()]
+
+ return [{
+ 'page_num': 1,
+ 'content': content,
+ 'title': soup.title.string if soup.title else ''
+ }]
+
+
+def extract_xml_skeleton_markdown(xml_file):
+ tree = ET.parse(xml_file)
+ root = tree.getroot()
+ markdown_lines = []
+
+ def process_element(element, level=0, parent_path="", is_last=True, prefix=""):
+ if level > 0:
+ connector = "└── " if is_last else "├── "
+ markdown_lines.append(f"{prefix}{connector}**{element.tag}**")
+ else:
+ markdown_lines.append(f"## Root: {element.tag}")
+
+ if element.attrib:
+ attrs = [f"`{k}`" for k in element.attrib.keys()]
+ attr_line = f"{prefix}{' ' if level > 0 else ''}*Attributes:* {', '.join(attrs)}"
+ markdown_lines.append(attr_line)
+
+ if element.text and element.text.strip():
+ text_line = f"{prefix}{' ' if level > 0 else ''}*Has text content*"
+ markdown_lines.append(text_line)
+ seen_tags = set()
+ unique_children = []
+ for child in element:
+ if child.tag not in seen_tags:
+ seen_tags.add(child.tag)
+ unique_children.append(child)
+
+ for i, child in enumerate(unique_children):
+ is_last_child = (i == len(unique_children) - 1)
+ child_prefix = prefix + (" " if is_last else "│ ")
+ process_element(child, level + 1,
+ f"{parent_path}/{element.tag}" if parent_path else element.tag,
+ is_last_child, child_prefix)
+
+ process_element(root)
+ markdown_content = "\n".join(markdown_lines)
+ return markdown_content
+
+
+@mcp.tool(
+ description="Parse XML file and return structured content or schema"
+)
+def parse_xml(file_path: str) -> List[dict]:
+ with open(file_path, 'r', encoding='utf-8') as f:
+ text = f.read()
+ if count_tokens(text) > DEFAULT_MAX_INPUT_TOKENS:
+ schema = extract_xml_skeleton_markdown(file_path)
+ content = [{'schema': schema}]
+ else:
+ content = [{'text': text}]
+ return [{'page_num': 1, 'content': content}]
+
+
+def compress(results: list) -> list[str]:
+ compress_results = []
+ max_token = math.floor(DEFAULT_MAX_INPUT_TOKENS / len(results))
+ for result in results:
+ token_list = tokenizer.tokenize(result)
+ token_list = token_list[:min(len(token_list), max_token)]
+ compress_results.append(tokenizer.convert_tokens_to_string(token_list))
+ return compress_results
+
+
+# @register_tool('file_parser')
+class SingleFileParser(BaseTool):
+ name="file_parser"
+ description = f"File parsing tool, supports parsing data in {'/'.join(PARSER_SUPPORTED_FILE_TYPES)} formats, and returns the parsed markdown format data."
+ parameters = [{
+ 'name': 'url',
+ 'type': 'string',
+ 'description': 'The full path of the file to be parsed, which can be a local path or a downloadable http(s) link.',
+ 'required': True
+ }]
+
+ def __init__(self, cfg: Optional[Dict] = None):
+ super().__init__(cfg)
+ self.data_root = self.cfg.get('path', os.path.join(DEFAULT_WORKSPACE, 'tools', self.name))
+ self.db = Storage({'storage_root_path': self.data_root})
+ self.structured_doc = self.cfg.get('structured_doc', True)
+
+
+ self.parsers = {
+ 'pdf': parse_pdf,
+ 'docx': parse_word,
+ 'doc': parse_word,
+ 'pptx': parse_ppt,
+ 'txt': parse_txt,
+ 'jsonl': parse_txt,
+ 'jsonld': parse_txt,
+ 'pdb': parse_txt,
+ 'py': parse_txt,
+ 'html': parse_html,
+ 'xml': parse_xml,
+ 'csv': lambda p: parse_tabular_file(p, sep=','),
+ 'tsv': lambda p: parse_tabular_file(p, sep='\t'),
+ 'xlsx': parse_tabular_file,
+ 'xls': parse_tabular_file,
+ 'zip': self.parse_zip
+ }
+
+ def call(self, params: Union[str, dict], **kwargs) -> Union[str, list]:
+ params = self._verify_json_format_args(params)
+ file_path = self._prepare_file(params['url'])
+ try:
+ cached = self.db.get(f'{hash_sha256(file_path)}_ori')
+ return self._flatten_result(json.loads(cached))
+ except KeyNotExistsError:
+ return self._flatten_result(self._process_new_file(file_path))
+
+ def _prepare_file(self, path: str) -> str:
+ if is_http_url(path):
+ download_dir = os.path.join(self.data_root, hash_sha256(path))
+ os.makedirs(download_dir, exist_ok=True)
+ return save_url_to_local_work_dir(path, download_dir)
+ return sanitize_chrome_file_path(path)
+
+ def _process_new_file(self, file_path: str) -> Union[str, list]:
+ file_type = get_file_type(file_path)
+ idp_types = ['pdf', 'docx', 'pptx', 'xlsx', 'jpg', 'png', 'mp3']
+ logger.info(f'Start parsing {file_path}...')
+ logger.info(f'File type {file_type}...')
+ logger.info(f"structured_doc {self.cfg.get('structured_doc')}...")
+
+ if file_type not in idp_types:
+ file_type = get_basename_from_url(file_path).split('.')[-1].lower()
+
+ try:
+ if USE_IDP and file_type in idp_types:
+ try:
+ results = parse_file_by_idp(file_path=file_path)
+ except Exception as e:
+ results = self.parsers[file_type](file_path)
+ else:
+ results = self.parsers[file_type](file_path)
+ tokens = 0
+ for page in results:
+ for para in page['content']:
+ if 'schema' in para:
+ para['token'] = count_tokens(json.dumps(para['schema']))
+ else:
+ para['token'] = count_tokens(para.get('text', para.get('table')))
+ tokens += para['token']
+
+ if not results or not tokens:
+ logger.error(f"Parsing failed: No information was parsed")
+ raise FileParserError("Document parsing failed")
+ else:
+ self._cache_result(file_path, results)
+ return results
+ except Exception as e:
+ logger.error(f"Parsing failed: {str(e)}")
+ raise FileParserError("Document parsing failed", exception=e)
+
+ def _cache_result(self, file_path: str, result: list):
+ cache_key = f'{hash_sha256(file_path)}_ori'
+ self.db.put(cache_key, json.dumps(result, ensure_ascii=False))
+ logger.info(f'The parsing result of {file_path} has been cached')
+
+ def _flatten_result(self, result: list) -> str:
+ return PARAGRAPH_SPLIT_SYMBOL.join(
+ para.get('text', para.get('table', ''))
+ for page in result for para in page['content']
+ )
+
+ def parse_zip(self, file_path: str) -> List[dict]:
+ extract_dir = os.path.join(self.data_root, f"zip_{hash_sha256(file_path)}")
+ os.makedirs(extract_dir, exist_ok=True)
+
+ results = []
+ for extracted_file in parse_zip(file_path, extract_dir):
+ if (ft := get_file_type(extracted_file)) in self.parsers:
+ try:
+ results.extend(self.parsers[ft](extracted_file))
+ except Exception as e:
+ logger.warning(f"Skip files {extracted_file}: {str(e)}")
+
+ if not results:
+ raise ValueError("No parseable content found in the ZIP file")
+ return results
+
+
+
+def main():
+ load_dotenv()
+
+ mcp.run(transport="stdio")
+
+
+# Make the module callable
+def __call__():
+ """
+ Make the module callable for uvx.
+ This function is called when the module is executed directly.
+ """
+ main()
+
+
+sys.modules[__name__].__call__ = __call__
+
+# Run the server when the script is executed directly
+if __name__ == "__main__":
+ main()
+ parse_file_by_idp("/private/tmp/usda_1959_standards.pdf")
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/settings.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/settings.py
new file mode 100644
index 000000000..1ee003c97
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/settings.py
@@ -0,0 +1,6 @@
+# qwen_agent settings
+import os
+
+DEFAULT_WORKSPACE = os.path.join(os.path.expanduser('~'), '.qwen_agent')
+DEFAULT_MAX_INPUT_TOKENS = 30000
+MAX_LLM_CALL_PER_RUN = 10
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/storage.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/storage.py
new file mode 100644
index 000000000..e3acac759
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/storage.py
@@ -0,0 +1,38 @@
+# qwen_agent tools storage
+import os
+import json
+from typing import Any, Dict, Optional
+
+
+class KeyNotExistsError(Exception):
+ """Exception raised when a key doesn't exist in storage"""
+ pass
+
+
+class Storage:
+ """Simple file-based storage implementation"""
+
+ def __init__(self, config: Dict[str, Any]):
+ self.storage_root_path = config.get('storage_root_path', './storage')
+ os.makedirs(self.storage_root_path, exist_ok=True)
+
+ def get(self, key: str) -> str:
+ """Get value by key"""
+ file_path = os.path.join(self.storage_root_path, f"{key}.json")
+ if not os.path.exists(file_path):
+ raise KeyNotExistsError(f"Key '{key}' not found")
+
+ with open(file_path, 'r', encoding='utf-8') as f:
+ return f.read()
+
+ def put(self, key: str, value: str) -> None:
+ """Put value by key"""
+ file_path = os.path.join(self.storage_root_path, f"{key}.json")
+ with open(file_path, 'w', encoding='utf-8') as f:
+ f.write(value)
+
+ def delete(self, key: str) -> None:
+ """Delete value by key"""
+ file_path = os.path.join(self.storage_root_path, f"{key}.json")
+ if os.path.exists(file_path):
+ os.remove(file_path)
diff --git a/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/utils.py b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/utils.py
new file mode 100644
index 000000000..e3f0bf679
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/mcp_tools/qwen/utils.py
@@ -0,0 +1,345 @@
+import base64
+import copy
+import hashlib
+import json
+import os
+import re
+import shutil
+import signal
+import socket
+import sys
+import time
+import traceback
+import urllib.parse
+from io import BytesIO
+from typing import Any, List, Literal, Optional, Tuple, Union
+
+import json5
+import requests
+from pydantic import BaseModel
+
+
+def append_signal_handler(sig, handler):
+ """
+ Installs a new signal handler while preserving any existing handler.
+ If an existing handler is present, it will be called _after_ the new handler.
+ """
+
+ old_handler = signal.getsignal(sig)
+ if not callable(old_handler):
+ old_handler = None
+ if sig == signal.SIGINT:
+
+ def old_handler(*args, **kwargs):
+ raise KeyboardInterrupt
+ elif sig == signal.SIGTERM:
+
+ def old_handler(*args, **kwargs):
+ raise SystemExit
+
+ def new_handler(*args, **kwargs):
+ handler(*args, **kwargs)
+ if old_handler is not None:
+ old_handler(*args, **kwargs)
+
+ signal.signal(sig, new_handler)
+
+
+def get_local_ip() -> str:
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ # doesn't even have to be reachable
+ s.connect(('10.255.255.255', 1))
+ ip = s.getsockname()[0]
+ except Exception:
+ ip = '127.0.0.1'
+ finally:
+ s.close()
+ return ip
+
+
+def hash_sha256(text: str) -> str:
+ hash_object = hashlib.sha256(text.encode())
+ key = hash_object.hexdigest()
+ return key
+
+
+def print_traceback(is_error: bool = True):
+ tb = ''.join(traceback.format_exception(*sys.exc_info(), limit=3))
+ if is_error:
+ print(f"ERROR: {tb}")
+ else:
+ print(f"WARNING: {tb}")
+
+
+CHINESE_CHAR_RE = re.compile(r'[\u4e00-\u9fff]')
+
+
+def has_chinese_chars(data: Any) -> bool:
+ text = f'{data}'
+ return bool(CHINESE_CHAR_RE.search(text))
+
+
+def get_basename_from_url(path_or_url: str, need_rm_uuid: bool = False) -> str:
+ if re.match(r'^[A-Za-z]:\\', path_or_url):
+ # "C:\\a\\b\\c" -> "C:/a/b/c"
+ path_or_url = path_or_url.replace('\\', '/')
+
+ # "/mnt/a/b/c" -> "c"
+ # "https://github.com/here?k=v" -> "here"
+ # "https://github.com/" -> ""
+ basename = urllib.parse.urlparse(path_or_url).path
+ basename = os.path.basename(basename)
+ basename = urllib.parse.unquote(basename)
+ basename = basename.strip()
+
+ # "https://github.com/" -> "" -> "github.com"
+ if not basename:
+ basename = [x.strip() for x in path_or_url.split('/') if x.strip()][-1]
+
+ new_basename = basename
+ if need_rm_uuid:
+ try:
+ # Hotfix: rm uuid
+ if len(basename) > 38 and basename[8] == '-' and basename[13] == '-' and basename[18] == '-' and basename[
+ 23] == '-' and basename[36] == '_':
+ new_basename = basename[37:]
+ except Exception:
+ new_basename = basename
+ return new_basename
+
+
+def is_http_url(path_or_url: str) -> bool:
+ if path_or_url.startswith('https://') or path_or_url.startswith('http://'):
+ return True
+ return False
+
+
+def is_image(path_or_url: str) -> bool:
+ filename = get_basename_from_url(path_or_url).lower()
+ for ext in ['jpg', 'jpeg', 'png', 'webp']:
+ if filename.endswith(ext):
+ return True
+ return False
+
+
+def sanitize_chrome_file_path(file_path: str) -> str:
+ if os.path.exists(file_path):
+ return file_path
+
+ # Dealing with "file:///...":
+ new_path = urllib.parse.urlparse(file_path)
+ new_path = urllib.parse.unquote(new_path.path)
+ new_path = sanitize_windows_file_path(new_path)
+ if os.path.exists(new_path):
+ return new_path
+
+ return sanitize_windows_file_path(file_path)
+
+
+def sanitize_windows_file_path(file_path: str) -> str:
+ # For Linux and macOS.
+ if os.path.exists(file_path):
+ return file_path
+
+ # For native Windows, drop the leading '/' in '/C:/'
+ win_path = file_path
+ if win_path.startswith('/'):
+ win_path = win_path[1:]
+ if os.path.exists(win_path):
+ return win_path
+
+ # For Windows + WSL.
+ if re.match(r'^[A-Za-z]:/', win_path):
+ wsl_path = f'/mnt/{win_path[0].lower()}/{win_path[3:]}'
+ if os.path.exists(wsl_path):
+ return wsl_path
+
+ # For native Windows, replace / with \.
+ win_path = win_path.replace('/', '\\')
+ if os.path.exists(win_path):
+ return win_path
+
+ return file_path
+
+
+def save_url_to_local_work_dir(url: str, save_dir: str, save_filename: str = '') -> str:
+ if not save_filename:
+ save_filename = get_basename_from_url(url)
+ new_path = os.path.join(save_dir, save_filename)
+ if os.path.exists(new_path):
+ os.remove(new_path)
+ # print(f'Downloading {url} to {new_path}...')
+ start_time = time.time()
+ if not is_http_url(url):
+ url = sanitize_chrome_file_path(url)
+ shutil.copy(url, new_path)
+ else:
+ headers = {
+ 'User-Agent':
+ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
+ }
+ response = requests.get(url, headers=headers)
+ if response.status_code == 200:
+ with open(new_path, 'wb') as file:
+ file.write(response.content)
+ else:
+ raise ValueError('Can not download this file. Please check your network or the file link.')
+ end_time = time.time()
+ # print(f'Finished downloading {url} to {new_path}. Time spent: {end_time - start_time} seconds.')
+ return new_path
+
+
+def save_text_to_file(path: str, text: str) -> None:
+ with open(path, 'w', encoding='utf-8') as fp:
+ fp.write(text)
+
+
+def read_text_from_file(path: str) -> str:
+ try:
+ with open(path, 'r', encoding='utf-8') as file:
+ file_content = file.read()
+ except UnicodeDecodeError:
+ print_traceback(is_error=False)
+ from charset_normalizer import from_path
+ results = from_path(path)
+ file_content = str(results.best())
+ return file_content
+
+
+def contains_html_tags(text: str) -> bool:
+ pattern = r'<(p|span|div|li|html|script)[^>]*?'
+ return bool(re.search(pattern, text))
+
+
+def get_content_type_by_head_request(path: str) -> str:
+ try:
+ response = requests.head(path, timeout=5)
+ content_type = response.headers.get('Content-Type', '')
+ return content_type
+ except requests.RequestException:
+ return 'unk'
+
+
+def get_file_type(path: str) -> Literal['pdf', 'docx', 'pptx', 'csv', 'tsv', 'xlsx', 'xls','zip','mp3','jsonl','pdb','py','xml']:
+ f_type = get_basename_from_url(path).split('.')[-1].lower()
+ if is_image(path):
+ return "image"
+ if f_type in ['pdf', 'docx', 'pptx', 'csv', 'tsv', 'xlsx', 'xls','zip','mp3','jsonl','pdb','py','xml']:
+ # Specially supported file types
+ return f_type
+
+ if is_http_url(path):
+ # The HTTP header information for the response is obtained by making a HEAD request to the target URL,
+ # where the Content-type field usually indicates the Type of Content to be returned
+ content_type = get_content_type_by_head_request(path)
+ if 'application/pdf' in content_type:
+ return 'pdf'
+ elif 'application/msword' in content_type:
+ return 'docx'
+
+ # Assuming that the URL is HTML by default,
+ # because the file downloaded by the request may contain html tags
+ return 'html'
+ else:
+ # Determine by reading local HTML file
+ try:
+ content = read_text_from_file(path)
+ except Exception:
+ print_traceback()
+ return 'unk'
+
+ if contains_html_tags(content):
+ return 'html'
+ else:
+ return 'txt'
+
+
+def extract_urls(text: str) -> List[str]:
+ pattern = re.compile(r'https?://\S+')
+ urls = re.findall(pattern, text)
+ return urls
+
+
+def extract_markdown_urls(md_text: str) -> List[str]:
+ pattern = r'!?\[[^\]]*\]\(([^\)]+)\)'
+ urls = re.findall(pattern, md_text)
+ return urls
+
+
+def extract_code(text: str) -> str:
+ # Match triple backtick blocks first
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
+ if triple_match:
+ text = triple_match.group(1)
+ else:
+ try:
+ text = json5.loads(text)['code']
+ except Exception:
+ print_traceback(is_error=False)
+ # If no code blocks found, return original text
+ return text
+
+
+def json_loads(text: str) -> dict:
+ text = text.strip('\n')
+ if text.startswith('```') and text.endswith('\n```'):
+ text = '\n'.join(text.split('\n')[1:-1])
+ try:
+ return json.loads(text)
+ except json.decoder.JSONDecodeError as json_err:
+ try:
+ return json5.loads(text)
+ except ValueError:
+ raise json_err
+
+
+class PydanticJSONEncoder(json.JSONEncoder):
+
+ def default(self, obj):
+ if isinstance(obj, BaseModel):
+ return obj.model_dump()
+ return super().default(obj)
+
+
+def json_dumps_pretty(obj: dict, ensure_ascii=False, indent=2, **kwargs) -> str:
+ return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent, cls=PydanticJSONEncoder, **kwargs)
+
+
+def json_dumps_compact(obj: dict, ensure_ascii=False, indent=None, **kwargs) -> str:
+ return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent, cls=PydanticJSONEncoder, **kwargs)
+
+# qwen_agent tokenization utilities
+import re
+from typing import List, Union
+
+
+class SimpleTokenizer:
+ """Simple tokenizer implementation for basic token counting"""
+
+ def __init__(self):
+ # Simple tokenization rules based on spaces and punctuation
+ self.word_pattern = re.compile(r'\b\w+\b|[^\w\s]')
+
+ def tokenize(self, text: str) -> List[str]:
+ """Tokenize text into tokens"""
+ if not text:
+ return []
+ return self.word_pattern.findall(text)
+
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
+ """Convert tokens back to string"""
+ return ' '.join(tokens)
+
+
+# Create global tokenizer instance
+tokenizer = SimpleTokenizer()
+
+
+def count_tokens(text: Union[str, List[str]]) -> int:
+ """Count tokens in text"""
+ if isinstance(text, list):
+ text = ' '.join(text)
+ if not text:
+ return 0
+ return len(tokenizer.tokenize(text))
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/__init__.py b/train/examples/train_gaia_with_aworld_verl/rollout/__init__.py
index 2ce7f72e6..1769851eb 100644
--- a/train/examples/train_gaia_with_aworld_verl/rollout/__init__.py
+++ b/train/examples/train_gaia_with_aworld_verl/rollout/__init__.py
@@ -1,8 +1,8 @@
# Import prompts first (no dependencies)
# Import gaia next (depends on prompts, but not on custom_agent_loop)
-from .gaia import (
- build_gaia_agent,
- build_gaia_task,
+from train.adapter.verl.utils import (
+ build_context_aware_agent,
+ build_task,
)
from .prompts import (
GAIA_SYSTEM_PROMPT,
@@ -20,8 +20,8 @@
__all__ = [
"GAIA_SYSTEM_PROMPT",
- "build_gaia_agent",
- "build_gaia_task",
+ "build_context_aware_agent",
+ "build_task",
"build_mcp_config",
"episode_memory_summary_rule",
"episode_memory_summary_schema",
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/encode_run.py b/train/examples/train_gaia_with_aworld_verl/rollout/encode_run.py
new file mode 100644
index 000000000..ec91d4c41
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/rollout/encode_run.py
@@ -0,0 +1,189 @@
+# coding: utf-8
+# Copyright (c) 2025 inclusionAI.
+import json
+import asyncio
+import sys
+import os
+from typing import List, Any, Dict
+
+from aworld.models.openai_tokenizer import openai_tokenizer
+
+os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
+
+# 添加项目路径
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../'))
+
+from train.adapter.verl.aworld_agent_loop import AworldAgentLoop
+from aworld.agents.llm_agent import Agent
+from aworld.config import AgentConfig
+from transformers import AutoTokenizer
+
+
+class MockAworldAgentLoop(AworldAgentLoop):
+ """Mock AworldAgentLoop for testing encode functionality"""
+
+ def __init__(self, tokenizer, agent, config):
+ self.tokenizer = tokenizer
+ self.agent = agent
+ self.config = config
+
+ async def build_agents(self):
+ return self.agent
+
+
+async def load_trajectory_from_json(file_path: str) -> List[Any]:
+ """读取 traj JSON 文件并转换为 trajectory 格式(消息列表)"""
+ with open(file_path, 'r', encoding='utf-8') as f:
+ # 尝试读取 JSON,可能是单行格式
+ content = f.read().strip()
+
+ # 尝试解析 JSON
+ try:
+ data = json.loads(content)
+ except json.JSONDecodeError as e:
+ # 如果不是标准 JSON,尝试使用 ast.literal_eval
+ import ast
+ try:
+ data = ast.literal_eval(content)
+ except Exception as e2:
+ raise ValueError(f"无法解析文件: {file_path}, JSON错误: {e}, literal_eval错误: {e2}")
+
+ # 如果 data 是列表,检查是否是消息列表
+ if isinstance(data, list):
+ # 如果列表中的元素是消息格式(有 role 字段),直接返回
+ if len(data) > 0 and isinstance(data[0], dict) and 'role' in data[0]:
+ return data
+ # 否则可能是 trajectory 格式,需要提取消息
+ messages = []
+ for item in data:
+ if isinstance(item, dict):
+ # 如果是 exp_data 格式,提取 messages
+ if 'exp_data' in item and isinstance(item['exp_data'], dict):
+ exp_messages = item['exp_data'].get('messages', [])
+ if exp_messages:
+ messages.extend(exp_messages)
+ # 如果直接是消息格式
+ elif 'role' in item:
+ messages.append(item)
+ return messages if messages else data
+
+ # 如果 data 是字典,可能包含 trajectory 或 messages 字段
+ elif isinstance(data, dict):
+ if 'trajectory' in data:
+ traj = data['trajectory']
+ # 如果 trajectory 是列表,提取消息
+ if isinstance(traj, list):
+ messages = []
+ for item in traj:
+ if isinstance(item, dict) and 'exp_data' in item:
+ exp_messages = item['exp_data'].get('messages', [])
+ if exp_messages:
+ messages.extend(exp_messages)
+ return messages if messages else traj
+ return traj
+ elif 'messages' in data:
+ return data['messages']
+ else:
+ # 如果是一个 dict,获取其所有的 value 并转换为 list
+ values = list(data.values())
+ # 如果 values 中有列表,尝试提取消息
+ messages = []
+ for value in values:
+ if isinstance(value, list):
+ for item in value:
+ if isinstance(item, dict) and 'role' in item:
+ messages.append(item)
+ # 如果有提取到消息,返回消息列表;否则返回所有 values
+ return messages if messages else values
+ else:
+ return [data]
+
+
+async def test_encode():
+ """测试 encode 功能"""
+ path = "/Users/hgc/hgc_repo/aworldcore/pipelines/logs/trajectory/1093217cdb524b148c5275027c615e4a/traj_1093217cdb524b148c5275027c615e4a.json"
+
+ print(f"正在读取文件: {path}")
+
+ # 读取 trajectory
+ trajectory = await load_trajectory_from_json(path)
+ print(f"成功读取 trajectory,长度: {len(trajectory)}")
+
+ # 打印前几个元素的结构
+ if trajectory:
+ print(f"\n第一个元素类型: {type(trajectory[0])}")
+ if isinstance(trajectory[0], dict):
+ print(f"第一个元素的键: {list(trajectory[0].keys())[:10]}")
+ # 如果是消息格式,打印 role
+ if 'role' in trajectory[0]:
+ print(f"第一个消息的 role: {trajectory[0].get('role')}")
+
+ # 创建 mock tokenizer(需要根据实际情况调整)
+ # 这里使用一个简单的 tokenizer,实际使用时需要根据配置加载
+ model_name = os.getenv("LLM_MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct")
+ print(f"\n正在加载 tokenizer: {model_name}")
+ try:
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
+ except Exception as e:
+ print(f"警告: 无法加载 tokenizer {model_name}: {e}")
+ print("使用默认 tokenizer...")
+ tokenizer = AutoTokenizer.from_pretrained("gpt2", trust_remote_code=True)
+
+ # 创建 mock agent
+ agent = Agent(
+ conf=AgentConfig(
+ llm_model_name=model_name,
+ llm_base_url=os.getenv("LLM_BASE_URL", ""),
+ llm_api_key=os.getenv("LLM_API_KEY", "")
+ ),
+ name="test_agent",
+ system_prompt="You are a helpful assistant."
+ )
+
+ # 创建 mock config
+ class MockConfig:
+ class ActorRolloutRef:
+ class Rollout:
+ response_length = 128000
+ rollout = Rollout()
+ actor_rollout_ref = ActorRolloutRef()
+
+ config = MockConfig()
+
+ # 创建 mock loop
+ loop = MockAworldAgentLoop(tokenizer=tokenizer, agent=agent, config=config)
+
+ # 调用 convert_memory_trajectory_agent_output
+ print("\n正在调用 convert_memory_trajectory_agent_output...")
+ try:
+ # with open("model_config/qwen_chat_template.jinja", "r") as f:
+ # chat_template = f.read()
+ # chat_template = "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}"
+
+ output = await loop.convert_memory_trajectory_agent_output(trajectory=trajectory, chat_template=None)
+ print(f"\n✅ 成功执行 encode!")
+ print(f"prompt_ids 长度: {len(output.prompt_ids)}")
+ print(f"response_ids 长度: {len(output.response_ids)}")
+ print(f"response_mask 长度: {len(output.response_mask)}")
+
+ # 打印mask前的response_ids(解码后的内容)
+ mask_before_decoded = tokenizer.decode(output.response_ids, skip_special_tokens=True)
+ print(f"\nmask前的response_ids (解码后): {mask_before_decoded}")
+
+ # 计算mask后的response_ids(只保留response_mask中对应为1的位置)
+ masked_response_ids = [output.response_ids[i] for i in range(len(output.response_ids))
+ if i < len(output.response_mask) and output.response_mask[i] == 1]
+ mask_after_decoded = tokenizer.decode(masked_response_ids, skip_special_tokens=True)
+ print(f"mask后的response_ids (解码后): {mask_after_decoded}")
+ print(f"mask后的response_ids长度: {len(masked_response_ids)}")
+
+ print(f"num_turns: {output.num_turns}")
+ print(f"metrics: {output.metrics}")
+ except Exception as e:
+ print(f"\n❌ 执行失败: {e}")
+ import traceback
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ asyncio.run(test_encode())
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/parallel.py b/train/examples/train_gaia_with_aworld_verl/rollout/parallel.py
deleted file mode 100644
index bc9a8f91a..000000000
--- a/train/examples/train_gaia_with_aworld_verl/rollout/parallel.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import logging
-import os
-import traceback
-from datetime import datetime
-
-from aworld.core.task import TaskResponse
-from aworld.evaluations.base import EvalTarget, EvalDataCase
-from aworld.runner import Runners
-from aworld.runners.state_manager import RuntimeStateManager
-from train.examples.train_gaia_with_aworld_verl.mcp_tools import build_mcp_config
-from train.examples.train_gaia_with_aworld_verl.rollout import build_gaia_agent, build_gaia_task
-
-logging.basicConfig(level=logging.INFO, force=True, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-
-log_path = os.path.join("logs", "eval_digest.log")
-
-# Use RotatingFileHandler for size-based rotation (100MB per file, keep 10 files)
-from logging.handlers import RotatingFileHandler
-
-file_handler = RotatingFileHandler(
- log_path,
- maxBytes=30 * 1024 * 1024, # 100MB per file
- backupCount=10, # Keep 10 backup files
- encoding='utf-8'
-)
-eval_digest_logger = logging.getLogger("eval_digest")
-eval_digest_logger.setLevel(level=logging.INFO)
-
-eval_digest_logger.addHandler(file_handler)
-
-
-class ParallelGaiaEvalTarget(EvalTarget[dict]):
-
- def __init__(
- self
- ):
- super().__init__()
-
- async def build_gaia_task(self, user_input: str, session_id, task_id):
- if 'screen_shot' in os.getenv("ENV_PLUGINS", ""):
- from ..env.hooks import PostToolCallRolloutHook
-
- agent = build_gaia_agent(llm_model_name=os.getenv("LLM_MODEL_NAME"),
- llm_base_url=os.getenv("LLM_BASE_URL"),
- llm_api_key=os.getenv("LLM_API_KEY"),
- mcp_config=await build_mcp_config(user_input=user_input, session_id=session_id, task_id=task_id))
- return await build_gaia_task(user_input=user_input, target=agent, timeout=1200,
- session_id=session_id, task_id=task_id)
-
-
- async def predict(self, index: int, o_input: EvalDataCase[dict]) -> dict:
- batch_id = o_input.run_id
- input = o_input.case_data
- session_id = f"{batch_id}_session#{input['id']}"
- task_id = f"{batch_id}_task#{input['id']}"
- task = await self.build_gaia_task(user_input=input['prompt'], session_id=session_id, task_id=task_id)
- task_id = task.id
-
- try:
- result = await Runners.run_task(task=task)
- os.makedirs(f"logs/trajectory/{batch_id}", exist_ok=True)
- with open(f"logs/trajectory/{batch_id}/traj_{index+1}.json", "a") as f:
- f.write(str(result[task_id].trajectory))
- os.makedirs(f"logs/results/{batch_id}", exist_ok=True)
- cur_time = datetime.now().strftime('%Y%m%d%H%M%S')
- with open(f"logs/results/{batch_id}/{task_id}_{cur_time}_{o_input.eval_case_id}.txt", "w") as f:
- f.write(result[task_id].answer)
-
- # 任务结束后,查询state_manager获取所有节点并绘制火焰图
- try:
- state_manager = RuntimeStateManager.instance()
- if state_manager:
- nodes = state_manager.query_by_task(task_id)
- if nodes:
- os.makedirs(f"logs/flame_graphs/{batch_id}", exist_ok=True)
- flame_graph_path = f"logs/flame_graphs/{batch_id}/flame_{task_id}_{cur_time}.html"
- from train.examples.train_gaia_with_aworld_verl.log_processor.analyze_state_manager import \
- plot_flame_graph
- plot_flame_graph(nodes, task_id, flame_graph_path)
- except Exception as flame_err:
- logging.warning(f"绘制火焰图失败: {flame_err}, trace: {traceback.format_exc()}")
-
- if isinstance(result, TaskResponse):
- return {"answer": result[task_id].answer, "trajectory": result[task_id].trajectory}
- if isinstance(result, dict):
- task_result = result[task_id]
- eval_digest_logger.info(
- f"eval_task_digest|{batch_id}|{task_id}|{task_result.time_cost:0.1f}|{task_result.usage}")
- return {"answer": task_result.answer, "trajectory": task_result.trajectory}
- else:
- return {"answer": result}
- except Exception as err:
- print(f"err is {err}, trace is {traceback.format_exc()}")
- return {"answer": str(err)}
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/rollout_run.py b/train/examples/train_gaia_with_aworld_verl/rollout/rollout_run.py
index c0e02c48b..2b5024ae2 100644
--- a/train/examples/train_gaia_with_aworld_verl/rollout/rollout_run.py
+++ b/train/examples/train_gaia_with_aworld_verl/rollout/rollout_run.py
@@ -1,39 +1,96 @@
import asyncio
import logging
import os
+import traceback
from datetime import datetime
from dotenv import load_dotenv
load_dotenv('.env')
-from aworld.logs.util import logger
-
-
-from train.examples.train_gaia_with_aworld_verl.rollout.parallel import ParallelGaiaEvalTarget
+from train.examples.train_gaia_with_aworld_verl.mcp_tools.ip_pool import release_proxy_by_task_id
+from aworld.core.task import TaskResponse
+from aworld.evaluations.base import EvalTarget, EvalDataCase
+from aworld.runner import Runners
+from aworld.runners.state_manager import RuntimeStateManager
+from train.examples.train_gaia_with_aworld_verl.rollout import *
+from aworld.logs.util import logger
from aworld.config import EvaluationConfig, DataLoaderConfig
from aworld.evaluations.base import EvalResult, EvalTask
from aworld.runners.evaluate_runner import EvaluateRunner
-logging.basicConfig(level=logging.INFO, force=True, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
-
-log_path = os.path.join("logs", "eval_digest.log")
-
-# Use RotatingFileHandler for size-based rotation (100MB per file, keep 10 files)
-from logging.handlers import RotatingFileHandler
-
-file_handler = RotatingFileHandler(
- log_path,
- maxBytes=30 * 1024 * 1024, # 100MB per file
- backupCount=10, # Keep 10 backup files
- encoding='utf-8'
-)
-eval_digest_logger = logging.getLogger("eval_digest")
-eval_digest_logger.setLevel(level=logging.INFO)
+# Import scorer to register it with the global scorer registry
+
+class ParallelGaiaEvalTarget(EvalTarget[dict]):
+
+ def __init__(
+ self
+ ):
+ super().__init__()
+
+ async def build_gaia_task(self, user_input: str, session_id, task_id):
+ if 'screen_shot' in os.getenv("ENV_PLUGINS", ""):
+ from train.examples.train_gaia_with_aworld_verl.mcp_tools.hooks import PostToolCallRolloutHook
+
+ agent = build_context_aware_agent(llm_model_name=os.getenv("LLM_MODEL_NAME"),
+ llm_base_url=os.getenv("LLM_BASE_URL"),
+ llm_api_key=os.getenv("LLM_API_KEY"),
+ mcp_config=await build_mcp_config())
+ return await build_task(user_input=user_input, target=agent, timeout=1200,
+ session_id=session_id, task_id=task_id)
+
+
+ async def predict(self, index: int, o_input: EvalDataCase[dict]) -> dict:
+ batch_id = o_input.run_id
+ input = o_input.case_data
+ session_id = f"{batch_id}_session#{input['id']}"
+ task_id = f"{batch_id}_task#{input['id']}"
+ task = await self.build_gaia_task(user_input=input['prompt'], session_id=session_id, task_id=task_id)
+ task_id = task.id
+
+ try:
+ result = await Runners.run_task(task=task)
+ os.makedirs(f"logs/trajectory/{batch_id}", exist_ok=True)
+ with open(f"logs/trajectory/{batch_id}/traj_{index+1}.json", "a") as f:
+ f.write(str(result[task_id].trajectory))
+ os.makedirs(f"logs/results/{batch_id}", exist_ok=True)
+ cur_time = datetime.now().strftime('%Y%m%d%H%M%S')
+ with open(f"logs/results/{batch_id}/{task_id}_{cur_time}_{o_input.eval_case_id}.txt", "w") as f:
+ f.write(result[task_id].answer)
+
+ # 任务结束后,查询state_manager获取所有节点并绘制火焰图
+ try:
+ state_manager = RuntimeStateManager.instance()
+ if state_manager:
+ nodes = state_manager.query_by_task(task_id)
+ if nodes:
+ os.makedirs(f"logs/flame_graphs/{batch_id}", exist_ok=True)
+ flame_graph_path = f"logs/flame_graphs/{batch_id}/flame_{task_id}_{cur_time}.html"
+ from train.examples.train_gaia_with_aworld_verl.log_processor.analyze_state_manager import \
+ plot_flame_graph
+ plot_flame_graph(nodes, task_id, flame_graph_path)
+ except Exception as flame_err:
+ logging.warning(f"绘制火焰图失败: {flame_err}, trace: {traceback.format_exc()}")
+
+ if isinstance(result, TaskResponse):
+ return {"answer": result[task_id].answer, "trajectory": result[task_id].trajectory}
+ if isinstance(result, dict):
+ task_result = result[task_id]
+ logger.info(
+ f"eval_task_digest|{batch_id}|{task_id}|{task_result.time_cost:0.1f}|{task_result.usage}")
+ return {"answer": task_result.answer, "trajectory": task_result.trajectory}
+ else:
+ return {"answer": result}
+ except Exception as err:
+ print(f"err is {err}, trace is {traceback.format_exc()}")
+ return {"answer": str(err)}
+ finally:
+ # 任务执行结束后释放 IP 回 IP 池
+ if os.getenv("IP_POOL_ENABLE", "False") == "True":
+ release_proxy_by_task_id(task_id)
-eval_digest_logger.addHandler(file_handler)
async def batch_run():
logger.info(f"runner_log|pid={os.getpid()}|ppid={os.getppid()}")
@@ -46,6 +103,10 @@ async def batch_run():
eval_target=eval_target,
eval_dataset_query_column="prompt",
eval_criterias=[
+ {
+ "metric_name": "flight_judge",
+ "threshold": 0.5,
+ }
] if os.getenv('ENABLE_SCORE', 'True') == 'True' else [],
eval_dataset_id_or_file_path=os.getenv(
'EVAL_DATASET_PATH',
@@ -55,7 +116,7 @@ async def batch_run():
# eval_dataset_load_config=DataLoaderConfig(sampler=RangeSampler(start_index=50, end_index=100)),
# eval_dataset_load_config=DataLoaderConfig(sampler=FixedSampler(ids = [12,14,16,24,25,26])),
repeat_times=1,
- parallel_num=20,
+ parallel_num=10,
skip_passed_cases=True,
)).run()
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/scorer/__init__.py b/train/examples/train_gaia_with_aworld_verl/rollout/scorer/__init__.py
new file mode 100644
index 000000000..c4113ea44
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/rollout/scorer/__init__.py
@@ -0,0 +1,7 @@
+from .flight_judge import FlightJudgeLLMScorer
+
+# Import custom_agent_loop last (depends on gaia and agent_loop)
+
+__all__ = [
+ "FlightJudgeLLMScorer"
+]
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/scorer/flight_judge.py b/train/examples/train_gaia_with_aworld_verl/rollout/scorer/flight_judge.py
new file mode 100644
index 000000000..a1f07848f
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/rollout/scorer/flight_judge.py
@@ -0,0 +1,154 @@
+import json
+
+from aworld.core.context.amni import TaskInput
+from aworld.evaluations.base import EvalDataCase, EvalCaseDataType, MetricResult
+from typing import Optional
+from aworld.evaluations.scorers.metrics import MetricNames
+from aworld.evaluations.scorers.scorer_registry import scorer_register
+from aworld.evaluations.scorers.llm_as_judge import LLMAsJudgeScorer
+import base64
+import os
+import glob
+
+def encode_image(imag_dir):
+ # if image_content is a path to an image file, check type of the image_content to verify
+ if imag_dir is None:
+ raise ValueError("Image path is None, cannot encode image")
+ if isinstance(imag_dir, str):
+ with open(imag_dir, "rb") as image_file:
+ return base64.b64encode(image_file.read()).decode("utf-8")
+ else:
+ return base64.b64encode(imag_dir).decode("utf-8")
+
+def get_latest_file_os(directory='.'):
+ # Use glob.glob to get all paths, filter out files, then use max to find the latest one
+ files = (p for p in glob.glob(os.path.join(directory, '*')) if os.path.isfile(p))
+ return max(files, key=os.path.getmtime, default=None)
+
+@scorer_register(MetricNames.FLIGHT_JUDGE)
+class FlightJudgeLLMScorer(LLMAsJudgeScorer):
+
+ def build_pic_data(self, input: EvalDataCase[EvalCaseDataType]):
+ task_prompt = """[Task Description]
+Your role is to act as an AI Agent Evaluator. Based on the user's query, the agent's execution path, and the final browser screenshot provided, you must determine if the agent's final answer successfully resolves the user's query.
+
+[Evaluation Criteria]
+1. Accuracy:
+The final answer must directly and accurately address the user's question.
+It must fulfill all explicit and implicit requirements mentioned in the query (e.g., location, date, direct flights, layovers, airline preferences, departure/arrival times, etc.).
+
+2. Factual Grounding:
+The final answer must be strictly grounded in the information visible in the final browser screenshot and be logically consistent with the agent's execution path.
+No fabricated or hallucinated information is allowed. Every piece of data in the answer (e.g., prices, times, flight numbers) must be verifiable from the provided evidence.
+
+3. Execution Integrity:
+The agent successfully retrieved the flight information by navigating the process unimpeded by anti-scraping measures, such as CAPTCHAs or login walls.
+
+[Output Format]
+Score:
+If the final answer meets both of the above criteria, the score is 1.
+If either criterion is not met, the score is 0.
+
+Explanation:
+You must provide a explanation for your score.
+For a score of 1, briefly explain how both criteria were met.
+For a score of 0, you must clearly state which criterion was violated and provide a specific example of the failure.
+
+Please output in the following standard JSON format without any additional explanatory text:
+{{"score":0/1, "explanation":"explain why the final answer is correct or incorrect."}}
+
+Here is the task: {task}
+"""
+ task_prompt = """[Task Description]
+Based on the answer, execution flow, and final browser screenshot, determine whether the flight query execution process encountered connection issues or anti-scraping mechanisms, including web pages that cannot be opened, user login verification, slider verification, etc.
+Note: Only issues that affect the flight query process, making it impossible to obtain final flight information or preventing flight information from loading, should be considered. If pop-up prompts appear but do not affect information retrieval, they should not be counted.
+Only when no anti-scraping mechanisms are encountered at every step of the execution process can it be concluded that the above problems were not encountered.
+
+[Output Format]
+score: score of 0 means the above problems were not encountered, score of 1 means the above problems were encountered.
+explanation: If the above problems were encountered, the specific problem encountered must be explained; if the above problems were not encountered, leave it empty.
+Output in JSON format.
+Examples:
+{{"score":1, "explanation":"User login verification"}}
+{{"score":0, "explanation":""}}
+
+[Start Task]
+{task}
+"""
+
+ base_path = os.getenv("SCREEN_SHOT_PATH", "./logs/screen_shot")
+ screenshot_dir = os.path.join(base_path, input.run_id + "_task#" + input.case_data['id'])
+ latest_screenshot = get_latest_file_os(screenshot_dir)
+ if latest_screenshot is None:
+ return [
+ {
+ "type": "text",
+ "text": task_prompt
+ }
+ ]
+
+ image_base64 = encode_image(latest_screenshot)
+
+ return [
+ {
+ "type": "text",
+ "text": task_prompt
+ },
+ {
+ "type": "image_url",
+ "image_url": {
+ "url": "data:image/png;base64," + image_base64
+ }
+ }
+ ]
+
+ def build_judge_prompt(self, index: int, input: EvalDataCase[EvalCaseDataType], output: dict) -> str:
+ return ""
+
+ def build_judge_data(self, index: int, input: EvalDataCase[EvalCaseDataType], output: dict) -> [str, TaskInput]:
+ question_column = self.eval_config.eval_dataset_query_column or 'question'
+ response_column = self.eval_config.eval_output_answer_column or 'answer'
+ if not output or 'trajectory' not in output:
+ return None
+ trajectory_list = [msg for key, msg in sorted(output.get('trajectory', {}).items())]
+
+ last_summary_idx = next(
+ (i for i in range(len(trajectory_list) - 1, -1, -1) if trajectory_list[i].get('memory_type') == 'summary'), -1
+ )
+
+ if last_summary_idx != -1:
+ messages_to_process = trajectory_list[:2] + trajectory_list[last_summary_idx:]
+ else:
+ messages_to_process = trajectory_list
+
+ new_trajectory = [
+ {"role": message["role"], "content": message["content"]}
+ for message in messages_to_process
+ ]
+ new_trajectory_str = json.dumps(new_trajectory, ensure_ascii=False)
+
+ # judge_data = f"""
+ # : {input.case_data.get(question_column, '')}
+ # [Trajectory]: {new_trajectory_str}
+ # [Final Answer]: {output.get(response_column, '')}
+ # """
+ judge_data = f"""
+ : {input.case_data.get(question_column, '')}
+ [Execution Flow]: {new_trajectory_str}
+ [Answer]: {output.get(response_column, '')}
+ """
+ pic_data = self.build_pic_data(input)
+ pic_data[0]['text'] = pic_data[0]['text'].format(task=judge_data)
+ return pic_data
+
+ def convert_judge_response_to_score(self, judge_response: str) -> Optional[dict[str, MetricResult]]:
+ json_output = self.fetch_json_from_result(judge_response)
+ if json_output:
+ return {
+ MetricNames.FLIGHT_JUDGE: MetricResult(
+ value=json_output.get('score', 0),
+ explanation=json_output.get('explanation', '')
+ )
+ }
+ return None
+
diff --git a/train/examples/train_gaia_with_aworld_verl/rollout/verl_contextaware_agent_loop.py b/train/examples/train_gaia_with_aworld_verl/rollout/verl_contextaware_agent_loop.py
new file mode 100644
index 000000000..e48fa7f6c
--- /dev/null
+++ b/train/examples/train_gaia_with_aworld_verl/rollout/verl_contextaware_agent_loop.py
@@ -0,0 +1,28 @@
+import os
+from typing import Union, Any
+
+from aworld.agents.llm_agent import Agent
+from aworld.config import TaskConfig
+from aworld.core.agent.swarm import Swarm
+from train.adapter.verl.aworld_agent_loop import AworldAgentLoop
+from train.examples.train_gaia_with_aworld_verl.mcp_tools import build_mcp_config
+from train.adapter.verl.utils import build_context_config, \
+ build_context_aware_task_config, build_context_aware_agent
+from train.adapter.verl.verl_provider import *
+
+class VerlAgentLoop(AworldAgentLoop):
+ async def build_context_config(self):
+ return build_context_config()
+
+ async def build_task_config(self) -> TaskConfig:
+ return build_context_aware_task_config()
+
+ async def build_agents(self) -> Union[Agent, Swarm]:
+ return build_context_aware_agent(llm_model_name=await self.get_llm_server_model_name(),
+ llm_base_url=await self.get_llm_server_address(),
+ # TODO use template env variables
+ llm_api_key=os.environ['LLM_API_KEY'],
+ llm_provider="verl",
+ mcp_config=await build_mcp_config(),
+ server_manager=self.server_manager,
+ tokenizer=self.tokenizer)
diff --git a/train/trainer/agent_trainer.py b/train/trainer/agent_trainer.py
index ae9a6180a..aac2f6330 100644
--- a/train/trainer/agent_trainer.py
+++ b/train/trainer/agent_trainer.py
@@ -4,8 +4,11 @@
from typing import Callable, Union, Type, Dict
from datasets import Dataset
+
from aworld.agents.llm_agent import Agent
+from aworld.config import TaskConfig
from aworld.core.common import Config
+from aworld.core.context.amni import AmniContextConfig
from aworld.logs.util import logger
from train.adapter.verl.verl_trainer import VerlTrainer
from train.trainer.trainer_processor import TrainerProcessor
@@ -25,6 +28,8 @@ def __init__(self,
reward_func: Union[str, Callable[..., float]] = None,
train_dataset: Union[str, Dataset] = None,
test_dataset: Union[str, Dataset] = None,
+ context_config: AmniContextConfig = None,
+ task_config: TaskConfig = None,
run_path: str = None,
train_engine_name: str = 'verl') -> None:
"""AgentTrainer initialization, 4 modules are required (agent, dataset, reward, config).
@@ -67,13 +72,13 @@ def __init__(self,
raise ValueError(f"{train_engine_name} train engine is not a TrainerProcessor")
# process prerequisite modules
- train_engine.check_agent(agent=agent)
+ agent_config = train_engine.check_agent(agent=agent, context_config=context_config, task_config=task_config)
train_engine.check_dataset(dataset=train_dataset, test_dataset=test_dataset)
train_engine.check_reward(reward_func=reward_func)
real_config = train_engine.check_config(config=config)
train_engine.mark_initialized()
- logger.info(f"Train config: {real_config}")
+ logger.info(f"Agent Config: {agent_config} \n Train config: {real_config}")
self.train_processor = train_engine
@staticmethod