Skip to content

Commit a81c8b1

Browse files
committed
feat: implement tier-aware max_turns for AutoAgents ReAct agent
Implement TierAwareReActAgent wrapper to override max_turns configuration based on context tier without forking AutoAgents framework. Changes: - Add TierAwareReActAgent<T> wrapper that delegates to ReActAgent<T> - Override config() method to return tier-specific max_turns - Implement AgentDeriveT, AgentHooks, and AgentExecutor traits - Update CodeGraphAgentBuilder to use TierAwareReActAgent wrapper - Add Clone derive to CodeGraphReActAgent for Arc-based sharing - Add unit tests for tier-aware max_turns (Small: 5, Medium: 10, Large: 15, Massive: 20) - Remove obsolete NOTE about hardcoded max_turns=10 limitation Max turns by tier: - Small (<50K tokens): 5 turns - Medium (50K-150K): 10 turns - Large (150K-500K): 15 turns - Massive (>500K): 20 turns This enables the AutoAgents ReAct agent to properly respect tier-aware prompting and iteration limits, improving both efficiency and effectiveness across different LLM context window sizes.
1 parent 7a11fc1 commit a81c8b1

File tree

1 file changed

+148
-13
lines changed

1 file changed

+148
-13
lines changed

crates/codegraph-mcp/src/autoagents/agent_builder.rs

Lines changed: 148 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,97 @@ impl ChatResponse for CodeGraphChatResponse {
305305
}
306306
}
307307

308+
// ============================================================================
309+
// Tier-Aware ReAct Agent Wrapper
310+
// ============================================================================
311+
312+
/// Wrapper around ReActAgent that overrides max_turns configuration
313+
/// This allows tier-aware max_turns without forking AutoAgents
314+
#[derive(Debug)]
315+
pub struct TierAwareReActAgent<T: AgentDeriveT> {
316+
inner: ReActAgent<T>,
317+
inner_derive: Arc<T>,
318+
max_turns: usize,
319+
}
320+
321+
impl<T: AgentDeriveT + AgentHooks + Clone> TierAwareReActAgent<T> {
322+
pub fn new(agent: T, max_turns: usize) -> Self {
323+
let agent_arc = Arc::new(agent);
324+
Self {
325+
inner: ReActAgent::new((*agent_arc).clone()),
326+
inner_derive: agent_arc,
327+
max_turns,
328+
}
329+
}
330+
}
331+
332+
impl<T: AgentDeriveT + AgentHooks + Clone> AgentDeriveT for TierAwareReActAgent<T> {
333+
type Output = T::Output;
334+
335+
fn description(&self) -> &'static str {
336+
self.inner_derive.description()
337+
}
338+
339+
fn name(&self) -> &'static str {
340+
self.inner_derive.name()
341+
}
342+
343+
fn output_schema(&self) -> Option<serde_json::Value> {
344+
self.inner_derive.output_schema()
345+
}
346+
347+
fn tools(&self) -> Vec<Box<dyn ToolT>> {
348+
self.inner_derive.tools()
349+
}
350+
}
351+
352+
impl<T: AgentDeriveT + AgentHooks + Clone> AgentHooks for TierAwareReActAgent<T> {}
353+
354+
impl<T: AgentDeriveT + AgentHooks + Clone> Clone for TierAwareReActAgent<T> {
355+
fn clone(&self) -> Self {
356+
Self {
357+
inner: ReActAgent::new((*self.inner_derive).clone()),
358+
inner_derive: Arc::clone(&self.inner_derive),
359+
max_turns: self.max_turns,
360+
}
361+
}
362+
}
363+
364+
#[async_trait]
365+
impl<T: AgentDeriveT + AgentHooks + Clone> AgentExecutor for TierAwareReActAgent<T> {
366+
type Output = <ReActAgent<T> as AgentExecutor>::Output;
367+
type Error = <ReActAgent<T> as AgentExecutor>::Error;
368+
369+
fn config(&self) -> ExecutorConfig {
370+
ExecutorConfig {
371+
max_turns: self.max_turns,
372+
}
373+
}
374+
375+
async fn execute(
376+
&self,
377+
task: &autoagents::core::agent::task::Task,
378+
context: Arc<Context>,
379+
) -> Result<Self::Output, Self::Error> {
380+
self.inner.execute(task, context).await
381+
}
382+
383+
async fn execute_stream(
384+
&self,
385+
task: &autoagents::core::agent::task::Task,
386+
context: Arc<Context>,
387+
) -> Result<
388+
std::pin::Pin<
389+
Box<
390+
dyn futures::Stream<Item = Result<Self::Output, Self::Error>> + Send,
391+
>,
392+
>,
393+
Self::Error,
394+
> {
395+
self.inner.execute_stream(task, context).await
396+
}
397+
}
398+
308399
// ============================================================================
309400
// Agent Builder
310401
// ============================================================================
@@ -319,13 +410,12 @@ use crate::autoagents::codegraph_agent::CodeGraphAgentOutput;
319410
use autoagents::core::agent::memory::SlidingWindowMemory;
320411
use autoagents::core::agent::prebuilt::executor::ReActAgent;
321412
use autoagents::core::agent::AgentBuilder;
322-
use autoagents::core::agent::{AgentDeriveT, AgentHooks, AgentOutputT, DirectAgentHandle, ExecutorConfig};
413+
use autoagents::core::agent::{AgentDeriveT, AgentExecutor, AgentHooks, AgentOutputT, Context, DirectAgentHandle, ExecutorConfig};
323414
use autoagents::core::error::Error as AutoAgentsError;
324415
use autoagents::core::tool::{shared_tools_to_boxes, ToolT};
325-
use autoagents_derive::AgentHooks;
326416

327417
/// Agent implementation for CodeGraph with manual tool registration
328-
#[derive(Debug)]
418+
#[derive(Debug, Clone)]
329419
pub struct CodeGraphReActAgent {
330420
tools: Vec<Arc<dyn ToolT>>,
331421
system_prompt: String,
@@ -371,11 +461,6 @@ impl AgentDeriveT for CodeGraphReActAgent {
371461

372462
impl AgentHooks for CodeGraphReActAgent {}
373463

374-
// NOTE: AutoAgents ReActAgent has max_turns hardcoded to 10 in version 8248b4e
375-
// We calculate tier-aware max_iterations (5-20) but can't override ReActAgent's config()
376-
// This is a known limitation - agents will stop at 10 turns regardless of tier
377-
// TODO: Update AutoAgents version or fork to allow configurable max_turns
378-
379464
/// Builder for CodeGraph AutoAgents workflows
380465
pub struct CodeGraphAgentBuilder {
381466
llm_adapter: Arc<CodeGraphChatAdapter>,
@@ -435,13 +520,13 @@ impl CodeGraphAgentBuilder {
435520
max_iterations,
436521
};
437522

438-
// Build ReAct agent with our CodeGraph agent
439-
let react_agent = ReActAgent::new(codegraph_agent);
523+
// Wrap in TierAwareReActAgent to override max_turns configuration
524+
let tier_aware_agent = TierAwareReActAgent::new(codegraph_agent, max_iterations);
440525

441526
// Build full agent with configuration
442527
// System prompt injected via AgentDeriveT::description() using Box::leak pattern
443528
use autoagents::core::agent::DirectAgent;
444-
let agent = AgentBuilder::<_, DirectAgent>::new(react_agent)
529+
let agent = AgentBuilder::<_, DirectAgent>::new(tier_aware_agent)
445530
.llm(self.llm_adapter)
446531
.memory(memory)
447532
.build()
@@ -457,7 +542,7 @@ impl CodeGraphAgentBuilder {
457542

458543
/// Handle for executing CodeGraph agent
459544
pub struct AgentHandle {
460-
pub agent: DirectAgentHandle<ReActAgent<CodeGraphReActAgent>>,
545+
pub agent: DirectAgentHandle<TierAwareReActAgent<CodeGraphReActAgent>>,
461546
pub tier: ContextTier,
462547
pub analysis_type: AnalysisType,
463548
}
@@ -543,11 +628,61 @@ mod tests {
543628
#[tokio::test]
544629
async fn test_chat_adapter_integration() {
545630
let mock_llm = Arc::new(MockCodeGraphLLM);
546-
let adapter = CodeGraphChatAdapter::new(mock_llm);
631+
let adapter = CodeGraphChatAdapter::new(mock_llm, ContextTier::Medium);
547632

548633
let messages = vec![ChatMessage::user().content("Hello").build()];
549634
let response = adapter.chat(&messages, None, None).await.unwrap();
550635

551636
assert_eq!(response.text(), Some("Echo: Hello".to_string()));
552637
}
638+
639+
#[test]
640+
fn test_tier_aware_max_turns_small() {
641+
// Test that Small tier gets 5 max_turns
642+
let agent = create_mock_codegraph_agent(ContextTier::Small);
643+
let wrapper = TierAwareReActAgent::new(agent, 5);
644+
645+
let config = wrapper.config();
646+
assert_eq!(config.max_turns, 5, "Small tier should have 5 max_turns");
647+
}
648+
649+
#[test]
650+
fn test_tier_aware_max_turns_medium() {
651+
// Test that Medium tier gets 10 max_turns
652+
let agent = create_mock_codegraph_agent(ContextTier::Medium);
653+
let wrapper = TierAwareReActAgent::new(agent, 10);
654+
655+
let config = wrapper.config();
656+
assert_eq!(config.max_turns, 10, "Medium tier should have 10 max_turns");
657+
}
658+
659+
#[test]
660+
fn test_tier_aware_max_turns_large() {
661+
// Test that Large tier gets 15 max_turns
662+
let agent = create_mock_codegraph_agent(ContextTier::Large);
663+
let wrapper = TierAwareReActAgent::new(agent, 15);
664+
665+
let config = wrapper.config();
666+
assert_eq!(config.max_turns, 15, "Large tier should have 15 max_turns");
667+
}
668+
669+
#[test]
670+
fn test_tier_aware_max_turns_massive() {
671+
// Test that Massive tier gets 20 max_turns
672+
let agent = create_mock_codegraph_agent(ContextTier::Massive);
673+
let wrapper = TierAwareReActAgent::new(agent, 20);
674+
675+
let config = wrapper.config();
676+
assert_eq!(config.max_turns, 20, "Massive tier should have 20 max_turns");
677+
}
678+
679+
// Helper function to create mock agent for testing
680+
fn create_mock_codegraph_agent(_tier: ContextTier) -> CodeGraphReActAgent {
681+
CodeGraphReActAgent {
682+
tools: vec![],
683+
system_prompt: "Test prompt".to_string(),
684+
analysis_type: AnalysisType::CodeSearch,
685+
max_iterations: 10,
686+
}
687+
}
553688
}

0 commit comments

Comments
 (0)