diff --git a/contributing/samples/rlm/.dockerignore b/contributing/samples/rlm/.dockerignore new file mode 100644 index 0000000000..4a4431d420 --- /dev/null +++ b/contributing/samples/rlm/.dockerignore @@ -0,0 +1,67 @@ +# Virtual environment +.venv/ +venv/ +env/ + +# Git +.git/ +.gitignore + +# Python cache +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# Build artifacts +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Testing +.pytest_cache/ +.coverage +htmlcov/ +.tox/ +.nox/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# Local files +*.db +*.sqlite +logs/ +*.log +*.jsonl + +# Documentation +docs/ +*.md +!README.md +!adk_rlm/**/*.md + +# Tests +tests/ + +# Examples +examples/ + +# Deployment scripts (not needed in container) +deployment/ diff --git a/contributing/samples/rlm/.gitignore b/contributing/samples/rlm/.gitignore new file mode 100644 index 0000000000..40ae92369d --- /dev/null +++ b/contributing/samples/rlm/.gitignore @@ -0,0 +1,28 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# Virtual environments +.venv + +# Environment files +.env +.env.* + +# ADK generated files +.adk/ +*/.adk/ + +original_rlm/ +plans/ +CLAUDE.md +logs/ +.pytest_cache/ +sessions.db +corpora/ +*.db +.ruff_cache/ \ No newline at end of file diff --git a/contributing/samples/rlm/.python-version b/contributing/samples/rlm/.python-version new file mode 100644 index 0000000000..e4fba21835 --- /dev/null +++ b/contributing/samples/rlm/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/contributing/samples/rlm/Dockerfile b/contributing/samples/rlm/Dockerfile new file mode 100644 index 0000000000..18e0a75de6 --- /dev/null +++ b/contributing/samples/rlm/Dockerfile @@ -0,0 +1,28 @@ +FROM python:3.12-slim + +WORKDIR /app + +# Install uv for faster package management (rarely changes) +RUN pip install uv + +# Copy only dependency specification first for better caching +COPY pyproject.toml README.md ./ + +# Create minimal package structure for dependency resolution +RUN mkdir -p adk_rlm && echo '__version__ = "0.1.0"' > adk_rlm/__init__.py + +# Install dependencies (this layer is cached unless pyproject.toml changes) +RUN uv pip install --system -e ".[all]" --index-url https://pypi.org/simple/ + +# Now copy the actual source code (changes frequently) +COPY adk_rlm/ adk_rlm/ + +# Create logs directory +RUN mkdir -p /app/logs + +# Expose port (Cloud Run uses PORT env var) +ENV PORT=8080 +EXPOSE 8080 + +# Run the web server +CMD ["sh", "-c", "python -m adk_rlm.web --host 0.0.0.0 --port $PORT"] diff --git a/contributing/samples/rlm/README.md b/contributing/samples/rlm/README.md new file mode 100644 index 0000000000..1aca9fb570 --- /dev/null +++ b/contributing/samples/rlm/README.md @@ -0,0 +1,307 @@ +# ADK-RLM + +A Python implementation of Recursive Language Models (RLM) using Google's Agent Development Kit (ADK) and Gemini models. + +RLM enables LLMs to handle near-infinite length contexts by programmatically examining, decomposing, and recursively calling themselves through a REPL environment. + +![alt text](image.png) + +## Features + +- **Recursive LLM Calls**: LLMs can call sub-LLMs to analyze context chunks +- **Sandboxed Python REPL**: Safe code execution with restricted builtins +- **Streaming Events**: Real-time event streaming for UI integration +- **Multi-Turn Persistence**: Maintain state across conversation turns +- **JSONL Logging**: Compatible with the original RLM visualizer +- **Rich Console Output**: Terminal output with Tokyo Night theme +- **Usage Tracking**: Track token usage per model +- **File System Integration**: Extend the concept behind RLM to file system and drives (e.g., Sharepoint, etc.), with a lazy-loading approach. + +## Installation + +```bash +# Navigate to this sample directory +cd contributing/samples/rlm + +# Create virtual environment +uv venv +source .venv/bin/activate + +# Install dependencies +uv pip install -e . + +# Or install with all optional features +uv pip install -e ".[all]" +``` + +## UI Quickstart + +```bash +# Copy environment file +cp .env.example .env + +# Authenticate with Google Cloud (AI Platform API must be enabled) +gcloud auth application-default login +gcloud auth application-default set-quota-project YOUR_PROJECT_ID + +# Run the UI +python -m adk_rlm.web +``` + +## Quick Start + +```bash +# Copy environment file +cp .env.example .env + +# Authenticate with Google Cloud (AI Platform API must be enabled) +gcloud auth application-default login +gcloud auth application-default set-quota-project YOUR_PROJECT_ID +``` + +```python +from adk_rlm import completion + +result = completion( + context="Alice is 30 years old. Bob is 25 years old.", + prompt="Who is older and by how much?", +) + +print(result.response) # Alice is older by 5 years +``` + +## Usage + +### Basic Usage + +```python +from adk_rlm import completion + +# Simple completion with options +result = completion( + context="Your document or data here...", + prompt="What patterns do you see in the data?", + model="gemini-3-flash-preview", + sub_model="gemini-3-flash-preview", + max_iterations=10, + verbose=True, # Show Rich console output +) + +print(result.response) +print(f"Execution time: {result.execution_time:.2f}s") +``` + +### Streaming Events + +For real-time UI updates, use the `RLM` class with `run_streaming()`: + +```python +import asyncio +from adk_rlm import RLM, RLMEventType + +async def main(): + rlm = RLM(model="gemini-3-flash-preview") + + async for event in rlm.run_streaming(context, prompt): + event_type = event.custom_metadata.get("event_type") + + if event_type == RLMEventType.ITERATION_START.value: + print(f"Starting iteration {event.custom_metadata['iteration']}") + + elif event_type == RLMEventType.FINAL_ANSWER.value: + print(f"Answer: {event.custom_metadata['answer']}") + + rlm.close() + +asyncio.run(main()) +``` + +### Multi-Turn Sessions + +For persistent sessions where context accumulates: + +```python +import asyncio +from adk_rlm import RLM, RLMEventType + +async def run_query(rlm, context, prompt): + """Helper to run a query and return the answer.""" + async for event in rlm.run_streaming(context, prompt): + if event.custom_metadata: + if event.custom_metadata.get("event_type") == RLMEventType.FINAL_ANSWER.value: + return event.custom_metadata.get("answer") + return "" + +async def main(): + rlm = RLM( + model="gemini-3-flash-preview", + persistent=True, # Enable multi-turn persistence + ) + + try: + # First turn + result1 = await run_query(rlm, "Alice is 30 years old.", "How old is Alice?") + + # Second turn - has access to previous context + result2 = await run_query(rlm, "Bob is 25 years old.", "Who is older?") + print(result2) # Alice + + finally: + rlm.close() + +asyncio.run(main()) +``` + +### File Loading + +```python +from adk_rlm import completion + +# Load files using glob patterns +result = completion( + files=["./docs/**/*.md", "./data/*.csv"], + prompt="Summarize the key findings across all documents.", +) + +print(result.response) +``` + +### Google Cloud Storage + +Load files directly from GCS buckets: + +```python +from adk_rlm.files.sources import GCSFileSource +from adk_rlm.files.loader import FileLoader + +# Initialize GCS source (uses Application Default Credentials) +gcs = GCSFileSource(bucket="my-bucket") + +# Or with service account +gcs = GCSFileSource( + bucket="my-bucket", + credentials_path="/path/to/service-account.json" +) + +loader = FileLoader(sources={"gcs": gcs}) + +# Load files using gs:// URIs with glob patterns +files = loader.create_lazy_files([ + "gs://my-bucket/reports/*.pdf", + "gs://my-bucket/data/**/*.csv" +]) + +# Files load lazily - no download until content accessed +for f in files: + print(f.name) # No I/O + print(f.size_kb) # Metadata fetch only + print(f.content) # Full download + parse +``` + +Install the GCS dependency: +```bash +uv pip install -e ".[gcs]" +``` + +### JSONL Logging + +```python +from adk_rlm import completion + +result = completion( + context="Your data...", + prompt="Analyze this data.", + log_dir="./logs", # Enable JSONL logging +) + +# Logs are saved to ./logs/.jsonl +# Compatible with the RLM visualizer +``` + +## How It Works + +RLM provides the LLM with a Python REPL environment that includes: + +1. **`context`**: The input data/document to analyze +2. **`llm_query(prompt, model=None)`**: Function to make sub-LLM calls +3. **`llm_query_batched(prompts, model=None)`**: Batch sub-LLM calls +4. **`FINAL_VAR(var)`**: Mark a variable as the final answer + +The LLM iteratively writes and executes Python code to: +- Break down large contexts into manageable chunks +- Make recursive LLM calls to analyze each chunk +- Aggregate results and produce a final answer + +### Example LLM Code Execution + +```python +# The LLM might generate code like: +chunks = [context[i:i+1000] for i in range(0, len(context), 1000)] +summaries = [] +for chunk in chunks: + summary = llm_query(f"Summarize: {chunk}") + summaries.append(summary) +final_summary = llm_query(f"Combine summaries: {summaries}") +FINAL_VAR(final_summary) +``` + +## Project Structure + +``` +adk_rlm/ + __init__.py # Package exports + main.py # RLM class and completion() function + types.py # Data classes + prompts.py # System/user prompts + usage.py # UsageTracker + agents/ + rlm_agent.py # Core RLMAgent implementation + repl/ + local_repl.py # Sandboxed REPL environment + safe_builtins.py # Restricted Python builtins + callbacks/ + code_execution.py # Code parsing utilities + logging/ + rlm_logger.py # JSONL logger + verbose.py # Rich console output +``` + +## Running Tests + +```bash +# Run unit tests +python -m pytest tests/ --ignore=tests/test_e2e.py --ignore=tests/test_gcs_integration.py + +# Run E2E tests (requires Gemini API access) +RLM_E2E_TESTS=true python -m pytest tests/test_e2e.py -v + +# Run GCS integration tests (requires GCS bucket) +RLM_GCS_TEST_BUCKET="your-test-bucket" \ +RLM_GCS_TEST_FILE="test/sample.txt" \ +python -m pytest tests/test_gcs_integration.py -v +``` + +### Setting Up GCS Test Bucket + +```bash +# Create bucket +BUCKET_NAME="adk-rlm-test-$(date +%s)" +gcloud storage buckets create "gs://${BUCKET_NAME}" --location=us-central1 + +# Create test files +echo "Test content" | gcloud storage cp - "gs://${BUCKET_NAME}/test/sample.txt" +echo "Report Q1" | gcloud storage cp - "gs://${BUCKET_NAME}/test/report_q1.txt" + +# Run tests +RLM_GCS_TEST_BUCKET="${BUCKET_NAME}" python -m pytest tests/test_gcs_integration.py -v + +# Clean up when done +gcloud storage rm -r "gs://${BUCKET_NAME}" +``` + +## Requirements + +- Python 3.10+ +- Google Cloud authentication (application default credentials) +- Access to Gemini models (gemini-3-flash-preview, gemini-3-pro-preview) + diff --git a/contributing/samples/rlm/adk_rlm/__init__.py b/contributing/samples/rlm/adk_rlm/__init__.py new file mode 100644 index 0000000000..3865f30107 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/__init__.py @@ -0,0 +1,122 @@ +""" +ADK-RLM: Recursive Language Models implemented with Google ADK. + +This package provides an implementation of Recursive Language Models (RLM) +using Google's Agent Development Kit (ADK) framework. + +Features: +- RLM (Recursive Language Model) pattern with Gemini +- ADK-native agent with streaming events +- File handling with lazy loading and progressive disclosure +- Support for local files, PDFs, and text formats +""" + +# Import stdlib logging before any adk_rlm imports to avoid shadowing +# by the local adk_rlm.logging module +import logging as _logging + +# Configure library logging on import +# Logs warnings and above to stderr by default +_logger = _logging.getLogger(__name__) +if not _logger.handlers: + _handler = _logging.StreamHandler() + _handler.setFormatter( + _logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + ) + _logger.addHandler(_handler) + _logger.setLevel(_logging.WARNING) + + +def configure_logging( + level: int = _logging.WARNING, + format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) -> None: + """ + Configure logging for the adk_rlm package. + + This reconfigures the library's logging level and format. + By default, the library logs WARNING and above to stderr. + + Args: + level: Logging level (e.g., logging.DEBUG, logging.INFO, logging.WARNING). + format: Log message format string. + + Example: + import logging + import adk_rlm + + # Enable debug logging for adk_rlm + adk_rlm.configure_logging(level=logging.DEBUG) + + # Or configure manually: + logging.getLogger("adk_rlm").setLevel(logging.DEBUG) + """ + logger = _logging.getLogger(__name__) + logger.setLevel(level) + + # Update existing handler's level and format + for handler in logger.handlers: + handler.setLevel(level) + handler.setFormatter(_logging.Formatter(format)) + + +from adk_rlm.agents.rlm_agent import RLMAgent +from adk_rlm.code_executor import RLMCodeExecutor +from adk_rlm.events import RLMEventData +from adk_rlm.events import RLMEventType +from adk_rlm.files import FileLoader +from adk_rlm.files import FileMetadata +from adk_rlm.files import FileParser +from adk_rlm.files import FileSource +from adk_rlm.files import LazyFile +from adk_rlm.files import LazyFileCollection +from adk_rlm.files import LoadedFile +from adk_rlm.files import LocalFileSource +from adk_rlm.files import ParsedContent +from adk_rlm.files import PDFParser +from adk_rlm.files import TextParser +from adk_rlm.main import completion +from adk_rlm.main import RLM +from adk_rlm.types import CodeBlock +from adk_rlm.types import ModelUsageSummary +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMChatCompletion +from adk_rlm.types import RLMIteration +from adk_rlm.types import RLMMetadata +from adk_rlm.types import UsageSummary + +__all__ = [ + # Main class and convenience function + "RLM", + "completion", + # Logging configuration + "configure_logging", + # ADK components + "RLMAgent", + "RLMCodeExecutor", + # Event types + "RLMEventType", + "RLMEventData", + # RLM types + "CodeBlock", + "ModelUsageSummary", + "REPLResult", + "RLMChatCompletion", + "RLMIteration", + "RLMMetadata", + "UsageSummary", + # File handling + "FileLoader", + "FileMetadata", + "FileParser", + "FileSource", + "LazyFile", + "LazyFileCollection", + "LoadedFile", + "LocalFileSource", + "ParsedContent", + "PDFParser", + "TextParser", +] + +__version__ = "0.1.0" diff --git a/contributing/samples/rlm/adk_rlm/agent.py b/contributing/samples/rlm/adk_rlm/agent.py new file mode 100644 index 0000000000..2dc1db3e39 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/agent.py @@ -0,0 +1,10 @@ +""" +ADK agent entry point for the built-in web interface. + +This module exposes the root_agent for ADK's web UI and CLI tools. +Run with: adk web adk_rlm +""" + +from adk_rlm.agents.rlm_agent import RLMAgent + +root_agent = RLMAgent() diff --git a/contributing/samples/rlm/adk_rlm/agents/__init__.py b/contributing/samples/rlm/adk_rlm/agents/__init__.py new file mode 100644 index 0000000000..d89e98fc7d --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/agents/__init__.py @@ -0,0 +1,5 @@ +"""Agent implementations for ADK-RLM.""" + +from adk_rlm.agents.rlm_agent import RLMAgent + +__all__ = ["RLMAgent"] diff --git a/contributing/samples/rlm/adk_rlm/agents/rlm_agent.py b/contributing/samples/rlm/adk_rlm/agents/rlm_agent.py new file mode 100644 index 0000000000..16cda12f75 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/agents/rlm_agent.py @@ -0,0 +1,641 @@ +""" +RLM Agent implementation using Google ADK BaseAgent. + +This is the core agent that implements the Recursive Language Model +pattern using ADK's agent abstractions and streaming events. +""" + +import asyncio +import logging +import time +from typing import Any +from typing import AsyncGenerator + +logger = logging.getLogger(__name__) + +from adk_rlm.callbacks.code_execution import find_code_blocks +from adk_rlm.callbacks.code_execution import find_final_answer +from adk_rlm.callbacks.code_execution import format_iteration +from adk_rlm.code_executor import RLMCodeExecutor +from adk_rlm.events import RLMEventData +from adk_rlm.events import RLMEventType +from adk_rlm.llm import AsyncLLMRateLimiter +from adk_rlm.logging.rlm_logger import RLMLogger +from adk_rlm.logging.verbose import VerbosePrinter +from adk_rlm.prompts import build_rlm_system_prompt +from adk_rlm.prompts import build_user_prompt +from adk_rlm.prompts import RLM_SYSTEM_PROMPT +from adk_rlm.types import CodeBlock +from adk_rlm.types import QueryMetadata +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMIteration +from adk_rlm.types import RLMMetadata +from adk_rlm.usage import UsageTracker +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.genai import types +from pydantic import PrivateAttr + +from google import genai + + +class RLMAgent(BaseAgent): + """ + Recursive Language Model agent using Google ADK BaseAgent. + + This agent implements the RLM pattern where an LLM can execute Python + code in a REPL environment, including making recursive LLM calls. + It emits granular streaming events for UI integration. + """ + + # Pydantic model fields (public configuration) + model: str = "gemini-3-pro-preview" + sub_model: str | None = None + max_iterations: int = 30 + max_depth: int = 5 + current_depth: int = 0 # Current recursion depth (0 = root level) + custom_system_prompt: str | None = None + persistent: bool = False + + # Private attributes (not part of the model schema) + _code_executor: RLMCodeExecutor | None = PrivateAttr(default=None) + _client: genai.Client | None = PrivateAttr(default=None) + _usage_tracker: UsageTracker = PrivateAttr(default_factory=UsageTracker) + _logger: RLMLogger | None = PrivateAttr(default=None) + _parent_agent: str | None = PrivateAttr(default=None) + _verbose: VerbosePrinter = PrivateAttr( + default_factory=lambda: VerbosePrinter(enabled=False) + ) + _persistent_executor: RLMCodeExecutor | None = PrivateAttr(default=None) + _ancestry: list[dict] = PrivateAttr(default_factory=list) + + def __init__( + self, + name: str = "rlm_agent", + model: str = "gemini-3-pro-preview", + sub_model: str | None = None, + max_iterations: int = 30, + max_depth: int = 5, + current_depth: int = 0, + custom_system_prompt: str | None = None, + logger: RLMLogger | None = None, + parent_agent: str | None = None, + verbose: bool = False, + persistent: bool = False, + ancestry: list[dict] | None = None, + **kwargs, + ): + """ + Initialize the RLM Agent. + + Args: + name: Agent name for identification. + model: The main model to use for RLM reasoning. + sub_model: The model to use for sub-LLM calls (defaults to model). + max_iterations: Maximum number of RLM iterations. + max_depth: Maximum recursion depth for nested llm_query calls. + current_depth: Current recursion depth (0 = root level). + custom_system_prompt: Custom system prompt (uses default if None). + logger: Optional JSONL logger for trajectory logging. + parent_agent: Name of the parent agent (for nested agents). + verbose: Whether to print verbose Rich output. + persistent: Whether to persist REPL state across calls. + ancestry: List of ancestor agent context dicts for event tagging. + **kwargs: Additional arguments for BaseAgent. + """ + super().__init__( + name=name, + model=model, + sub_model=sub_model, + max_iterations=max_iterations, + max_depth=max_depth, + current_depth=current_depth, + custom_system_prompt=custom_system_prompt, + persistent=persistent, + **kwargs, + ) + + # Initialize private attributes + self._client = genai.Client(vertexai=True, location="global") + self._usage_tracker = UsageTracker() + self._logger = logger + self._parent_agent = parent_agent + self._verbose = VerbosePrinter(enabled=verbose) + self._persistent_executor = None + self._ancestry = ancestry.copy() if ancestry else [] + + # Log/print metadata + if self._logger or verbose: + metadata = self._build_metadata() + if self._logger: + self._logger.log_metadata(metadata) + self._verbose.print_metadata(metadata) + + @property + def _effective_sub_model(self) -> str: + """Return the effective sub-model (defaults to main model).""" + return self.sub_model or self.model + + @property + def _system_prompt(self) -> str: + """Return the effective system prompt.""" + return self.custom_system_prompt or RLM_SYSTEM_PROMPT + + def _build_metadata(self) -> RLMMetadata: + """Build metadata about this RLM configuration.""" + return RLMMetadata( + root_model=self.model, + max_depth=self.max_depth, + max_iterations=self.max_iterations, + backend="gemini", + backend_kwargs={"model_name": self.model}, + environment_type="local", + environment_kwargs={}, + other_backends=[self._effective_sub_model] + if self._effective_sub_model != self.model + else None, + ) + + def _prepare_contents( + self, prompt: list[dict[str, Any]] + ) -> tuple[list[types.Content], str | None]: + """Convert message history to Gemini format.""" + system_instruction = None + contents = [] + + for msg in prompt: + role = msg.get("role") + content = msg.get("content", "") + + if role == "system": + system_instruction = content + elif role == "user": + contents.append( + types.Content(role="user", parts=[types.Part(text=content)]) + ) + elif role == "assistant": + contents.append( + types.Content(role="model", parts=[types.Part(text=content)]) + ) + + return contents, system_instruction + + async def _call_llm_async(self, message_history: list[dict[str, Any]]) -> str: + """Call the main LLM asynchronously.""" + contents, system_instruction = self._prepare_contents(message_history) + + # Build config with function calling disabled to prevent MALFORMED_FUNCTION_CALL errors + # when the model tries to use tools that aren't configured + config = types.GenerateContentConfig( + system_instruction=system_instruction, + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode="NONE") + ), + ) + + async with AsyncLLMRateLimiter(): + response = await self._client.aio.models.generate_content( + model=self.model, + contents=contents, + config=config, + ) + + self._usage_tracker.add_from_response(self.model, response.usage_metadata) + + # Handle None/empty responses with detailed logging + if response.text is None or response.text == "": + # Extract debugging info from response + finish_reason = None + safety_ratings = None + block_reason = None + + if response.candidates: + candidate = response.candidates[0] + finish_reason = getattr(candidate, "finish_reason", None) + safety_ratings = getattr(candidate, "safety_ratings", None) + if hasattr(response, "prompt_feedback"): + block_reason = getattr(response.prompt_feedback, "block_reason", None) + + logger.warning( + "LLM returned empty response: model=%s, finish_reason=%s, " + "block_reason=%s, safety_ratings=%s, usage=%s", + self.model, + finish_reason, + block_reason, + safety_ratings, + response.usage_metadata, + ) + + # Return informative message instead of empty string + reason_parts = [] + if finish_reason: + reason_parts.append(f"finish_reason={finish_reason}") + if block_reason: + reason_parts.append(f"block_reason={block_reason}") + reason_str = ", ".join(reason_parts) if reason_parts else "unknown reason" + + return f"[LLM returned empty response: {reason_str}]" + + return response.text + + def _create_rlm_event( + self, + ctx: InvocationContext, + event_type: RLMEventType, + **data, + ) -> Event: + """Create an ADK Event with RLM-specific metadata.""" + event_data = RLMEventData(event_type=event_type, **data) + metadata = event_data.to_dict() + + # Add agent identification for proper UI rendering + metadata["agent_name"] = self.name + metadata["agent_depth"] = self.current_depth + metadata["ancestry"] = self._ancestry + + return Event( + invocation_id=ctx.invocation_id, + author=self.name, + custom_metadata=metadata, + ) + + async def _run_async_impl( + self, ctx: InvocationContext + ) -> AsyncGenerator[Event, None]: + """ + Core RLM iteration loop with granular event streaming. + + This is the main entry point called by the ADK Runner. + """ + start_time = time.perf_counter() + + # Emit run start + yield self._create_rlm_event( + ctx, + RLMEventType.RUN_START, + model=self.model, + metadata={ + "sub_model": self._effective_sub_model, + "max_iterations": self.max_iterations, + }, + ) + + try: + # Get context from session state + context_payload = ( + ctx.session.state.get("rlm_context") if ctx.session else None + ) + root_prompt = ctx.session.state.get("rlm_prompt") if ctx.session else None + + if context_payload is None: + yield self._create_rlm_event( + ctx, + RLMEventType.RUN_ERROR, + error="No context provided in session state", + ) + return + + # Get conversation history if present (list of {role, content} messages) + conversation_history = ( + ctx.session.state.get("rlm_conversation_history") + if ctx.session + else None + ) + + # Create or reuse code executor + if self.persistent and self._persistent_executor is not None: + executor = self._persistent_executor + executor.add_context(context_payload) + else: + executor = RLMCodeExecutor( + sub_model=self._effective_sub_model, + current_depth=self.current_depth, + max_depth=self.max_depth, + max_iterations=self.max_iterations, + usage_tracker=self._usage_tracker, + logger=self._logger, + parent_agent=self.name, + ancestry=self._ancestry, + ) + executor.load_context(context_payload) + if self.persistent: + self._persistent_executor = executor + + # Build initial message history + query_metadata = QueryMetadata(context_payload) + message_history = build_rlm_system_prompt( + self._system_prompt, query_metadata + ) + + # Prepend conversation history if present (multi-turn conversation) + if conversation_history: + # Insert previous conversation turns after system prompt + # Format: user messages become "user", assistant messages become "assistant" + conv_messages = [] + for msg in conversation_history: + role = msg.get("role", "user") + content = msg.get("content", "") + if role == "user": + conv_messages.append({ + "role": "user", + "content": f"[Previous question from user]: {content}", + }) + elif role == "assistant": + conv_messages.append({ + "role": "assistant", + "content": f"[Your previous answer]: {content}", + }) + # Add a separator to indicate the new question + if conv_messages: + conv_messages.append({ + "role": "user", + "content": ( + "[End of conversation history. The user is now asking a" + " follow-up question below. Use the context from your" + " previous answers to provide a coherent response.]" + ), + }) + conv_messages.append({ + "role": "assistant", + "content": ( + "I understand. I'll consider my previous answers when" + " responding to the follow-up question." + ), + }) + # Insert after system prompt (first 2 messages) + message_history = ( + message_history[:2] + conv_messages + message_history[2:] + ) + + final_answer = None + + # Main iteration loop + for i in range(self.max_iterations): + # Emit iteration start + yield self._create_rlm_event( + ctx, + RLMEventType.ITERATION_START, + iteration=i + 1, + ) + + # Build current prompt + context_count = executor.get_context_count() + history_count = executor.get_history_count() + current_prompt = message_history + [ + build_user_prompt(root_prompt, i, context_count, history_count) + ] + + # Emit LLM call start + yield self._create_rlm_event( + ctx, + RLMEventType.LLM_CALL_START, + iteration=i + 1, + model=self.model, + ) + + # Call LLM + response_text = await self._call_llm_async(current_prompt) + + # Emit LLM call end + yield self._create_rlm_event( + ctx, + RLMEventType.LLM_CALL_END, + iteration=i + 1, + response_preview=response_text[:500] if response_text else None, + response_full=response_text, + ) + + # Find and execute code blocks + code_block_strs = find_code_blocks(response_text) + code_blocks = [] + + for j, code_str in enumerate(code_block_strs): + # Emit code found + yield self._create_rlm_event( + ctx, + RLMEventType.CODE_FOUND, + iteration=i + 1, + block_index=j, + code=code_str[:200] if code_str else None, + code_full=code_str, + ) + + # Emit code execution start + yield self._create_rlm_event( + ctx, + RLMEventType.CODE_EXEC_START, + iteration=i + 1, + block_index=j, + ) + + # Set iteration context so child events can reference parent iteration + executor.set_iteration_context(i + 1, j) + + # Execute code asynchronously while streaming child events in real-time + from google.adk.code_executors.code_execution_utils import CodeExecutionInput + + # Reset queue state BEFORE starting the task to avoid race conditions + executor.reset_event_state() + + # Start code execution as a background task + exec_task = asyncio.create_task( + executor.execute_code_async( + ctx, CodeExecutionInput(code=code_str) + ) + ) + + # Poll for child events while execution runs + async for child_event in executor.poll_child_events(): + yield child_event + + # Wait for execution to complete + exec_result = await exec_task + + # Yield any remaining events that arrived after polling stopped + for remaining_event in executor.pop_child_events(): + yield remaining_event + + # Create REPLResult for compatibility + repl_result = REPLResult( + stdout=exec_result.stdout, + stderr=exec_result.stderr, + locals=executor.locals.copy(), + execution_time=0.0, + ) + code_blocks.append(CodeBlock(code=code_str, result=repl_result)) + + # Emit code execution end + yield self._create_rlm_event( + ctx, + RLMEventType.CODE_EXEC_END, + iteration=i + 1, + block_index=j, + output=exec_result.stdout[:1000] if exec_result.stdout else None, + output_full=exec_result.stdout, + error=exec_result.stderr[:500] if exec_result.stderr else None, + error_full=exec_result.stderr, + has_error=bool(exec_result.stderr), + ) + + # Create iteration object for logging + iteration = RLMIteration( + prompt=current_prompt, + response=response_text, + code_blocks=code_blocks, + ) + + # Check for final answer in response text + final_answer = find_final_answer(response_text, None) + + # Also check for FINAL_ANSWER variable + if final_answer is None and executor.final_answer: + final_answer = executor.final_answer + + # Also check REPL locals for FINAL_VAR pattern + if final_answer is None: + final_answer = find_final_answer(response_text, executor._repl) + + iteration.final_answer = final_answer + + # Log iteration + if self._logger: + self._logger.log( + iteration, + depth=self.current_depth, + agent_name=self.name, + parent_agent=self._parent_agent, + ) + + # Verbose output + self._verbose.print_iteration(iteration, i + 1) + + if final_answer is not None: + # Emit final detected + yield self._create_rlm_event( + ctx, + RLMEventType.FINAL_DETECTED, + iteration=i + 1, + source="text" + if find_final_answer(response_text, None) + else "variable", + ) + + # Emit final answer + yield self._create_rlm_event( + ctx, + RLMEventType.FINAL_ANSWER, + answer=final_answer, + total_iterations=i + 1, + execution_time_ms=(time.perf_counter() - start_time) * 1000, + ) + + # Store history if persistent + if self.persistent: + executor.add_history(message_history) + + # Emit run end + yield self._create_rlm_event( + ctx, + RLMEventType.RUN_END, + success=True, + total_iterations=i + 1, + ) + + self._verbose.print_final_answer(final_answer) + self._verbose.print_summary( + i + 1, + time.perf_counter() - start_time, + self._usage_tracker.get_summary().to_dict(), + ) + return + + # Emit iteration end + yield self._create_rlm_event( + ctx, + RLMEventType.ITERATION_END, + iteration=i + 1, + ) + + # Format iteration for next prompt + new_messages = format_iteration(iteration) + message_history.extend(new_messages) + + # Max iterations reached - generate fallback answer + fallback = await self._generate_fallback_answer_async(message_history) + + yield self._create_rlm_event( + ctx, + RLMEventType.FINAL_ANSWER, + answer=fallback, + total_iterations=self.max_iterations, + execution_time_ms=(time.perf_counter() - start_time) * 1000, + ) + + yield self._create_rlm_event( + ctx, + RLMEventType.RUN_END, + success=True, + total_iterations=self.max_iterations, + fallback=True, + ) + + if self.persistent: + executor.add_history(message_history) + + self._verbose.print_final_answer(fallback) + self._verbose.print_summary( + self.max_iterations, + time.perf_counter() - start_time, + self._usage_tracker.get_summary().to_dict(), + ) + + except Exception as e: + yield self._create_rlm_event( + ctx, + RLMEventType.RUN_ERROR, + error=str(e), + metadata={"error_type": type(e).__name__}, + ) + raise + + async def _generate_fallback_answer_async( + self, message_history: list[dict[str, Any]] + ) -> str: + """Generate a default answer when max iterations is reached.""" + fallback_prompt = message_history + [{ + "role": "user", + "content": ( + "Please provide a final answer to the user's question based on the" + " information gathered so far." + ), + }] + response = await self._call_llm_async(fallback_prompt) + + if self._logger: + self._logger.log( + RLMIteration( + prompt=fallback_prompt, + response=response, + final_answer=response, + code_blocks=[], + ), + depth=self.current_depth, + agent_name=self.name, + parent_agent=self._parent_agent, + ) + + return response + + def close(self) -> None: + """Clean up persistent environment.""" + if self._persistent_executor is not None: + self._persistent_executor.cleanup() + self._persistent_executor = None + + def __enter__(self) -> "RLMAgent": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + self.close() + return False diff --git a/contributing/samples/rlm/adk_rlm/callbacks/__init__.py b/contributing/samples/rlm/adk_rlm/callbacks/__init__.py new file mode 100644 index 0000000000..173ee9508a --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/callbacks/__init__.py @@ -0,0 +1,13 @@ +"""Callbacks for ADK-RLM.""" + +from adk_rlm.callbacks.code_execution import find_code_blocks +from adk_rlm.callbacks.code_execution import find_final_answer +from adk_rlm.callbacks.code_execution import format_execution_result +from adk_rlm.callbacks.code_execution import format_iteration + +__all__ = [ + "find_code_blocks", + "find_final_answer", + "format_execution_result", + "format_iteration", +] diff --git a/contributing/samples/rlm/adk_rlm/callbacks/code_execution.py b/contributing/samples/rlm/adk_rlm/callbacks/code_execution.py new file mode 100644 index 0000000000..d74ee9ed4d --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/callbacks/code_execution.py @@ -0,0 +1,194 @@ +""" +Code execution utilities for ADK-RLM. + +This module provides functions for parsing code blocks from LLM responses +and processing execution results. +""" + +import re +from typing import TYPE_CHECKING + +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMIteration + +if TYPE_CHECKING: + from adk_rlm.repl.local_repl import LocalREPL + + +def _extract_balanced_parens(text: str, start_pos: int) -> str | None: + """ + Extract content inside balanced parentheses starting at start_pos. + + Args: + text: The full text. + start_pos: Position of the opening parenthesis. + + Returns: + The content inside the balanced parentheses, or None if unbalanced. + """ + if start_pos >= len(text) or text[start_pos] != "(": + return None + + depth = 0 + start = start_pos + 1 # Skip the opening paren + + for i in range(start_pos, len(text)): + if text[i] == "(": + depth += 1 + elif text[i] == ")": + depth -= 1 + if depth == 0: + return text[start:i] + + # Unbalanced - return everything after the opening paren + return text[start:] + + +def find_code_blocks(text: str | None) -> list[str]: + """ + Find REPL code blocks in text wrapped in triple backticks. + + Args: + text: The text to search for code blocks. + + Returns: + List of code block contents (without the ```repl markers). + """ + if text is None: + return [] + + pattern = r"```repl\s*\n(.*?)\n```" + results = [] + + for match in re.finditer(pattern, text, re.DOTALL): + code_content = match.group(1).strip() + results.append(code_content) + + return results + + +def find_final_answer( + text: str | None, repl: "LocalREPL | None" = None +) -> str | None: + """ + Find FINAL(...) or FINAL_VAR(...) statement in response. + + Args: + text: The response text to parse. + repl: Optional REPL environment for FINAL_VAR resolution. + + Returns: + The final answer string, or None if no final answer pattern is found. + """ + if text is None: + return None + + # Check for FINAL_VAR pattern first - must be at start of line + # Use regex to find the start, then balanced parens for content + final_var_match = re.search(r"^\s*FINAL_VAR\(", text, re.MULTILINE) + if final_var_match: + paren_start = final_var_match.end() - 1 # Position of '(' + variable_name = _extract_balanced_parens(text, paren_start) + if variable_name is not None: + variable_name = variable_name.strip().strip('"').strip("'") + if repl is not None: + result = repl.execute_code(f"print(FINAL_VAR({variable_name!r}))") + final_answer = result.stdout.strip() + if final_answer == "": + final_answer = result.stderr.strip() or "" + # Check if FINAL_VAR returned an error (variable not found) + if final_answer.startswith( + "Error: Variable '" + ) and final_answer.endswith("' not found"): + return None + return final_answer + return None + + # Check for FINAL pattern - must be at start of line + # Use regex to find the start, then balanced parens for content + final_match = re.search(r"^\s*FINAL\(", text, re.MULTILINE) + if final_match: + paren_start = final_match.end() - 1 # Position of '(' + content = _extract_balanced_parens(text, paren_start) + if content is not None: + return content.strip() + + return None + + +def format_execution_result(result: REPLResult) -> str: + """ + Format the execution result as a string for display. + + Args: + result: The REPLResult object to format. + + Returns: + A formatted string representation of the result. + """ + result_parts = [] + + if result.stdout: + result_parts.append(f"\n{result.stdout}") + + if result.stderr: + result_parts.append(f"\n{result.stderr}") + + # Show some key variables (excluding internal ones) + important_vars = {} + for key, value in result.locals.items(): + if not key.startswith("_") and key not in [ + "__builtins__", + "__name__", + "__doc__", + ]: + # Only show simple types or short representations + if isinstance(value, (str, int, float, bool, list, dict, tuple)): + important_vars[key] = "" + + if important_vars: + result_parts.append(f"REPL variables: {list(important_vars.keys())}\n") + + return "\n\n".join(result_parts) if result_parts else "No output" + + +def format_iteration( + iteration: RLMIteration, max_character_length: int = 20000 +) -> list[dict[str, str]]: + """ + Format an RLM iteration to append to the message history. + + Args: + iteration: The iteration to format. + max_character_length: Maximum character length for results. + + Returns: + A list of messages to add to the next prompt. + """ + # Handle None responses - use empty string to avoid corrupting message history + response_content = ( + iteration.response if iteration.response is not None else "" + ) + messages = [{"role": "assistant", "content": response_content}] + + for code_block in iteration.code_blocks: + code = code_block.code + result = code_block.result + result_str = format_execution_result(result) + + if len(result_str) > max_character_length: + result_str = ( + result_str[:max_character_length] + + f"... + [{len(result_str) - max_character_length} chars...]" + ) + + execution_message = { + "role": "user", + "content": ( + f"Code executed:\n```python\n{code}\n```\n\nREPL" + f" output:\n{result_str}" + ), + } + messages.append(execution_message) + + return messages diff --git a/contributing/samples/rlm/adk_rlm/cli.py b/contributing/samples/rlm/adk_rlm/cli.py new file mode 100644 index 0000000000..95cc7dd70a --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/cli.py @@ -0,0 +1,1129 @@ +""" +Interactive CLI for ADK-RLM with real-time streaming output. + +A conversational REPL that uses the ADK Runner and renders events +to the terminal using Rich. Supports slash commands for configuration. +Sessions are persisted using ADK's DatabaseSessionService. +""" + +import argparse +import asyncio +from datetime import datetime +import os +from pathlib import Path +import time +import uuid + +from adk_rlm import RLM +from adk_rlm import RLMEventType +from google.adk.events import Event +from google.adk.events import EventActions +from google.adk.sessions import DatabaseSessionService +from google.adk.sessions import Session +from rich.console import Console +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.prompt import Prompt +from rich.rule import Rule +from rich.style import Style +from rich.syntax import Syntax +from rich.table import Table +from rich.text import Text + +# Default configuration +DEFAULT_DB_URL = os.environ.get( + "RLM_CLI_DB_URL", "sqlite+aiosqlite:///./cli_sessions.db" +) +APP_NAME = "adk_rlm_cli" +DEFAULT_USER_ID = "default_user" + +# Tokyo Night Color Theme +COLORS = { + "primary": "#7AA2F7", + "secondary": "#BB9AF7", + "success": "#9ECE6A", + "warning": "#E0AF68", + "error": "#F7768E", + "text": "#A9B1D6", + "muted": "#565F89", + "accent": "#7DCFFF", + "border": "#3B4261", +} + + +class RLMDisplay: + """Manages the Rich display for RLM execution.""" + + def __init__(self, console: Console): + self.console = console + self.current_iteration = 0 + self.status_text = "Initializing..." + self.last_response_preview = "" + self.last_code = "" + self.last_output = "" + self.final_answer = None + self.total_iterations = 0 + self.execution_time_ms = 0 + + def reset(self): + """Reset display state for a new query.""" + self.current_iteration = 0 + self.status_text = "Thinking..." + self.last_response_preview = "" + self.last_code = "" + self.last_output = "" + self.final_answer = None + self.total_iterations = 0 + self.execution_time_ms = 0 + + def build_display(self) -> Panel: + """Build the current display panel.""" + content_parts = [] + + # Status line + status = Text() + status.append("◆ ", style=Style(color=COLORS["accent"])) + status.append(self.status_text, style=Style(color=COLORS["text"])) + content_parts.append(status) + + # Current iteration + if self.current_iteration > 0: + iter_text = Text() + iter_text.append( + f"\nIteration {self.current_iteration}", + style=Style(color=COLORS["primary"], bold=True), + ) + content_parts.append(iter_text) + + # Last response preview + if self.last_response_preview: + response_text = Text() + response_text.append( + "\n\nLLM Response: ", style=Style(color=COLORS["muted"]) + ) + preview = self.last_response_preview[:300] + if len(self.last_response_preview) > 300: + preview += "..." + response_text.append(preview, style=Style(color=COLORS["text"])) + content_parts.append(response_text) + + # Last code block + if self.last_code: + code_header = Text("\n\nCode: ", style=Style(color=COLORS["success"])) + content_parts.append(code_header) + code_preview = self.last_code[:200] + if len(self.last_code) > 200: + code_preview += "\n..." + content_parts.append( + Text(code_preview, style=Style(color=COLORS["success"], dim=True)) + ) + + # Last output + if self.last_output: + output_text = Text() + output_text.append("\n\nOutput: ", style=Style(color=COLORS["muted"])) + output_preview = self.last_output[:200] + if len(self.last_output) > 200: + output_preview += "..." + output_text.append(output_preview, style=Style(color=COLORS["accent"])) + content_parts.append(output_text) + + title = Text() + title.append("◆ ", style=Style(color=COLORS["accent"])) + title.append("RLM", style=Style(color=COLORS["primary"], bold=True)) + title.append(" ━ Processing", style=Style(color=COLORS["muted"])) + + return Panel( + Group(*content_parts), + title=title, + title_align="left", + border_style=COLORS["border"], + padding=(1, 2), + ) + + def update_from_event(self, event) -> None: + """Update display state from an event.""" + if not event.custom_metadata: + return + + event_type = event.custom_metadata.get("event_type") + + if event_type == RLMEventType.RUN_START.value: + self.status_text = "Starting..." + + elif event_type == RLMEventType.ITERATION_START.value: + self.current_iteration = event.custom_metadata.get("iteration", 0) + self.status_text = f"Iteration {self.current_iteration} - Thinking..." + self.last_code = "" + self.last_output = "" + + elif event_type == RLMEventType.LLM_CALL_START.value: + self.status_text = f"Iteration {self.current_iteration} - Calling LLM..." + + elif event_type == RLMEventType.LLM_CALL_END.value: + self.last_response_preview = event.custom_metadata.get( + "response_preview", "" + ) + self.status_text = f"Iteration {self.current_iteration} - Processing..." + + elif event_type == RLMEventType.CODE_FOUND.value: + self.last_code = event.custom_metadata.get("code", "") + self.status_text = f"Iteration {self.current_iteration} - Found code..." + + elif event_type == RLMEventType.CODE_EXEC_START.value: + self.status_text = f"Iteration {self.current_iteration} - Executing..." + + elif event_type == RLMEventType.CODE_EXEC_END.value: + output = event.custom_metadata.get("output", "") + error = event.custom_metadata.get("error", "") + self.last_output = output or error or "(no output)" + self.status_text = f"Iteration {self.current_iteration} - Code executed" + + elif event_type == RLMEventType.FINAL_DETECTED.value: + self.status_text = "Final answer detected!" + + elif event_type == RLMEventType.FINAL_ANSWER.value: + self.final_answer = event.custom_metadata.get("answer", "") + self.total_iterations = event.custom_metadata.get("total_iterations", 0) + self.execution_time_ms = event.custom_metadata.get("execution_time_ms", 0) + + elif event_type == RLMEventType.RUN_END.value: + self.status_text = "Complete!" + + elif event_type == RLMEventType.RUN_ERROR.value: + error = event.custom_metadata.get("error", "Unknown error") + self.status_text = f"Error: {error}" + + +class InteractiveCLI: + """Interactive REPL for ADK-RLM with session persistence.""" + + DEFAULT_LOG_DIR = "./logs" + + def __init__( + self, + model: str = "gemini-3-pro-preview", + sub_model: str | None = None, + max_iterations: int = 30, + verbose: bool = False, + log_dir: str | None = None, + db_url: str = DEFAULT_DB_URL, + session_id: str | None = None, + ): + self.console = Console() + self.model = model + self.sub_model = sub_model + self.max_iterations = max_iterations + self.verbose = verbose + self.log_dir = log_dir or self.DEFAULT_LOG_DIR + self.db_url = db_url + + # Session management + self.session_service: DatabaseSessionService | None = None + self.session: Session | None = None + self.session_id = session_id or str(uuid.uuid4()) + + # State (will be synced with session) + self.files: list[str] = [] + self.conversation: list[dict] = [] + self.rlm: RLM | None = None + self.display = RLMDisplay(self.console) + + def _init_rlm(self): + """Initialize or reinitialize the RLM instance.""" + if self.rlm: + self.rlm.close() + + self.rlm = RLM( + model=self.model, + sub_model=self.sub_model, + max_iterations=self.max_iterations, + persistent=True, # Keep REPL state across turns + log_dir=self.log_dir, + ) + + async def _init_session_service(self): + """Initialize the session service.""" + if self.session_service is None: + self.session_service = DatabaseSessionService(db_url=self.db_url) + + async def _load_or_create_session(self) -> Session: + """Load existing session or create a new one.""" + await self._init_session_service() + + session = await self.session_service.get_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=self.session_id, + ) + + if session is None: + session = await self.session_service.create_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=self.session_id, + state={ + "title": f"Session {datetime.now().strftime('%Y-%m-%d %H:%M')}", + "model": self.model, + "sub_model": self.sub_model, + "max_iterations": self.max_iterations, + "files": [], + "conversation": [], + }, + ) + else: + # Restore state from session + self.model = session.state.get("model", self.model) + self.sub_model = session.state.get("sub_model", self.sub_model) + self.max_iterations = session.state.get( + "max_iterations", self.max_iterations + ) + self.files = session.state.get("files", []) + self.conversation = session.state.get("conversation", []) + + self.session = session + return session + + async def _update_session_state(self, state_updates: dict): + """Update session state and persist to database.""" + if not self.session or not self.session_service: + return + + # Update in-memory state + self.session.state.update(state_updates) + + # Create a state-update event to persist changes + event = Event( + author="system", + timestamp=time.time(), + actions=EventActions(state_delta=state_updates), + ) + + await self.session_service.append_event(self.session, event) + + async def _sync_state_to_session(self): + """Sync current state to session.""" + await self._update_session_state({ + "model": self.model, + "sub_model": self.sub_model, + "max_iterations": self.max_iterations, + "files": self.files, + "conversation": self.conversation, + }) + + def print_welcome(self): + """Print welcome message.""" + title = Text() + title.append("◆ ", style=Style(color=COLORS["accent"])) + title.append("ADK-RLM", style=Style(color=COLORS["primary"], bold=True)) + title.append(" ━ Interactive Mode", style=Style(color=COLORS["muted"])) + + help_text = Text() + help_text.append( + "Type a message to chat, or use slash commands:\n\n", + style=Style(color=COLORS["text"]), + ) + help_text.append(" /files ", style=Style(color=COLORS["accent"])) + help_text.append(" ", style=Style(color=COLORS["muted"])) + help_text.append( + "Add files to context\n", style=Style(color=COLORS["text"]) + ) + help_text.append( + " /clear ", style=Style(color=COLORS["accent"]) + ) + help_text.append( + "Clear files and conversation\n", style=Style(color=COLORS["text"]) + ) + help_text.append( + " /status ", style=Style(color=COLORS["accent"]) + ) + help_text.append( + "Show current configuration\n", style=Style(color=COLORS["text"]) + ) + help_text.append(" /model ", style=Style(color=COLORS["accent"])) + help_text.append(" ", style=Style(color=COLORS["muted"])) + help_text.append("Change model\n", style=Style(color=COLORS["text"])) + help_text.append(" /iterations ", style=Style(color=COLORS["accent"])) + help_text.append(" ", style=Style(color=COLORS["muted"])) + help_text.append("Set max iterations\n", style=Style(color=COLORS["text"])) + help_text.append( + " /logs ", style=Style(color=COLORS["accent"]) + ) + help_text.append("Show log file path\n", style=Style(color=COLORS["text"])) + help_text.append("\n", style=Style(color=COLORS["text"])) + help_text.append( + " /sessions ", style=Style(color=COLORS["secondary"]) + ) + help_text.append("List saved sessions\n", style=Style(color=COLORS["text"])) + help_text.append( + " /new ", style=Style(color=COLORS["secondary"]) + ) + help_text.append("Create new session\n", style=Style(color=COLORS["text"])) + help_text.append(" /load ", style=Style(color=COLORS["secondary"])) + help_text.append(" ", style=Style(color=COLORS["muted"])) + help_text.append("Load a session\n", style=Style(color=COLORS["text"])) + help_text.append(" /delete ", style=Style(color=COLORS["secondary"])) + help_text.append(" ", style=Style(color=COLORS["muted"])) + help_text.append("Delete a session\n", style=Style(color=COLORS["text"])) + help_text.append(" /title ", style=Style(color=COLORS["secondary"])) + help_text.append(" ", style=Style(color=COLORS["muted"])) + help_text.append("Set session title\n", style=Style(color=COLORS["text"])) + help_text.append("\n", style=Style(color=COLORS["text"])) + help_text.append( + " /help ", style=Style(color=COLORS["accent"]) + ) + help_text.append("Show this help\n", style=Style(color=COLORS["text"])) + help_text.append( + " /quit ", style=Style(color=COLORS["accent"]) + ) + help_text.append("Exit\n", style=Style(color=COLORS["text"])) + + panel = Panel( + help_text, + title=title, + title_align="left", + border_style=COLORS["border"], + padding=(1, 2), + ) + + self.console.print() + self.console.print(panel) + self.console.print() + + def print_status(self): + """Print current configuration status.""" + table = Table( + show_header=False, + show_edge=False, + box=None, + padding=(0, 2), + ) + table.add_column("key", style=Style(color=COLORS["muted"]), width=16) + table.add_column("value", style=Style(color=COLORS["text"])) + + # Session info + session_title = ( + self.session.state.get("title", "Untitled") + if self.session + else "Untitled" + ) + table.add_row( + "Session", Text(session_title, style=Style(color=COLORS["primary"])) + ) + table.add_row( + "Session ID", + Text(self.session_id[:8] + "...", style=Style(color=COLORS["muted"])), + ) + table.add_row( + "Messages", + Text( + str(len(self.conversation)), style=Style(color=COLORS["secondary"]) + ), + ) + + table.add_row( + "Model", Text(self.model, style=Style(color=COLORS["accent"])) + ) + table.add_row( + "Sub-model", + Text( + self.sub_model or self.model, style=Style(color=COLORS["secondary"]) + ), + ) + table.add_row( + "Max iterations", + Text(str(self.max_iterations), style=Style(color=COLORS["warning"])), + ) + + if self.files: + files_str = ", ".join(self.files[:5]) + if len(self.files) > 5: + files_str += f" (+{len(self.files) - 5} more)" + table.add_row( + "Files", Text(files_str, style=Style(color=COLORS["success"])) + ) + else: + table.add_row("Files", Text("(none)", style=Style(color=COLORS["muted"]))) + + log_path = self.rlm.log_path if self.rlm else None + if log_path: + table.add_row( + "Log file", Text(log_path, style=Style(color=COLORS["muted"])) + ) + + title = Text() + title.append("◇ ", style=Style(color=COLORS["secondary"])) + title.append("Status", style=Style(color=COLORS["secondary"])) + + panel = Panel( + table, + title=title, + title_align="left", + border_style=COLORS["muted"], + padding=(1, 2), + ) + self.console.print(panel) + + def print_answer(self, answer: str, iterations: int, time_ms: float): + """Print the final answer.""" + title = Text() + title.append("★ ", style=Style(color=COLORS["warning"])) + title.append("Answer", style=Style(color=COLORS["warning"], bold=True)) + title.append( + f" ({iterations} iterations, {time_ms/1000:.1f}s)", + style=Style(color=COLORS["muted"]), + ) + + # Try to render as markdown + try: + content = Markdown(answer) + except Exception: + content = Text(answer, style=Style(color=COLORS["text"])) + + panel = Panel( + content, + title=title, + title_align="left", + border_style=COLORS["warning"], + padding=(1, 2), + ) + + self.console.print() + self.console.print(panel) + + def print_error(self, message: str): + """Print an error message.""" + self.console.print(f"[red]Error:[/red] {message}") + + def print_info(self, message: str): + """Print an info message.""" + self.console.print(f"[{COLORS['accent']}]ℹ[/{COLORS['accent']}] {message}") + + async def handle_command(self, command: str) -> bool: + """ + Handle a slash command. + + Returns True if the REPL should continue, False to exit. + """ + parts = command.split(maxsplit=1) + cmd = parts[0].lower() + args = parts[1] if len(parts) > 1 else "" + + if cmd in ("/quit", "/exit", "/q"): + self.console.print(f"[{COLORS['muted']}]Goodbye![/{COLORS['muted']}]") + return False + + elif cmd == "/help": + self.print_welcome() + + elif cmd == "/status": + self.print_status() + + elif cmd == "/clear": + self.files = [] + self.conversation = [] + self._init_rlm() + await self._update_session_state({ + "files": [], + "conversation": [], + }) + self.print_info("Cleared files and conversation state.") + + elif cmd == "/files": + if not args: + if self.files: + self.console.print( + f"[{COLORS['muted']}]Current files:[/{COLORS['muted']}]" + ) + for f in self.files: + self.console.print(f" • {f}") + else: + self.print_info("No files loaded. Use /files to add files.") + else: + # Add new files + new_files = args.split() + try: + resolved = self.rlm.file_loader.create_lazy_files(new_files) + if len(resolved) == 0: + self.print_error(f"No files found matching: {' '.join(new_files)}") + else: + self.files.extend(new_files) + await self._update_session_state({"files": self.files}) + self.print_info( + f"Added {len(resolved)} file(s). Total patterns:" + f" {len(self.files)}" + ) + for f in resolved.names[:5]: + self.console.print( + f" [{COLORS['success']}]✓[/{COLORS['success']}] {f}" + ) + if len(resolved) > 5: + self.console.print( + f" [{COLORS['muted']}]... and {len(resolved) - 5}" + f" more[/{COLORS['muted']}]" + ) + except Exception as e: + self.print_error(f"Could not resolve files: {e}") + + elif cmd == "/model": + if not args: + self.console.print( + "Current model:" + f" [{COLORS['accent']}]{self.model}[/{COLORS['accent']}]" + ) + else: + self.model = args.strip() + self._init_rlm() + await self._update_session_state({"model": self.model}) + self.print_info(f"Model changed to: {self.model}") + + elif cmd == "/submodel": + if not args: + self.console.print( + "Current sub-model:" + f" [{COLORS['accent']}]{self.sub_model or self.model}[/{COLORS['accent']}]" + ) + else: + self.sub_model = args.strip() + self._init_rlm() + await self._update_session_state({"sub_model": self.sub_model}) + self.print_info(f"Sub-model changed to: {self.sub_model}") + + elif cmd == "/iterations": + if not args: + self.console.print( + "Max iterations:" + f" [{COLORS['accent']}]{self.max_iterations}[/{COLORS['accent']}]" + ) + else: + try: + self.max_iterations = int(args.strip()) + self._init_rlm() + await self._update_session_state( + {"max_iterations": self.max_iterations} + ) + self.print_info(f"Max iterations set to: {self.max_iterations}") + except ValueError: + self.print_error("Invalid number") + + elif cmd == "/logs": + log_path = self.rlm.log_path if self.rlm else None + if log_path: + self.console.print( + f"Log file: [{COLORS['accent']}]{log_path}[/{COLORS['accent']}]" + ) + else: + self.console.print( + "Log directory:" + f" [{COLORS['accent']}]{self.log_dir}[/{COLORS['accent']}]" + ) + + # Session management commands + elif cmd == "/sessions": + await self._list_sessions() + + elif cmd == "/new": + await self._new_session() + + elif cmd == "/load": + if not args: + self.print_error("Usage: /load ") + else: + await self._load_session(args.strip()) + + elif cmd == "/delete": + if not args: + self.print_error("Usage: /delete ") + else: + await self._delete_session(args.strip()) + + elif cmd == "/title": + if not args: + title = ( + self.session.state.get("title", "Untitled") + if self.session + else "Untitled" + ) + self.console.print( + f"Session title: [{COLORS['primary']}]{title}[/{COLORS['primary']}]" + ) + else: + await self._update_session_state({"title": args.strip()}) + self.print_info(f"Session title set to: {args.strip()}") + + else: + self.print_error( + f"Unknown command: {cmd}. Type /help for available commands." + ) + + return True + + async def _list_sessions(self): + """List all saved sessions.""" + await self._init_session_service() + response = await self.session_service.list_sessions( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + ) + + if not response.sessions: + self.print_info("No saved sessions.") + return + + table = Table( + show_header=True, + header_style=Style(color=COLORS["primary"], bold=True), + border_style=COLORS["border"], + ) + table.add_column("ID", style=Style(color=COLORS["muted"]), width=10) + table.add_column("Title", style=Style(color=COLORS["text"])) + table.add_column( + "Messages", style=Style(color=COLORS["secondary"]), justify="right" + ) + table.add_column("Updated", style=Style(color=COLORS["muted"])) + + # Sort by last update time descending + sessions = sorted( + response.sessions, + key=lambda s: s.last_update_time or 0, + reverse=True, + ) + + for s in sessions: + conv = s.state.get("conversation", []) + updated = ( + datetime.fromtimestamp(s.last_update_time).strftime("%Y-%m-%d %H:%M") + if s.last_update_time + else "Unknown" + ) + is_current = "→ " if s.id == self.session_id else " " + table.add_row( + is_current + s.id[:8], + s.state.get("title", "Untitled"), + str(len(conv)), + updated, + ) + + self.console.print() + self.console.print(table) + self.console.print( + f"\n[{COLORS['muted']}]Use /load to switch" + f" sessions[/{COLORS['muted']}]" + ) + + async def _new_session(self): + """Create a new session.""" + await self._init_session_service() + + # Close current RLM + if self.rlm: + self.rlm.close() + self.rlm = None + + # Create new session + self.session_id = str(uuid.uuid4()) + self.files = [] + self.conversation = [] + + self.session = await self.session_service.create_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=self.session_id, + state={ + "title": f"Session {datetime.now().strftime('%Y-%m-%d %H:%M')}", + "model": self.model, + "sub_model": self.sub_model, + "max_iterations": self.max_iterations, + "files": [], + "conversation": [], + }, + ) + + self._init_rlm() + self.print_info(f"Created new session: {self.session_id[:8]}...") + + async def _load_session(self, session_id_prefix: str): + """Load a session by ID or prefix.""" + await self._init_session_service() + + # Find session matching prefix + response = await self.session_service.list_sessions( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + ) + + matching = [ + s for s in response.sessions if s.id.startswith(session_id_prefix) + ] + + if not matching: + self.print_error(f"No session found matching: {session_id_prefix}") + return + if len(matching) > 1: + self.print_error( + f"Multiple sessions match '{session_id_prefix}'. Be more specific." + ) + return + + # Close current RLM + if self.rlm: + self.rlm.close() + self.rlm = None + + # Load the session + target = matching[0] + self.session_id = target.id + self.session = target + self.model = target.state.get("model", self.model) + self.sub_model = target.state.get("sub_model", self.sub_model) + self.max_iterations = target.state.get( + "max_iterations", self.max_iterations + ) + self.files = target.state.get("files", []) + self.conversation = target.state.get("conversation", []) + + self._init_rlm() + title = target.state.get("title", "Untitled") + self.print_info(f"Loaded session: {title} ({self.session_id[:8]}...)") + self.print_info( + f" {len(self.conversation)} messages, {len(self.files)} file patterns" + ) + + async def _delete_session(self, session_id_prefix: str): + """Delete a session by ID or prefix.""" + await self._init_session_service() + + # Find session matching prefix + response = await self.session_service.list_sessions( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + ) + + matching = [ + s for s in response.sessions if s.id.startswith(session_id_prefix) + ] + + if not matching: + self.print_error(f"No session found matching: {session_id_prefix}") + return + if len(matching) > 1: + self.print_error( + f"Multiple sessions match '{session_id_prefix}'. Be more specific." + ) + return + + target = matching[0] + + if target.id == self.session_id: + self.print_error( + "Cannot delete the current session. Switch to another session first." + ) + return + + await self.session_service.delete_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=target.id, + ) + title = target.state.get("title", "Untitled") + self.print_info(f"Deleted session: {title} ({target.id[:8]}...)") + + async def run_query(self, prompt: str): + """Run a query and stream the results.""" + self.display.reset() + + # Add user message to conversation + self.conversation.append({ + "role": "user", + "content": prompt, + "timestamp": datetime.now().isoformat(), + }) + + # Extract conversation history for the agent (exclude current message) + # Only include role and content, not timestamp + conversation_history = None + if len(self.conversation) > 1: + conversation_history = [ + {"role": msg["role"], "content": msg["content"]} + for msg in self.conversation[:-1] + ] + + # Build file context + if self.files: + try: + file_ctx = self.rlm.file_loader.build_context(self.files, lazy=True) + file_count = file_ctx.get("file_count", 0) + if file_count == 0: + self.print_error( + f"No files found matching patterns: {' '.join(self.files)}" + ) + self.print_info( + "Check that the paths exist. Use /files to see current patterns." + ) + # Remove the user message we just added + self.conversation.pop() + return + ctx = file_ctx + except Exception as e: + self.print_error(f"Failed to load files: {e}") + self.conversation.pop() + return + else: + ctx = { + "info": "No files loaded. The user is asking a question.", + } + + try: + with Live( + self.display.build_display(), + console=self.console, + refresh_per_second=10, + transient=True, + ) as live: + async for event in self.rlm.run_streaming( + ctx, prompt, conversation_history + ): + self.display.update_from_event(event) + live.update(self.display.build_display()) + + # Print final answer and save to conversation + if self.display.final_answer: + self.conversation.append({ + "role": "assistant", + "content": self.display.final_answer, + "timestamp": datetime.now().isoformat(), + }) + + # Auto-generate title from first exchange (like web.py) + title = self.session.state.get("title", "") if self.session else "" + if title.startswith("Session ") and len(self.conversation) == 2: + first_msg = self.conversation[0]["content"] + title = first_msg[:50] + ("..." if len(first_msg) > 50 else "") + await self._update_session_state({"title": title}) + + # Save conversation to session + await self._update_session_state({"conversation": self.conversation}) + + self.print_answer( + self.display.final_answer, + self.display.total_iterations, + self.display.execution_time_ms, + ) + else: + # Remove user message if no answer + self.conversation.pop() + self.print_error("No answer received") + + except KeyboardInterrupt: + self.conversation.pop() + self.console.print( + f"\n[{COLORS['warning']}]Interrupted[/{COLORS['warning']}]" + ) + except Exception as e: + self.conversation.pop() + self.print_error(str(e)) + + async def run(self): + """Main REPL loop.""" + # Initialize session + await self._load_or_create_session() + self._init_rlm() + + self.print_welcome() + + # Handle pending files from command line + pending_files = getattr(self, "_pending_files", None) + if pending_files: + try: + resolved = self.rlm.file_loader.create_lazy_files(pending_files) + if len(resolved) == 0: + self.print_error( + f"No files found matching: {' '.join(pending_files)}" + ) + self.print_info( + "Use /files to add files, or check your paths" + ) + else: + self.files.extend(pending_files) + await self._update_session_state({"files": self.files}) + self.print_info( + f"Loaded {len(resolved)} file(s) from {len(pending_files)}" + " pattern(s)" + ) + for f in resolved.names[:5]: + self.console.print( + f" [{COLORS['success']}]✓[/{COLORS['success']}] {f}" + ) + if len(resolved) > 5: + self.console.print( + f" [{COLORS['muted']}]... and {len(resolved) - 5}" + f" more[/{COLORS['muted']}]" + ) + except Exception as e: + self.print_error(f"Could not load files: {e}") + self._pending_files = None + + # Show session info on startup + title = ( + self.session.state.get("title", "Untitled") + if self.session + else "Untitled" + ) + msg_count = len(self.conversation) + if msg_count > 0: + self.print_info(f"Resumed session: {title} ({msg_count} messages)") + else: + self.print_info(f"Session: {title}") + + while True: + try: + # Get input + self.console.print() + user_input = Prompt.ask( + f"[{COLORS['accent']}]>[/{COLORS['accent']}]", + console=self.console, + ) + + if not user_input.strip(): + continue + + # Handle slash commands + if user_input.startswith("/"): + if not await self.handle_command(user_input): + break + continue + + # Run query + await self.run_query(user_input) + + except KeyboardInterrupt: + self.console.print( + f"\n[{COLORS['muted']}]Use /quit to exit[/{COLORS['muted']}]" + ) + except EOFError: + break + + # Cleanup + if self.rlm: + self.rlm.close() + + +async def run_interactive( + model: str = "gemini-3-pro-preview", + sub_model: str | None = None, + max_iterations: int = 30, + files: list[str] | None = None, + log_dir: str | None = None, + db_url: str = DEFAULT_DB_URL, + session_id: str | None = None, +): + """Run the interactive CLI.""" + cli = InteractiveCLI( + model=model, + sub_model=sub_model, + max_iterations=max_iterations, + log_dir=log_dir, + db_url=db_url, + session_id=session_id, + ) + + # Pre-load any files specified on command line (will be added after session init) + if files: + cli._pending_files = files + else: + cli._pending_files = None + + # Run will initialize session and RLM + await cli.run() + + +def main(): + """Main entry point for the CLI.""" + parser = argparse.ArgumentParser( + description="ADK-RLM: Interactive Recursive Language Model CLI", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Start interactive mode + python -m adk_rlm.cli + + # Start with files pre-loaded + python -m adk_rlm.cli --files "./docs/**/*.md" "./data/*.csv" + + # Use a specific model + python -m adk_rlm.cli --model gemini-3-pro-preview + + # Resume a specific session + python -m adk_rlm.cli --session abc12345 + """, + ) + + parser.add_argument( + "--files", + "-f", + type=str, + nargs="+", + help="File paths or glob patterns to pre-load", + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="gemini-3-pro-preview", + help="Main model to use (default: gemini-3-pro-preview)", + ) + parser.add_argument( + "--sub-model", + "-s", + type=str, + help="Sub-model for recursive calls (defaults to main model)", + ) + parser.add_argument( + "--max-iterations", + "-i", + type=int, + default=30, + help="Maximum number of iterations (default: 30)", + ) + parser.add_argument( + "--log-dir", + "-l", + type=str, + default="./logs", + help="Directory for JSONL logs (default: ./logs)", + ) + parser.add_argument( + "--db-url", + type=str, + default=DEFAULT_DB_URL, + help=( + "SQLAlchemy database URL for sessions (default:" + " sqlite+aiosqlite:///./cli_sessions.db)" + ), + ) + parser.add_argument( + "--session", + type=str, + help="Session ID to resume (creates new session if not specified)", + ) + + args = parser.parse_args() + + # Run the interactive CLI + asyncio.run( + run_interactive( + model=args.model, + sub_model=args.sub_model, + max_iterations=args.max_iterations, + files=args.files, + log_dir=args.log_dir, + db_url=args.db_url, + session_id=args.session, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/adk_rlm/code_executor.py b/contributing/samples/rlm/adk_rlm/code_executor.py new file mode 100644 index 0000000000..014dff0858 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/code_executor.py @@ -0,0 +1,985 @@ +""" +RLM Code Executor using ADK's BaseCodeExecutor. + +This module provides a custom code executor that wraps LocalREPL +and provides llm_query() and FINAL() functions for the RLM pattern. +""" + +import asyncio +import concurrent.futures +import logging +from queue import Empty +from queue import Queue +import threading +import time +from typing import Any +from typing import AsyncGenerator +from typing import TYPE_CHECKING +import uuid + +from google.genai import types + +from google import genai + +logger = logging.getLogger(__name__) +from adk_rlm.events import RLMEventData +from adk_rlm.events import RLMEventType +from adk_rlm.llm import AsyncLLMRateLimiter +from adk_rlm.llm import llm_rate_limit +from adk_rlm.repl.local_repl import LocalREPL +from adk_rlm.usage import UsageTracker +from google.adk.agents.invocation_context import InvocationContext +from google.adk.code_executors import BaseCodeExecutor +from google.adk.code_executors.code_execution_utils import CodeExecutionInput +from google.adk.code_executors.code_execution_utils import CodeExecutionResult +from google.adk.events.event import Event +from pydantic import PrivateAttr + +if TYPE_CHECKING: + from adk_rlm.logging.rlm_logger import RLMLogger + + +class RLMCodeExecutor(BaseCodeExecutor): + """ + Code executor that provides llm_query() and FINAL() functions. + + This executor wraps the LocalREPL and provides the RLM-specific + functions for recursive LLM calls and final answer detection. + + When current_depth < max_depth, llm_query() creates a nested RLM + execution that can itself execute code and make further llm_query calls. + When current_depth >= max_depth, llm_query() falls back to a simple + LLM call without code execution capability. + """ + + stateful: bool = True # Persist namespace across code blocks + + # Use ```repl delimiter instead of ```python + code_block_delimiters: list[tuple[str, str]] = [ + ("```repl\n", "\n```"), + ] + + # Private attributes (not part of the Pydantic schema) + _sub_model: str = PrivateAttr(default="gemini-3-flash-preview") + _current_depth: int = PrivateAttr(default=0) + _max_depth: int = PrivateAttr(default=5) + _max_iterations: int = PrivateAttr(default=30) + _repl: LocalREPL | None = PrivateAttr(default=None) + _final_answer: str | None = PrivateAttr(default=None) + _usage_tracker: UsageTracker = PrivateAttr(default_factory=UsageTracker) + _logger: "RLMLogger | None" = PrivateAttr(default=None) + _parent_agent: str | None = PrivateAttr(default=None) + _current_iteration: int = PrivateAttr(default=0) + _current_block_index: int = PrivateAttr(default=0) + + # Real-time event streaming via thread-safe queue + _event_queue: Queue = PrivateAttr(default_factory=Queue) + _execution_complete: threading.Event = PrivateAttr( + default_factory=threading.Event + ) + + # Ancestry tracking for nested agents + _ancestry: list[dict] = PrivateAttr(default_factory=list) + + # Counter for unique child agent names + _child_agent_counter: int = PrivateAttr(default=0) + + def __init__( + self, + sub_model: str = "gemini-3-flash-preview", + current_depth: int = 0, + max_depth: int = 5, + max_iterations: int = 30, + usage_tracker: UsageTracker | None = None, + logger: "RLMLogger | None" = None, + parent_agent: str | None = None, + ancestry: list[dict] | None = None, + **kwargs, + ): + """ + Initialize the RLM code executor. + + Args: + sub_model: The model to use for sub-LLM queries. + current_depth: Current recursion depth (0 = root level). + max_depth: Maximum recursion depth for nested RLM calls. + max_iterations: Maximum iterations for nested RLM calls. + usage_tracker: Optional usage tracker to record token usage. + logger: Optional logger for recording iterations. + parent_agent: Name of the parent agent that created this executor. + ancestry: List of ancestor agent context dicts for event tagging. + **kwargs: Additional arguments for BaseCodeExecutor. + """ + super().__init__(**kwargs) + self._sub_model = sub_model + self._current_depth = current_depth + self._max_depth = max_depth + self._max_iterations = max_iterations + self._repl = None + self._final_answer = None + self._usage_tracker = usage_tracker or UsageTracker() + self._logger = logger + self._parent_agent = parent_agent + self._ancestry = ancestry.copy() if ancestry else [] + + # Initialize queue and threading event + self._event_queue = Queue() + self._execution_complete = threading.Event() + + def _create_llm_query_fn(self): + """Create the llm_query function for the REPL environment. + + When current_depth < max_depth, this creates a nested RLM execution + that can itself execute code and make further llm_query calls. + When at max_depth, falls back to a simple LLM call. + """ + + def llm_query( + prompt: str, + context: Any = None, + model: str | None = None, + recursive: bool = True, + ) -> str: + """ + Query an LLM with the given prompt. + + Args: + prompt: The prompt to send to the LLM. + context: Optional context object(s) to pass to the child agent. + Can be a LazyFile, LazyFileCollection, dict, list, or string. + The child agent can access this via its `context` variable. + model: Optional model override. + recursive: If True and depth allows, use recursive RLM execution. + If False, always use simple LLM call. + + Returns: + The LLM's response text. + """ + target_model = model or self._sub_model + + # Check if we can do recursive execution + can_recurse = recursive and (self._current_depth < self._max_depth) + + if can_recurse: + # Create a nested RLM execution + return self._run_recursive_rlm( + prompt, target_model, context_obj=context + ) + else: + # Simple LLM call (no code execution) + return self._simple_llm_call(prompt, target_model) + + return llm_query + + def _simple_llm_call( + self, + prompt: str, + model: str, + batch_index: int | None = None, + batch_size: int | None = None, + ) -> str: + """Make a simple LLM call without code execution capability. + + Emits SUB_LLM_START and SUB_LLM_END events for UI visibility and logs + the call to the JSONL logger. + + Args: + prompt: The prompt to send to the LLM. + model: The model to use. + batch_index: Position within a batch (0-indexed), if part of a batch. + batch_size: Total number of items in the batch, if part of a batch. + + Returns: + The LLM's response text, or an error message if the call failed. + """ + # Emit start event + self._emit_sub_llm_event( + RLMEventType.SUB_LLM_START, + model=model, + prompt=prompt, + batch_index=batch_index, + batch_size=batch_size, + ) + + start_time = time.perf_counter() + error_msg = None + response_text = None + + try: + # Create a fresh client for each simple LLM call to avoid + # "Event loop is closed" errors when called from thread pool. + # A shared client may hold references to an event loop that + # is no longer valid in this thread context. + client = genai.Client(vertexai=True, location="global") + # Disable function calling to prevent MALFORMED_FUNCTION_CALL errors + config = types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig(mode="NONE") + ) + ) + with llm_rate_limit(): + response = client.models.generate_content( + model=model, + contents=prompt, + config=config, + ) + self._usage_tracker.add_from_response(model, response.usage_metadata) + + # Handle None/empty responses with detailed logging + if response.text is None or response.text == "": + finish_reason = None + block_reason = None + if response.candidates: + finish_reason = getattr(response.candidates[0], "finish_reason", None) + if hasattr(response, "prompt_feedback"): + block_reason = getattr(response.prompt_feedback, "block_reason", None) + + logger.warning( + "Simple LLM call returned empty response: model=%s, " + "finish_reason=%s, block_reason=%s, prompt_preview=%s", + model, + finish_reason, + block_reason, + prompt[:100] if prompt else None, + ) + + reason_parts = [] + if finish_reason: + reason_parts.append(f"finish_reason={finish_reason}") + if block_reason: + reason_parts.append(f"block_reason={block_reason}") + reason_str = ( + ", ".join(reason_parts) if reason_parts else "unknown reason" + ) + response_text = f"[LLM returned empty response: {reason_str}]" + else: + response_text = response.text + except Exception as e: + error_msg = str(e) + response_text = f"Error: LLM query failed - {e}" + + execution_time_ms = (time.perf_counter() - start_time) * 1000 + + # Emit end event + self._emit_sub_llm_event( + RLMEventType.SUB_LLM_END, + model=model, + response=response_text if not error_msg else None, + error=error_msg, + execution_time_ms=execution_time_ms, + batch_index=batch_index, + batch_size=batch_size, + ) + + # Log to JSONL + self._log_simple_llm_call( + prompt=prompt, + response=response_text, + model=model, + execution_time_ms=execution_time_ms, + batch_index=batch_index, + batch_size=batch_size, + error=error_msg, + ) + + return response_text + + def _get_current_ancestry_entry(self) -> dict: + """Get the current agent's context for ancestry chain.""" + return { + "agent": self._parent_agent, + "depth": self._current_depth, + "iteration": self._current_iteration, + "block_index": self._current_block_index, + } + + def _emit_sub_llm_event( + self, + event_type: RLMEventType, + model: str, + prompt: str | None = None, + response: str | None = None, + error: str | None = None, + execution_time_ms: float | None = None, + batch_index: int | None = None, + batch_size: int | None = None, + ) -> None: + """Emit a sub-LLM event for simple (non-recursive) LLM calls. + + Args: + event_type: The type of event (SUB_LLM_START or SUB_LLM_END). + model: The model being used. + prompt: The prompt (for START events). + response: The response (for END events). + error: Error message if the call failed. + execution_time_ms: Execution time in milliseconds (for END events). + batch_index: Position within a batch (0-indexed). + batch_size: Total number of items in the batch. + """ + event_data = RLMEventData( + event_type=event_type, + model=model, + prompt_preview=prompt[:200] if prompt else None, + response_preview=response[:500] if response else None, + response_full=response, + error=error, + execution_time_ms=execution_time_ms, + iteration=self._current_iteration, + block_index=self._current_block_index, + batch_index=batch_index, + batch_size=batch_size, + metadata={"recursive": False}, + ) + + metadata = event_data.to_dict() + metadata["agent_name"] = self._parent_agent + metadata["agent_depth"] = self._current_depth + metadata["ancestry"] = self._ancestry + [self._get_current_ancestry_entry()] + + event = Event( + invocation_id=str(uuid.uuid4()), + author=self._parent_agent or "code_executor", + custom_metadata=metadata, + ) + + self._event_queue.put(event) + + def _log_simple_llm_call( + self, + prompt: str, + response: str, + model: str, + execution_time_ms: float, + batch_index: int | None = None, + batch_size: int | None = None, + error: str | None = None, + ) -> None: + """Log a simple (non-recursive) LLM call to the JSONL logger. + + Args: + prompt: The prompt sent to the LLM. + response: The response received (or error message if failed). + model: The model used. + execution_time_ms: Execution time in milliseconds. + batch_index: Position within a batch (0-indexed). + batch_size: Total number of items in the batch. + error: Error message if the call failed. + """ + if self._logger is None: + return + + self._logger.log_simple_llm_call( + prompt=prompt, + response=response, + model=model, + execution_time_ms=execution_time_ms, + depth=self._current_depth, + agent_name=self._parent_agent, + parent_iteration=self._current_iteration, + parent_block_index=self._current_block_index, + batch_index=batch_index, + batch_size=batch_size, + error=error, + ) + + def _run_recursive_rlm( + self, + prompt: str, + model: str, + context_obj: Any = None, + parallel_batch_id: str | None = None, + batch_index: int | None = None, + batch_size: int | None = None, + ) -> str: + """Run a nested RLM execution at depth + 1 with real-time event streaming. + + Args: + prompt: The prompt to send to the child agent. + model: The model to use for the child agent. + context_obj: Optional context object to pass to the child agent. + This becomes the child's `context` variable directly. + parallel_batch_id: Optional UUID identifying a parallel batch. + batch_index: Optional position within the batch (0-indexed). + batch_size: Optional total number of items in the batch. + """ + next_depth = self._current_depth + 1 + + # Generate unique child agent name using counter + child_index = self._child_agent_counter + self._child_agent_counter += 1 + nested_agent_name = f"rlm_agent_depth_{next_depth}_{child_index}" + + # Build child's ancestry = parent's ancestry + current context + child_ancestry = self._ancestry + [self._get_current_ancestry_entry()] + + # Reference to the event queue for the nested function + event_queue = self._event_queue + + # Capture context_obj for the nested async function + child_context = context_obj + + async def run_nested_async(): + """Run the nested agent async and stream events to queue.""" + import uuid + + from adk_rlm.agents.rlm_agent import RLMAgent + from google.adk.agents.invocation_context import InvocationContext + from google.adk.sessions import InMemorySessionService + from google.adk.sessions import Session + + # Create a nested RLM agent at the next depth level + nested_agent = RLMAgent( + name=nested_agent_name, + model=model, + sub_model=self._sub_model, + max_iterations=self._max_iterations, + max_depth=self._max_depth, + current_depth=next_depth, + logger=self._logger, + parent_agent=self._parent_agent, + ancestry=child_ancestry, # Pass ancestry to child + verbose=False, + ) + + # Create mock session with context + # If context_obj is provided, it becomes the child's `context` variable directly + # Otherwise, fall back to {"query": prompt} for backwards compatibility + rlm_context = ( + child_context if child_context is not None else {"query": prompt} + ) + + mock_session = Session( + id=str(uuid.uuid4()), + app_name="adk_rlm", + user_id="default_user", + state={ + "rlm_context": rlm_context, + "rlm_prompt": prompt, + }, + ) + mock_session_service = InMemorySessionService() + mock_ctx = InvocationContext( + invocation_id=str(uuid.uuid4()), + session=mock_session, + session_service=mock_session_service, + agent=nested_agent, + ) + + final_answer = None + + try: + async for event in nested_agent._run_async_impl(mock_ctx): + # Only add ancestry if not already present (preserve nested info) + if event.custom_metadata and "ancestry" not in event.custom_metadata: + event.custom_metadata["ancestry"] = child_ancestry + event.custom_metadata["agent_name"] = nested_agent_name + event.custom_metadata["agent_depth"] = next_depth + # Add parent info for backwards compatibility + event.custom_metadata["parent_agent"] = self._parent_agent + event.custom_metadata["parent_iteration"] = self._current_iteration + event.custom_metadata["parent_block_index"] = ( + self._current_block_index + ) + # Add batch metadata if this is part of a parallel batch + if parallel_batch_id is not None: + event.custom_metadata["parallel_batch_id"] = parallel_batch_id + event.custom_metadata["batch_index"] = batch_index + event.custom_metadata["batch_size"] = batch_size + + # Push to queue immediately for real-time streaming + event_queue.put(event) + + # Check for final answer + if event.custom_metadata: + from adk_rlm.events import RLMEventType + + event_type = event.custom_metadata.get("event_type") + if event_type == RLMEventType.FINAL_ANSWER.value: + final_answer = event.custom_metadata.get("answer") + + # Merge usage + self._usage_tracker.merge(nested_agent._usage_tracker) + finally: + # Properly close the nested agent's genai client before event loop closes + # This prevents "Event loop is closed" errors during cleanup + if nested_agent._client is not None: + try: + await nested_agent._client.aio.aclose() + except Exception: + pass # Ignore cleanup errors + + return final_answer + + try: + # Run async in a thread pool to avoid event loop conflicts + try: + asyncio.get_running_loop() + # Already in an event loop, use thread pool + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, run_nested_async()) + final_answer = future.result() + except RuntimeError: + # No running loop, safe to use asyncio.run directly + final_answer = asyncio.run(run_nested_async()) + + if final_answer is None: + return "[Recursive RLM returned no result]" + return final_answer + + except Exception as e: + # Fall back to simple call on error + return ( + f"[Recursive RLM at depth {next_depth} failed: {e}]\n" + + self._simple_llm_call(prompt, model) + ) + + def _create_llm_query_batched_fn(self): + """Create the llm_query_batched function for the REPL environment. + + When recursive=True, runs child agents in parallel using ThreadPoolExecutor. + When recursive=False, uses async gather for simple parallel LLM calls. + """ + + def llm_query_batched( + prompts: list[str], + contexts: list[Any] | None = None, + model: str | None = None, + recursive: bool = False, + ) -> list[str]: + """ + Query an LLM with multiple prompts concurrently. + + Args: + prompts: List of prompts to send. + contexts: Optional list of context objects (same length as prompts). + If provided, each prompt gets paired with its context. + model: Optional model override. + recursive: If True, use recursive RLM execution for each prompt. + Default is False for performance (simple LLM calls). + + Returns: + List of LLM response texts in the same order as prompts. + """ + if contexts is not None and len(contexts) != len(prompts): + raise ValueError( + f"contexts length ({len(contexts)}) must match prompts length" + f" ({len(prompts)})" + ) + + target_model = model or self._sub_model + + if recursive and self._current_depth < self._max_depth: + # Parallel recursive execution using ThreadPoolExecutor + return self._run_parallel_recursive(prompts, contexts, target_model) + + # Simple async batched calls (no recursion) - emit events for each query + batch_size = len(prompts) + + # Capture references needed in the async functions + usage_tracker = self._usage_tracker + emit_event = self._emit_sub_llm_event + log_call = self._log_simple_llm_call + + async def query_single( + client: genai.Client, prompt: str, batch_index: int + ) -> str: + # Emit start event + emit_event( + RLMEventType.SUB_LLM_START, + model=target_model, + prompt=prompt, + batch_index=batch_index, + batch_size=batch_size, + ) + + start_time = time.perf_counter() + error_msg = None + response_text = None + + try: + # Disable function calling to prevent MALFORMED_FUNCTION_CALL errors + config = types.GenerateContentConfig( + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode="NONE" + ) + ) + ) + async with AsyncLLMRateLimiter(): + response = await client.aio.models.generate_content( + model=target_model, + contents=prompt, + config=config, + ) + usage_tracker.add_from_response(target_model, response.usage_metadata) + + # Handle None/empty responses with detailed logging + if response.text is None or response.text == "": + finish_reason = None + block_reason = None + if response.candidates: + finish_reason = getattr( + response.candidates[0], "finish_reason", None + ) + if hasattr(response, "prompt_feedback"): + block_reason = getattr( + response.prompt_feedback, "block_reason", None + ) + + logger.warning( + "Batched LLM call returned empty response: model=%s, " + "batch_index=%s/%s, finish_reason=%s, block_reason=%s", + target_model, + batch_index, + batch_size, + finish_reason, + block_reason, + ) + + reason_parts = [] + if finish_reason: + reason_parts.append(f"finish_reason={finish_reason}") + if block_reason: + reason_parts.append(f"block_reason={block_reason}") + reason_str = ( + ", ".join(reason_parts) if reason_parts else "unknown reason" + ) + response_text = f"[LLM returned empty response: {reason_str}]" + else: + response_text = response.text + except Exception as e: + error_msg = str(e) + response_text = f"Error: LLM query failed - {e}" + + execution_time_ms = (time.perf_counter() - start_time) * 1000 + + # Emit end event + emit_event( + RLMEventType.SUB_LLM_END, + model=target_model, + response=response_text if not error_msg else None, + error=error_msg, + execution_time_ms=execution_time_ms, + batch_index=batch_index, + batch_size=batch_size, + ) + + # Log to JSONL + log_call( + prompt=prompt, + response=response_text, + model=target_model, + execution_time_ms=execution_time_ms, + batch_index=batch_index, + batch_size=batch_size, + error=error_msg, + ) + + return response_text + + async def run_all(): + # Create a fresh client in this event loop to avoid + # "Event loop is closed" errors from cross-thread usage + client = genai.Client(vertexai=True, location="global") + tasks = [query_single(client, p, i) for i, p in enumerate(prompts)] + return await asyncio.gather(*tasks) + + # Run in a new event loop if we're not already in one + try: + asyncio.get_running_loop() + # If we're in a running loop, create a new one in a thread + with concurrent.futures.ThreadPoolExecutor() as pool: + future = pool.submit(asyncio.run, run_all()) + return future.result() + except RuntimeError: + # No running loop, safe to use asyncio.run + return asyncio.run(run_all()) + + return llm_query_batched + + def _run_parallel_recursive( + self, + prompts: list[str], + contexts: list[Any] | None, + model: str, + ) -> list[str]: + """Run multiple recursive RLM calls in parallel. + + This spawns child agents concurrently using ThreadPoolExecutor. + Rate limiting is handled by the global LLM semaphore. + + Args: + prompts: List of prompts to send. + contexts: Optional list of context objects (same length as prompts). + model: The model to use for child agents. + + Returns: + List of results in the same order as prompts. + """ + contexts = contexts or [None] * len(prompts) + batch_id = str(uuid.uuid4()) + batch_size = len(prompts) + + def run_one(idx: int) -> tuple[int, str]: + """Run a single recursive RLM call and return (index, result).""" + prompt = prompts[idx] + context = contexts[idx] + try: + result = self._run_recursive_rlm( + prompt, + model, + context_obj=context, + parallel_batch_id=batch_id, + batch_index=idx, + batch_size=batch_size, + ) + return (idx, result) + except Exception as e: + return (idx, f"[Error in batch item {idx}: {e}]") + + results = [None] * len(prompts) + + with concurrent.futures.ThreadPoolExecutor() as pool: + futures = [pool.submit(run_one, i) for i in range(len(prompts))] + + for future in concurrent.futures.as_completed(futures): + try: + idx, result = future.result() + results[idx] = result + except Exception: + # This shouldn't happen since run_one catches exceptions, + # but handle it just in case + pass + + # Replace any None results with error messages + for i, result in enumerate(results): + if result is None: + results[i] = f"[Error: batch item {i} returned no result]" + + return results + + def _ensure_repl(self) -> LocalREPL: + """Ensure the REPL is initialized.""" + if self._repl is None: + self._repl = LocalREPL( + llm_query_fn=self._create_llm_query_fn(), + llm_query_batched_fn=self._create_llm_query_batched_fn(), + ) + return self._repl + + def execute_code( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """ + Execute code in the RLM REPL environment. + + Args: + invocation_context: The ADK invocation context. + code_execution_input: The code to execute. + + Returns: + CodeExecutionResult with stdout/stderr. + """ + repl = self._ensure_repl() + + # Execute code + result = repl.execute_code(code_execution_input.code) + + # Check for FINAL answer in namespace + if "FINAL_ANSWER" in repl.locals: + self._final_answer = str(repl.locals["FINAL_ANSWER"]) + + return CodeExecutionResult( + stdout=result.stdout, + stderr=result.stderr, + output_files=[], + ) + + def reset_event_state(self) -> None: + """Reset the event queue and completion flag. + + This should be called BEFORE starting execute_code_async to avoid + race conditions between the execution task and event polling. + + Note: We intentionally do NOT reset _child_agent_counter here. + Keeping it monotonically increasing ensures unique agent names + across all iterations and code blocks within a run. + """ + self._event_queue = Queue() + self._execution_complete.clear() + + async def execute_code_async( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """ + Execute code in the RLM REPL environment asynchronously. + + This runs the code execution in a thread pool to avoid blocking + the event loop, which is important when the code calls llm_query() + with recursive=True and spawns child agents. + + Note: Call reset_event_state() BEFORE creating the task to avoid + race conditions with poll_child_events(). + + Args: + invocation_context: The ADK invocation context. + code_execution_input: The code to execute. + + Returns: + CodeExecutionResult with stdout/stderr. + """ + # Run the synchronous execute_code in a thread pool + # This allows the event loop to continue processing (e.g., sending websocket events) + # while the code execution (which may spawn child agents) runs + result = await asyncio.to_thread( + self._execute_code_with_completion, + invocation_context, + code_execution_input, + ) + return result + + def _execute_code_with_completion( + self, + invocation_context: InvocationContext, + code_execution_input: CodeExecutionInput, + ) -> CodeExecutionResult: + """Execute code and signal completion when done.""" + try: + return self.execute_code(invocation_context, code_execution_input) + finally: + self._execution_complete.set() + + async def poll_child_events(self) -> AsyncGenerator[Event, None]: + """ + Poll for child agent events during code execution. + + This async generator yields events as they arrive from child agents + running in the thread pool. It should be called in a loop while + code execution is running. + + Yields: + Event objects from child agents as they arrive. + """ + while ( + not self._execution_complete.is_set() or not self._event_queue.empty() + ): + try: + event = self._event_queue.get_nowait() + yield event + except Empty: + # Small sleep to avoid busy-wait + await asyncio.sleep(0.01) + + def load_context(self, context_payload: dict | list | str) -> None: + """ + Load context into the REPL environment. + + Args: + context_payload: The context data to load. + """ + repl = self._ensure_repl() + repl.load_context(context_payload) + + def add_context(self, context_payload: dict | list | str) -> int: + """ + Add additional context to the REPL environment. + + Args: + context_payload: The context data to add. + + Returns: + The context index. + """ + repl = self._ensure_repl() + return repl.add_context(context_payload) + + def get_context_count(self) -> int: + """Return the number of contexts loaded.""" + if self._repl is None: + return 0 + return self._repl.get_context_count() + + def get_history_count(self) -> int: + """Return the number of conversation histories stored.""" + if self._repl is None: + return 0 + return self._repl.get_history_count() + + def add_history(self, message_history: list[dict[str, Any]]) -> int: + """ + Store a conversation's message history. + + Args: + message_history: The list of message dicts. + + Returns: + The history index. + """ + repl = self._ensure_repl() + return repl.add_history(message_history) + + @property + def final_answer(self) -> str | None: + """Return the final answer if detected via FINAL_ANSWER variable.""" + return self._final_answer + + def reset_final_answer(self) -> None: + """Reset the final answer state.""" + self._final_answer = None + + @property + def locals(self) -> dict[str, Any]: + """Return the REPL locals for variable inspection.""" + if self._repl is None: + return {} + return self._repl.locals + + @property + def usage_tracker(self) -> UsageTracker: + """Return the usage tracker.""" + return self._usage_tracker + + def set_iteration_context(self, iteration: int, block_index: int) -> None: + """Set the current iteration context for child event tagging. + + Args: + iteration: The current parent iteration number (1-indexed). + block_index: The current code block index within the iteration. + """ + self._current_iteration = iteration + self._current_block_index = block_index + + def pop_child_events(self) -> list: + """Get and clear any remaining child agent events from the queue. + + This is provided for backwards compatibility. With the new streaming + architecture, events are yielded in real-time via poll_child_events(). + + Returns: + List of remaining events from the queue, cleared after retrieval. + """ + events = [] + while not self._event_queue.empty(): + try: + events.append(self._event_queue.get_nowait()) + except Empty: + break + return events + + def cleanup(self) -> None: + """Clean up the REPL environment.""" + if self._repl: + self._repl.cleanup() + self._repl = None + self._final_answer = None + # Clear the event queue + while not self._event_queue.empty(): + try: + self._event_queue.get_nowait() + except Empty: + break + self._execution_complete.clear() diff --git a/contributing/samples/rlm/adk_rlm/events.py b/contributing/samples/rlm/adk_rlm/events.py new file mode 100644 index 0000000000..4ced4454f4 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/events.py @@ -0,0 +1,92 @@ +""" +RLM Event types for streaming execution updates. + +These events provide granular visibility into RLM execution, enabling +various interfaces (CLI, Web UI, API) to show real-time progress. +""" + +from dataclasses import dataclass +from dataclasses import field +from enum import Enum +from typing import Any + + +class RLMEventType(str, Enum): + """Event types emitted during RLM execution.""" + + # Lifecycle events + RUN_START = "rlm.run.start" # Agent starting execution + RUN_END = "rlm.run.end" # Agent finished (with final answer) + RUN_ERROR = "rlm.run.error" # Agent encountered error + + # Iteration events + ITERATION_START = "rlm.iteration.start" # Starting iteration N + ITERATION_END = "rlm.iteration.end" # Completed iteration N + + # LLM events + LLM_CALL_START = "rlm.llm.start" # Calling main LLM + LLM_CALL_END = "rlm.llm.end" # Main LLM response received + LLM_RESPONSE = "rlm.llm.response" # Streaming LLM response chunk + + # Code execution events + CODE_FOUND = "rlm.code.found" # Found code block in response + CODE_EXEC_START = "rlm.code.start" # Starting code execution + CODE_EXEC_END = "rlm.code.end" # Code execution completed + CODE_OUTPUT = "rlm.code.output" # Code produced output + + # Sub-LLM events (from llm_query calls) + SUB_LLM_START = "rlm.sub_llm.start" # Sub-LLM query started + SUB_LLM_END = "rlm.sub_llm.end" # Sub-LLM query completed + SUB_LLM_BATCH = "rlm.sub_llm.batch" # Batched sub-LLM queries + + # Final answer + FINAL_DETECTED = "rlm.final.detected" # FINAL() pattern found + FINAL_ANSWER = "rlm.final.answer" # Final answer content + + +@dataclass +class RLMEventData: + """Structured data for RLM events.""" + + event_type: RLMEventType + iteration: int | None = None + code: str | None = None # Truncated code preview for sidebar + code_full: str | None = None # Full code for modal display + output: str | None = None # Truncated output preview for sidebar + output_full: str | None = None # Full output for modal display + error: str | None = None # Truncated error preview for sidebar + error_full: str | None = None # Full error for modal display + model: str | None = None + prompt_preview: str | None = None # First N chars of prompt + response_preview: str | None = None # First N chars of response + response_full: str | None = None # Full LLM response for modal display + answer: str | None = None + token_count: int | None = None + execution_time_ms: float | None = None + total_iterations: int | None = None + success: bool | None = None + fallback: bool | None = None + block_index: int | None = None + has_error: bool | None = None + source: str | None = None # "text" or "variable" for FINAL detection + # Parallel batch metadata (for llm_query_batched with recursive=True) + parallel_batch_id: str | None = None # UUID identifying the batch + batch_index: int | None = None # Position within the batch (0-indexed) + batch_size: int | None = None # Total number of items in the batch + metadata: dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary, excluding None values.""" + result = {"event_type": self.event_type.value} + for key, value in self.__dict__.items(): + if key != "event_type" and value is not None: + if key == "metadata" and not value: + continue + result[key] = value + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "RLMEventData": + """Create from dictionary.""" + event_type = RLMEventType(data.pop("event_type")) + return cls(event_type=event_type, **data) diff --git a/contributing/samples/rlm/adk_rlm/files/__init__.py b/contributing/samples/rlm/adk_rlm/files/__init__.py new file mode 100644 index 0000000000..48ab4e026d --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/__init__.py @@ -0,0 +1,65 @@ +""" +File handling module for ADK-RLM. + +This module provides functionality for loading and parsing files from +various sources (local filesystem, cloud storage) and formats (text, PDF). + +Features: +- Progressive disclosure via lazy loading (Level 0/1/2) +- Pluggable file sources (local, SharePoint, GDrive, S3, HTTP) +- Pluggable file parsers (text, PDF, Office documents) +- Glob pattern support for batch file operations + +Example: + ```python + from adk_rlm.files import FileLoader, LocalFileSource + + # Basic usage with local files + loader = FileLoader() + files = loader.create_lazy_files(["./docs/**/*.pdf"]) + + # Level 0 - no I/O + for f in files: + print(f.name, f.extension) + + # Level 1 - metadata only + large_files = [f for f in files if f.size_mb > 10] + + # Level 2 - full content + for f in large_files: + print(f.content[:1000]) + ``` +""" + +from adk_rlm.files.base import FileMetadata +from adk_rlm.files.base import LoadedFile +from adk_rlm.files.base import ParsedContent +from adk_rlm.files.lazy import LazyFile +from adk_rlm.files.lazy import LazyFileCollection +from adk_rlm.files.loader import FileLoader +from adk_rlm.files.loader import FileSpec +from adk_rlm.files.parsers import FileParser +from adk_rlm.files.parsers import PDFParser +from adk_rlm.files.parsers import TextParser +from adk_rlm.files.sources import FileSource +from adk_rlm.files.sources import LocalFileSource + +__all__ = [ + # Base types + "FileMetadata", + "LoadedFile", + "ParsedContent", + # Lazy loading + "LazyFile", + "LazyFileCollection", + # Loader + "FileLoader", + "FileSpec", + # Sources + "FileSource", + "LocalFileSource", + # Parsers + "FileParser", + "TextParser", + "PDFParser", +] diff --git a/contributing/samples/rlm/adk_rlm/files/base.py b/contributing/samples/rlm/adk_rlm/files/base.py new file mode 100644 index 0000000000..7d3374f4ff --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/base.py @@ -0,0 +1,101 @@ +""" +Base types for the file handling module. + +This module defines the core data structures used throughout the file +handling system: FileMetadata, LoadedFile, and ParsedContent. +""" + +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +from pathlib import Path +from typing import Any + + +@dataclass +class FileMetadata: + """Metadata about a loaded file.""" + + name: str + path: str # Original path/URI + source_type: str # "local", "sharepoint", "gdrive", etc. + size_bytes: int + mime_type: str | None = None + last_modified: datetime | None = None + extra: dict[str, Any] = field(default_factory=dict) + + @property + def size_kb(self) -> float: + """File size in KB.""" + return self.size_bytes / 1024 + + @property + def size_mb(self) -> float: + """File size in MB.""" + return self.size_bytes / (1024 * 1024) + + @property + def extension(self) -> str: + """File extension (lowercase, with leading dot).""" + if "." in self.name: + return "." + self.name.rsplit(".", 1)[-1].lower() + return "" + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "name": self.name, + "path": self.path, + "source_type": self.source_type, + "size_bytes": self.size_bytes, + "mime_type": self.mime_type, + "last_modified": ( + self.last_modified.isoformat() if self.last_modified else None + ), + "extra": self.extra, + } + + +@dataclass +class LoadedFile: + """A file loaded from a source with raw content.""" + + metadata: FileMetadata + content: bytes + + def as_text(self, encoding: str = "utf-8") -> str: + """Decode content as text.""" + return self.content.decode(encoding) + + +@dataclass +class ParsedContent: + """Parsed content from a file.""" + + text: str # Extracted text content + metadata: dict[str, Any] = field( + default_factory=dict + ) # Parser-specific metadata + chunks: list[str] | None = None # Optional: pre-chunked content (e.g., pages) + tables: list[dict[str, Any]] | None = None # Optional: extracted tables + images: list[bytes] | None = None # Optional: extracted images + + @property + def has_tables(self) -> bool: + """Check if tables were extracted.""" + return self.tables is not None and len(self.tables) > 0 + + @property + def has_chunks(self) -> bool: + """Check if content was pre-chunked.""" + return self.chunks is not None and len(self.chunks) > 0 + + @property + def chunk_count(self) -> int: + """Number of chunks (0 if not chunked).""" + return len(self.chunks) if self.chunks else 0 + + @property + def table_count(self) -> int: + """Number of tables extracted (0 if none).""" + return len(self.tables) if self.tables else 0 diff --git a/contributing/samples/rlm/adk_rlm/files/lazy.py b/contributing/samples/rlm/adk_rlm/files/lazy.py new file mode 100644 index 0000000000..b4705afb2c --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/lazy.py @@ -0,0 +1,464 @@ +""" +Lazy file loading with progressive disclosure. + +This module provides LazyFile and LazyFileCollection classes that support +on-demand loading of file content at three levels: + +- Level 0 (free): name, path, extension - from initial listing +- Level 1 (cheap): size, modified_date, mime_type - metadata request +- Level 2 (expensive): content, tables, chunks - full download + parse +""" + +from __future__ import annotations + +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +import fnmatch +from typing import Any +from typing import Iterator +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from adk_rlm.files.base import FileMetadata + from adk_rlm.files.base import LoadedFile + from adk_rlm.files.base import ParsedContent + from adk_rlm.files.parsers.base import FileParser + from adk_rlm.files.sources.base import FileSource + + +@dataclass +class LazyFile: + """ + A lazy file reference that loads content on first access. + + Progressive disclosure file access: + - Level 0 (free): name, path, extension - from initial listing + - Level 1 (cheap): size, modified_date, mime_type - metadata request + - Level 2 (expensive): content, tables, chunks - full download + parse + + Metadata (name, path) is available immediately. + Content and parsed data load on-demand. + + Example: + ```python + # Level 0 - instant, no I/O + print(file.name) # "report.pdf" + print(file.extension) # ".pdf" + + # Level 1 - stat/HEAD request + print(file.size_kb) # 1024.5 + print(file.modified_date) + + # Level 2 - full download + parse + print(file.content[:100]) # First 100 chars + print(file.tables) # Extracted tables + ``` + """ + + path: str + source: "FileSource" + parser: "FileParser | None" = None + + # Cached data (loaded on demand) + _metadata: "FileMetadata | None" = field(default=None, repr=False) + _loaded: "LoadedFile | None" = field(default=None, repr=False) + _parsed: "ParsedContent | None" = field(default=None, repr=False) + + # ========================================================================= + # Level 0: Always available (from path) - No I/O required + # ========================================================================= + + @property + def name(self) -> str: + """Filename - available without loading.""" + return self.path.split("/")[-1].split("\\")[-1] + + @property + def extension(self) -> str: + """File extension (lowercase, with leading dot) - available without loading.""" + if "." in self.name: + return "." + self.name.rsplit(".", 1)[-1].lower() + return "" + + @property + def is_loaded(self) -> bool: + """Check if content has been loaded.""" + return self._loaded is not None + + @property + def is_parsed(self) -> bool: + """Check if content has been parsed.""" + return self._parsed is not None + + @property + def level(self) -> int: + """ + Current loading level. + + - 0: path only (no I/O) + - 1: metadata loaded + - 2: content parsed + """ + if self._parsed is not None: + return 2 + if self._metadata is not None: + return 1 + return 0 + + # ========================================================================= + # Level 1: Metadata (lazy, cached) - Cheap I/O (stat/HEAD) + # ========================================================================= + + def _ensure_metadata(self) -> None: + """Load just metadata (not full content).""" + if self._metadata is None: + # If we already have the loaded file, use its metadata + if self._loaded is not None: + self._metadata = self._loaded.metadata + else: + # Sources can implement get_metadata() for efficiency + self._metadata = self.source.get_metadata(self.path) + + @property + def metadata(self) -> "FileMetadata": + """Full file metadata - Level 1 (triggers metadata load).""" + self._ensure_metadata() + assert self._metadata is not None + return self._metadata + + @property + def size(self) -> int: + """File size in bytes - Level 1.""" + return self.metadata.size_bytes + + @property + def size_kb(self) -> float: + """File size in KB - Level 1.""" + return self.size / 1024 + + @property + def size_mb(self) -> float: + """File size in MB - Level 1.""" + return self.size / (1024 * 1024) + + @property + def modified_date(self) -> datetime | None: + """Last modified date - Level 1.""" + return self.metadata.last_modified + + @property + def mime_type(self) -> str | None: + """MIME type - Level 1.""" + return self.metadata.mime_type + + # ========================================================================= + # Level 2: Full content (lazy, cached) - Expensive I/O (full download) + # ========================================================================= + + def _ensure_loaded(self) -> None: + """Load file if not already loaded.""" + if self._loaded is None: + self._loaded = self.source.load(self.path) + # Also set metadata from loaded file + self._metadata = self._loaded.metadata + + def _ensure_parsed(self) -> None: + """Parse file if not already parsed.""" + self._ensure_loaded() + if self._parsed is None: + if self.parser is None: + raise ValueError(f"No parser configured for {self.name}") + assert self._loaded is not None + self._parsed = self.parser.parse(self._loaded) + + @property + def raw_content(self) -> bytes: + """Raw file bytes - Level 2 (triggers download).""" + self._ensure_loaded() + assert self._loaded is not None + return self._loaded.content + + @property + def content(self) -> str: + """Parsed text content - Level 2 (triggers download + parse).""" + self._ensure_parsed() + assert self._parsed is not None + return self._parsed.text + + @property + def tables(self) -> list[dict[str, Any]] | None: + """Extracted tables - Level 2 (triggers download + parse).""" + self._ensure_parsed() + assert self._parsed is not None + return self._parsed.tables + + @property + def chunks(self) -> list[str] | None: + """Pre-chunked content (e.g., pages) - Level 2 (triggers download + parse).""" + self._ensure_parsed() + assert self._parsed is not None + return self._parsed.chunks + + @property + def parsed_metadata(self) -> dict[str, Any]: + """Parser-specific metadata - Level 2 (triggers download + parse).""" + self._ensure_parsed() + assert self._parsed is not None + return self._parsed.metadata + + # ========================================================================= + # Utility methods + # ========================================================================= + + def read(self, encoding: str = "utf-8") -> str: + """ + Read raw content as text - Level 2 (triggers download only, no parse). + + This is useful when you want raw text without parsing overhead. + + Args: + encoding: Text encoding to use + + Returns: + File content as string + """ + self._ensure_loaded() + assert self._loaded is not None + return self._loaded.content.decode(encoding) + + def preload_metadata(self) -> "LazyFile": + """ + Eagerly load metadata (Level 1). + + Useful for batch metadata loading. + + Returns: + self (for chaining) + """ + self._ensure_metadata() + return self + + def preload(self) -> "LazyFile": + """ + Eagerly load and parse content (Level 2). + + Useful when you know you'll need the content. + + Returns: + self (for chaining) + """ + self._ensure_parsed() + return self + + def __str__(self) -> str: + level_names = {0: "path", 1: "metadata", 2: "content"} + return ( + f"" + ) + + def __repr__(self) -> str: + return self.__str__() + + +@dataclass +class LazyFileCollection: + """ + A collection of lazy files with helpful access patterns. + + Provides filtering and batch operations without loading files. + + Example: + ```python + files = LazyFileCollection([...]) + + # No loading required + print(files.names) # All filenames + pdfs = files.by_extension(".pdf") # Filter by extension + + # Selective loading + for pdf in pdfs[:3]: # Only load first 3 + print(pdf.content) + ``` + """ + + files: list[LazyFile] = field(default_factory=list) + + def __len__(self) -> int: + return len(self.files) + + def __iter__(self) -> Iterator[LazyFile]: + return iter(self.files) + + def __getitem__(self, idx: int | slice) -> LazyFile | list[LazyFile]: + result = self.files[idx] + if isinstance(idx, slice): + return result + return result + + def __bool__(self) -> bool: + return len(self.files) > 0 + + # ========================================================================= + # Level 0 operations (no I/O) + # ========================================================================= + + @property + def names(self) -> list[str]: + """List all filenames (no loading required).""" + return [f.name for f in self.files] + + @property + def paths(self) -> list[str]: + """List all file paths (no loading required).""" + return [f.path for f in self.files] + + @property + def extensions(self) -> set[str]: + """Set of all file extensions (no loading required).""" + return {f.extension for f in self.files} + + def by_extension(self, ext: str) -> "LazyFileCollection": + """ + Filter files by extension (no loading required). + + Args: + ext: Extension to filter by (with or without leading dot) + + Returns: + New LazyFileCollection with matching files + """ + if not ext.startswith("."): + ext = "." + ext + ext = ext.lower() + return LazyFileCollection([f for f in self.files if f.extension == ext]) + + def by_name(self, pattern: str) -> "LazyFileCollection": + """ + Filter files by name pattern (no loading required). + + Uses fnmatch for glob-style matching. + + Args: + pattern: Glob pattern (e.g., "report*.pdf", "*2024*") + + Returns: + New LazyFileCollection with matching files + """ + return LazyFileCollection( + [f for f in self.files if fnmatch.fnmatch(f.name, pattern)] + ) + + def search(self, keyword: str) -> "LazyFileCollection": + """ + Search for files with keyword in name (case-insensitive). + + Args: + keyword: Keyword to search for in filename + + Returns: + New LazyFileCollection with matching files + """ + keyword_lower = keyword.lower() + return LazyFileCollection( + [f for f in self.files if keyword_lower in f.name.lower()] + ) + + # ========================================================================= + # Status tracking + # ========================================================================= + + @property + def loaded_count(self) -> int: + """Count of files that have been loaded (Level 2).""" + return sum(1 for f in self.files if f.is_loaded) + + @property + def parsed_count(self) -> int: + """Count of files that have been parsed (Level 2).""" + return sum(1 for f in self.files if f.is_parsed) + + @property + def metadata_count(self) -> int: + """Count of files with metadata loaded (Level 1+).""" + return sum(1 for f in self.files if f.level >= 1) + + # ========================================================================= + # Batch loading operations + # ========================================================================= + + def load_all_metadata(self) -> "LazyFileCollection": + """ + Eagerly load metadata for all files (Level 1). + + Returns: + self (for chaining) + """ + for f in self.files: + f.preload_metadata() + return self + + def load_all(self) -> "LazyFileCollection": + """ + Eagerly load and parse all files (Level 2). + + Warning: This loads all files into memory. Use with caution. + + Returns: + self (for chaining) + """ + for f in self.files: + f.preload() + return self + + def get_all_content(self) -> list[str]: + """ + Get parsed content from all files. + + Triggers loading for any unloaded files. + + Returns: + List of text content from all files + """ + return [f.content for f in self.files] + + # ========================================================================= + # Statistics (may require Level 1) + # ========================================================================= + + @property + def total_size(self) -> int: + """Total size in bytes of all files (triggers metadata load).""" + return sum(f.size for f in self.files) + + @property + def total_size_mb(self) -> float: + """Total size in MB of all files (triggers metadata load).""" + return self.total_size / (1024 * 1024) + + def summary(self) -> str: + """ + Get a summary of the collection. + + Returns summary without triggering any loading. + """ + ext_counts: dict[str, int] = {} + for f in self.files: + ext = f.extension or "(no ext)" + ext_counts[ext] = ext_counts.get(ext, 0) + 1 + + lines = [f"LazyFileCollection with {len(self.files)} files:"] + for ext, count in sorted(ext_counts.items()): + lines.append(f" {ext}: {count}") + lines.append(f" Loaded: {self.loaded_count}/{len(self.files)}") + return "\n".join(lines) + + def __str__(self) -> str: + return ( + f"" + ) + + def __repr__(self) -> str: + return self.__str__() diff --git a/contributing/samples/rlm/adk_rlm/files/loader.py b/contributing/samples/rlm/adk_rlm/files/loader.py new file mode 100644 index 0000000000..73412aaf45 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/loader.py @@ -0,0 +1,365 @@ +""" +FileLoader orchestrator for ADK-RLM. + +Coordinates file loading from various sources and parsing into +content usable by the RLM system. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from adk_rlm.files.base import LoadedFile +from adk_rlm.files.base import ParsedContent +from adk_rlm.files.lazy import LazyFile +from adk_rlm.files.lazy import LazyFileCollection +from adk_rlm.files.parsers.base import FileParser +from adk_rlm.files.parsers.pdf import PDFParser +from adk_rlm.files.parsers.text import TextParser +from adk_rlm.files.sources.base import FileSource +from adk_rlm.files.sources.local import LocalFileSource + + +@dataclass +class FileSpec: + """ + Specification for a file to load. + + Allows explicit control over source selection. + + Example: + ```python + # Auto-detect source + spec = FileSpec(path="report.pdf") + + # Explicit source + spec = FileSpec(path="doc.pdf", source=my_source) + ``` + """ + + path: str + source: FileSource | None = None # None = auto-detect + + +class FileLoader: + """ + Orchestrates file loading and parsing. + + Handles: + - Auto-detecting file sources from paths/URIs + - Resolving glob patterns + - Loading files from various sources + - Parsing files into text content + - Creating lazy file collections for efficient access + + Example: + ```python + loader = FileLoader() + + # Eager loading - parse immediately + contents = loader.load_files(["report.pdf", "data.csv"]) + + # Lazy loading - parse on demand + files = loader.create_lazy_files(["report.pdf", "*.md"]) + for f in files: + print(f.name) # No I/O + print(f.content) # Triggers load + parse + ``` + """ + + def __init__( + self, + sources: dict[str, FileSource] | None = None, + parsers: list[FileParser] | None = None, + base_path: str | Path | None = None, + ): + """ + Initialize FileLoader. + + Args: + sources: Dictionary of named file sources. + Default includes "local" source. + parsers: List of file parsers. + Default includes TextParser and PDFParser. + base_path: Base path for local file source. + """ + # Default sources + self.sources: dict[str, FileSource] = sources or { + "local": LocalFileSource(base_path), + } + + # Default parsers (order matters - first matching parser wins) + self.parsers: list[FileParser] = parsers or [ + TextParser(), + PDFParser(), + ] + + def register_source(self, name: str, source: FileSource) -> None: + """ + Register a file source. + + Args: + name: Name to register source under + source: FileSource implementation + """ + self.sources[name] = source + + def register_parser(self, parser: FileParser) -> None: + """ + Register a file parser. + + New parsers are added to the end of the list. + + Args: + parser: FileParser implementation + """ + self.parsers.append(parser) + + def _detect_source(self, path: str) -> FileSource: + """ + Auto-detect the appropriate source for a path. + + Args: + path: File path or URI + + Returns: + Appropriate FileSource for the path + """ + if path.startswith("sharepoint://"): + source = self.sources.get("sharepoint") + if source is None: + raise ValueError( + "SharePoint source not configured. Register a SharePointSource with" + " loader.register_source('sharepoint', source)" + ) + return source + + elif path.startswith("gdrive://"): + source = self.sources.get("gdrive") + if source is None: + raise ValueError( + "Google Drive source not configured. Register a GoogleDriveSource" + " with loader.register_source('gdrive', source)" + ) + return source + + elif path.startswith("s3://"): + source = self.sources.get("s3") + if source is None: + raise ValueError( + "S3 source not configured. " + "Register an S3Source with loader.register_source('s3', source)" + ) + return source + + elif path.startswith("gs://"): + source = self.sources.get("gcs") + if source is None: + raise ValueError( + "GCS source not configured. Register a GCSFileSource with" + " loader.register_source('gcs', source)" + ) + return source + + elif path.startswith(("http://", "https://")): + source = self.sources.get("http") + if source is None: + raise ValueError( + "HTTP source not configured. " + "Register an HTTPSource with loader.register_source('http', source)" + ) + return source + + else: + # Default to local filesystem + return self.sources["local"] + + def _find_parser(self, file: LoadedFile) -> FileParser: + """ + Find appropriate parser for a file. + + Args: + file: LoadedFile to find parser for + + Returns: + FileParser that can handle the file + + Raises: + ValueError: If no parser found + """ + for parser in self.parsers: + if parser.can_parse(file): + return parser + raise ValueError(f"No parser found for file: {file.metadata.name}") + + def _find_parser_by_path(self, path: str) -> FileParser | None: + """ + Find appropriate parser based on file path/extension. + + Args: + path: File path + + Returns: + FileParser if found, None otherwise + """ + # Get extension from path + name = path.split("/")[-1].split("\\")[-1] + if "." not in name: + return None + + ext = "." + name.rsplit(".", 1)[-1].lower() + + for parser in self.parsers: + if ext in parser.supported_extensions: + return parser + return None + + def load_files( + self, + files: list[str | FileSpec], + ) -> list[ParsedContent]: + """ + Load and parse multiple files (eager loading). + + Args: + files: List of file paths, URIs, or FileSpecs + + Returns: + List of ParsedContent objects + """ + results: list[ParsedContent] = [] + + for file_ref in files: + # Normalize to FileSpec + if isinstance(file_ref, str): + file_ref = FileSpec(path=file_ref) + + # Detect source + source = file_ref.source or self._detect_source(file_ref.path) + + # Resolve patterns (e.g., globs) + resolved_paths = source.resolve(file_ref.path) + + # Load and parse each file + for path in resolved_paths: + loaded = source.load(path) + parser = self._find_parser(loaded) + parsed = parser.parse(loaded) + results.append(parsed) + + return results + + def create_lazy_files( + self, + files: list[str | FileSpec], + ) -> LazyFileCollection: + """ + Create lazy file references (deferred loading). + + Files are not loaded until their content is accessed. + + Args: + files: List of file paths, URIs, or FileSpecs + + Returns: + LazyFileCollection with lazy file references + """ + lazy_files: list[LazyFile] = [] + + for file_ref in files: + # Normalize to FileSpec + if isinstance(file_ref, str): + file_ref = FileSpec(path=file_ref) + + # Detect source + source = file_ref.source or self._detect_source(file_ref.path) + + # Resolve patterns + resolved_paths = source.resolve(file_ref.path) + + # Create lazy file for each resolved path + for path in resolved_paths: + parser = self._find_parser_by_path(path) + lazy_file = LazyFile( + path=path, + source=source, + parser=parser, + ) + lazy_files.append(lazy_file) + + return LazyFileCollection(lazy_files) + + def load_single(self, path: str) -> ParsedContent: + """ + Load and parse a single file. + + Args: + path: File path or URI + + Returns: + ParsedContent for the file + """ + results = self.load_files([path]) + if not results: + raise FileNotFoundError(f"File not found: {path}") + return results[0] + + def create_lazy_file(self, path: str) -> LazyFile: + """ + Create a single lazy file reference. + + Args: + path: File path or URI + + Returns: + LazyFile reference + """ + collection = self.create_lazy_files([path]) + if not collection: + raise FileNotFoundError(f"File not found: {path}") + return collection[0] + + def build_context( + self, + files: list[str | FileSpec], + lazy: bool = True, + ) -> dict[str, Any]: + """ + Build a context dictionary for RLM consumption. + + Args: + files: List of file paths, URIs, or FileSpecs + lazy: If True, use lazy loading. If False, load immediately. + + Returns: + Context dictionary with files and metadata + """ + if lazy: + file_collection = self.create_lazy_files(files) + return { + "files": file_collection, + "file_count": len(file_collection), + "file_names": file_collection.names, + } + else: + parsed_files = self.load_files(files) + if len(parsed_files) == 1: + return { + "content": parsed_files[0].text, + "metadata": parsed_files[0].metadata, + "tables": parsed_files[0].tables, + } + else: + return { + "files": [ + { + "content": pf.text, + "metadata": pf.metadata, + "tables": pf.tables, + } + for pf in parsed_files + ], + "file_count": len(parsed_files), + } diff --git a/contributing/samples/rlm/adk_rlm/files/parsers/__init__.py b/contributing/samples/rlm/adk_rlm/files/parsers/__init__.py new file mode 100644 index 0000000000..b40252410b --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/parsers/__init__.py @@ -0,0 +1,16 @@ +""" +File parser implementations for ADK-RLM. + +This module provides various file parsers for extracting text and +structured data from different file formats. +""" + +from adk_rlm.files.parsers.base import FileParser +from adk_rlm.files.parsers.pdf import PDFParser +from adk_rlm.files.parsers.text import TextParser + +__all__ = [ + "FileParser", + "PDFParser", + "TextParser", +] diff --git a/contributing/samples/rlm/adk_rlm/files/parsers/base.py b/contributing/samples/rlm/adk_rlm/files/parsers/base.py new file mode 100644 index 0000000000..45a8c323b2 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/parsers/base.py @@ -0,0 +1,75 @@ +""" +Base protocol for file parsers. + +FileParser is an abstract base class that defines the interface for +parsing files of various formats into text and structured content. +""" + +from abc import ABC +from abc import abstractmethod +from pathlib import Path + +from adk_rlm.files.base import LoadedFile +from adk_rlm.files.base import ParsedContent + + +class FileParser(ABC): + """ + Protocol for file format parsers. + + Implementations should handle parsing specific file formats + (e.g., text, PDF, Office documents) into text content. + """ + + @property + @abstractmethod + def supported_extensions(self) -> list[str]: + """ + Return list of supported file extensions. + + Extensions should include the leading dot and be lowercase. + Example: [".txt", ".md", ".json"] + """ + ... + + @property + @abstractmethod + def supported_mime_types(self) -> list[str]: + """ + Return list of supported MIME types. + + Example: ["text/plain", "text/markdown", "application/json"] + """ + ... + + @abstractmethod + def parse(self, file: LoadedFile) -> ParsedContent: + """ + Parse file content into text and structured data. + + Args: + file: LoadedFile with raw content + + Returns: + ParsedContent with extracted text and metadata + """ + ... + + def can_parse(self, file: LoadedFile) -> bool: + """ + Check if this parser can handle the file. + + Uses file extension and MIME type to determine compatibility. + + Args: + file: LoadedFile to check + + Returns: + True if this parser can handle the file + """ + ext = Path(file.metadata.name).suffix.lower() + mime = file.metadata.mime_type + + return ext in self.supported_extensions or ( + mime is not None and mime in self.supported_mime_types + ) diff --git a/contributing/samples/rlm/adk_rlm/files/parsers/pdf.py b/contributing/samples/rlm/adk_rlm/files/parsers/pdf.py new file mode 100644 index 0000000000..22f33cf40a --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/parsers/pdf.py @@ -0,0 +1,151 @@ +""" +PDF file parser implementation. + +Extracts text and tables from PDF files using pdfplumber or pypdf. +""" + +from io import BytesIO +from typing import Any + +from adk_rlm.files.base import LoadedFile +from adk_rlm.files.base import ParsedContent +from adk_rlm.files.parsers.base import FileParser + + +class PDFParser(FileParser): + """ + Parse PDF files. + + Uses pdfplumber for better table extraction when available, + falls back to pypdf for basic text extraction. + + Example: + ```python + parser = PDFParser() + if parser.can_parse(loaded_file): + content = parser.parse(loaded_file) + print(content.text) + if content.tables: + print(f"Found {len(content.tables)} table rows") + ``` + """ + + @property + def supported_extensions(self) -> list[str]: + """Return list of supported file extensions.""" + return [".pdf"] + + @property + def supported_mime_types(self) -> list[str]: + """Return list of supported MIME types.""" + return ["application/pdf"] + + def parse(self, file: LoadedFile) -> ParsedContent: + """ + Extract text from PDF. + + Tries pdfplumber first (better for tables), falls back to pypdf. + + Args: + file: LoadedFile with PDF content + + Returns: + ParsedContent with text, optional tables, and page chunks + """ + # Try pdfplumber first (better table extraction) + try: + import pdfplumber + + return self._parse_with_pdfplumber(file) + except ImportError: + pass + + # Fall back to pypdf + try: + import pypdf + + return self._parse_with_pypdf(file) + except ImportError: + pass + + # Neither library available + raise ImportError( + "PDF parsing requires either 'pdfplumber' or 'pypdf'. " + "Install with: pip install pdfplumber or pip install pypdf" + ) + + def _parse_with_pdfplumber(self, file: LoadedFile) -> ParsedContent: + """ + Parse PDF using pdfplumber (better for tables). + + Args: + file: LoadedFile with PDF content + + Returns: + ParsedContent with text, tables, and page chunks + """ + import pdfplumber + + text_parts: list[str] = [] + tables: list[dict[str, Any]] = [] + metadata: dict[str, Any] = {"parser": "pdfplumber"} + + with pdfplumber.open(BytesIO(file.content)) as pdf: + metadata["page_count"] = len(pdf.pages) + + for i, page in enumerate(pdf.pages): + # Extract text + page_text = page.extract_text() or "" + text_parts.append(f"--- Page {i + 1} ---\n{page_text}") + + # Extract tables + page_tables = page.extract_tables() + for table in page_tables: + if table and len(table) > 1: + # Use first row as headers + headers = [ + str(h) if h else f"col_{j}" for j, h in enumerate(table[0]) + ] + for row in table[1:]: + if row: + tables.append(dict(zip(headers, row))) + + metadata["table_count"] = len(tables) + + return ParsedContent( + text="\n\n".join(text_parts), + metadata=metadata, + chunks=text_parts, # Pre-chunked by page + tables=tables if tables else None, + images=None, + ) + + def _parse_with_pypdf(self, file: LoadedFile) -> ParsedContent: + """ + Parse PDF using pypdf (simpler, no table extraction). + + Args: + file: LoadedFile with PDF content + + Returns: + ParsedContent with text and page chunks + """ + import pypdf + + text_parts: list[str] = [] + metadata: dict[str, Any] = {"parser": "pypdf"} + + reader = pypdf.PdfReader(BytesIO(file.content)) + metadata["page_count"] = len(reader.pages) + + for i, page in enumerate(reader.pages): + page_text = page.extract_text() or "" + text_parts.append(f"--- Page {i + 1} ---\n{page_text}") + + return ParsedContent( + text="\n\n".join(text_parts), + metadata=metadata, + chunks=text_parts, + tables=None, + images=None, + ) diff --git a/contributing/samples/rlm/adk_rlm/files/parsers/text.py b/contributing/samples/rlm/adk_rlm/files/parsers/text.py new file mode 100644 index 0000000000..800f427b3a --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/parsers/text.py @@ -0,0 +1,206 @@ +""" +Text file parser implementation. + +Handles plain text files including .txt, .md, .json, .yaml, .csv, etc. +""" + +import csv +from io import StringIO +import json +from pathlib import Path +from typing import Any + +from adk_rlm.files.base import LoadedFile +from adk_rlm.files.base import ParsedContent +from adk_rlm.files.parsers.base import FileParser + + +class TextParser(FileParser): + """ + Parse plain text files. + + Supports various text-based formats including: + - Plain text (.txt) + - Markdown (.md, .markdown) + - JSON (.json) + - YAML (.yaml, .yml) + - CSV/TSV (.csv, .tsv) + - Code files (.py, .js, .ts, etc.) + - Log files (.log) + - XML/HTML (.xml, .html) + """ + + # Text file extensions + TEXT_EXTENSIONS = [ + ".txt", + ".md", + ".markdown", + ".json", + ".yaml", + ".yml", + ".csv", + ".tsv", + ".log", + ".xml", + ".html", + ".htm", + ".rst", + ".py", + ".js", + ".ts", + ".jsx", + ".tsx", + ".java", + ".c", + ".cpp", + ".h", + ".hpp", + ".go", + ".rs", + ".rb", + ".php", + ".sh", + ".bash", + ".zsh", + ".sql", + ".r", + ".scala", + ".kt", + ".swift", + ".css", + ".scss", + ".less", + ".toml", + ".ini", + ".cfg", + ".conf", + ".properties", + ".env", + ] + + # Text MIME types + TEXT_MIME_TYPES = [ + "text/plain", + "text/markdown", + "text/x-markdown", + "application/json", + "text/yaml", + "application/x-yaml", + "text/csv", + "text/tab-separated-values", + "text/html", + "application/xml", + "text/xml", + "text/x-python", + "application/javascript", + "text/javascript", + ] + + @property + def supported_extensions(self) -> list[str]: + """Return list of supported file extensions.""" + return self.TEXT_EXTENSIONS + + @property + def supported_mime_types(self) -> list[str]: + """Return list of supported MIME types.""" + return self.TEXT_MIME_TYPES + + def parse(self, file: LoadedFile) -> ParsedContent: + """ + Parse text file. + + Provides special handling for structured formats like JSON and CSV. + + Args: + file: LoadedFile with raw content + + Returns: + ParsedContent with text and optional structured data + """ + try: + text = file.as_text() + except UnicodeDecodeError: + # Try common encodings + for encoding in ["utf-8", "latin-1", "cp1252", "ascii"]: + try: + text = file.content.decode(encoding) + break + except UnicodeDecodeError: + continue + else: + # Last resort: decode with replacement + text = file.content.decode("utf-8", errors="replace") + + ext = Path(file.metadata.name).suffix.lower() + metadata: dict[str, Any] = {"format": ext, "encoding": "utf-8"} + tables: list[dict[str, Any]] | None = None + + # Special handling for structured formats + if ext == ".json": + text, metadata = self._handle_json(text, metadata) + elif ext in [".csv", ".tsv"]: + text, metadata, tables = self._handle_csv(text, ext, metadata) + elif ext in [".yaml", ".yml"]: + metadata = self._handle_yaml(text, metadata) + + return ParsedContent( + text=text, + metadata=metadata, + chunks=None, + tables=tables, + images=None, + ) + + def _handle_json( + self, text: str, metadata: dict[str, Any] + ) -> tuple[str, dict[str, Any]]: + """Handle JSON files - pretty print for readability.""" + try: + data = json.loads(text) + metadata["json_type"] = type(data).__name__ + if isinstance(data, list): + metadata["item_count"] = len(data) + elif isinstance(data, dict): + metadata["keys"] = list(data.keys())[:20] # First 20 keys + # Pretty print for better readability + text = json.dumps(data, indent=2, ensure_ascii=False, default=str) + except json.JSONDecodeError as e: + metadata["parse_error"] = str(e) + return text, metadata + + def _handle_csv( + self, text: str, ext: str, metadata: dict[str, Any] + ) -> tuple[str, dict[str, Any], list[dict[str, Any]] | None]: + """Handle CSV/TSV files - extract tables.""" + delimiter = "\t" if ext == ".tsv" else "," + tables: list[dict[str, Any]] = [] + + try: + reader = csv.DictReader(StringIO(text), delimiter=delimiter) + for row in reader: + tables.append(dict(row)) + + if tables: + metadata["row_count"] = len(tables) + metadata["columns"] = list(tables[0].keys()) + except Exception as e: + metadata["parse_error"] = str(e) + tables = [] + + return text, metadata, tables if tables else None + + def _handle_yaml(self, text: str, metadata: dict[str, Any]) -> dict[str, Any]: + """Handle YAML files - add metadata about structure.""" + try: + import yaml + + data = yaml.safe_load(text) + metadata["yaml_type"] = type(data).__name__ + if isinstance(data, dict): + metadata["keys"] = list(data.keys())[:20] + except ImportError: + metadata["yaml_parse"] = "yaml library not available" + except Exception as e: + metadata["parse_error"] = str(e) + return metadata diff --git a/contributing/samples/rlm/adk_rlm/files/sources/__init__.py b/contributing/samples/rlm/adk_rlm/files/sources/__init__.py new file mode 100644 index 0000000000..64dc1a4643 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/sources/__init__.py @@ -0,0 +1,24 @@ +""" +File source implementations for ADK-RLM. + +This module provides various file source implementations for loading files +from different locations (local filesystem, cloud storage, etc.). +""" + +from adk_rlm.files.sources.base import FileSource +from adk_rlm.files.sources.local import LocalFileSource + +# Optional GCS support (requires google-cloud-storage) +try: + from adk_rlm.files.sources.gcs import GCSFileSource + from adk_rlm.files.sources.gcs import RetryConfig +except ImportError: + GCSFileSource = None # type: ignore + RetryConfig = None # type: ignore + +__all__ = [ + "FileSource", + "LocalFileSource", + "GCSFileSource", + "RetryConfig", +] diff --git a/contributing/samples/rlm/adk_rlm/files/sources/base.py b/contributing/samples/rlm/adk_rlm/files/sources/base.py new file mode 100644 index 0000000000..5d0ee029ad --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/sources/base.py @@ -0,0 +1,104 @@ +""" +Base protocol for file sources. + +FileSource is an abstract base class that defines the interface for +loading files from various sources (local filesystem, cloud storage, etc.). +""" + +from abc import ABC +from abc import abstractmethod +from typing import Iterator + +from adk_rlm.files.base import FileMetadata +from adk_rlm.files.base import LoadedFile + + +class FileSource(ABC): + """ + Protocol for file sources. + + Implementations should handle loading files from a specific source type + (e.g., local filesystem, SharePoint, Google Drive, S3). + """ + + @property + @abstractmethod + def source_type(self) -> str: + """Return source type identifier (e.g., 'local', 'sharepoint').""" + ... + + @abstractmethod + def resolve(self, path: str) -> list[str]: + """ + Resolve a path pattern to concrete file paths. + + Supports glob patterns for sources that allow it. + + Args: + path: File path or pattern (e.g., "*.pdf", "folder/**/*.docx") + + Returns: + List of resolved file paths/URIs + """ + ... + + @abstractmethod + def load(self, path: str) -> LoadedFile: + """ + Load a single file from the source. + + Args: + path: Resolved file path (from resolve()) + + Returns: + LoadedFile with content and metadata + """ + ... + + def get_metadata(self, path: str) -> FileMetadata: + """ + Get metadata for a file without loading full content. + + Override this in subclasses for more efficient metadata-only access + (e.g., HEAD requests for HTTP, stat() for local files). + + Default implementation loads the full file, which is inefficient. + + Args: + path: File path to get metadata for + + Returns: + FileMetadata for the file + """ + return self.load(path).metadata + + def load_many(self, paths: list[str]) -> Iterator[LoadedFile]: + """ + Load multiple files. + + Override for parallel loading in subclasses. + + Args: + paths: List of file paths to load + + Yields: + LoadedFile for each path + """ + for path in paths: + yield self.load(path) + + def exists(self, path: str) -> bool: + """ + Check if a file exists at the given path. + + Args: + path: File path to check + + Returns: + True if file exists, False otherwise + """ + try: + resolved = self.resolve(path) + return len(resolved) > 0 + except Exception: + return False diff --git a/contributing/samples/rlm/adk_rlm/files/sources/gcs.py b/contributing/samples/rlm/adk_rlm/files/sources/gcs.py new file mode 100644 index 0000000000..08a7e2e86e --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/sources/gcs.py @@ -0,0 +1,472 @@ +""" +Google Cloud Storage file source implementation. + +Provides file loading from GCS buckets with glob pattern support, +retry logic, and efficient metadata access. +""" + +from concurrent.futures import as_completed +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +import fnmatch +import mimetypes +from pathlib import Path +import tempfile +import time +from typing import Callable +from typing import Iterator +from typing import TypeVar + +from adk_rlm.files.base import FileMetadata +from adk_rlm.files.base import LoadedFile +from adk_rlm.files.sources.base import FileSource + +try: + from google.cloud import storage + from google.cloud.exceptions import Forbidden + from google.cloud.exceptions import NotFound + + HAS_GCS = True +except ImportError: + HAS_GCS = False + storage = None # type: ignore + NotFound = Exception # type: ignore + Forbidden = Exception # type: ignore + +T = TypeVar("T") + + +@dataclass +class RetryConfig: + """Configuration for retry behavior on transient errors.""" + + max_attempts: int = 3 + initial_delay: float = 0.5 + max_delay: float = 30.0 + exponential_base: float = 2.0 + + +class GCSFileSource(FileSource): + """ + Load files from Google Cloud Storage. + + Supports gs:// URIs with glob patterns for batch file resolution. + Uses Application Default Credentials by default. + + Example: + ```python + source = GCSFileSource(bucket="my-bucket") + + # Single file + file = source.load("gs://my-bucket/data/report.pdf") + + # Glob pattern + paths = source.resolve("gs://my-bucket/data/**/*.pdf") + for path in paths: + file = source.load(path) + + # With explicit credentials + source = GCSFileSource( + bucket="my-bucket", + credentials_path="/path/to/service-account.json" + ) + ``` + """ + + def __init__( + self, + bucket: str | None = None, + project: str | None = None, + credentials: "storage.Client | None" = None, + credentials_path: str | None = None, + timeout: float = 60.0, + retry_config: RetryConfig | None = None, + max_concurrent: int = 10, + large_file_threshold: int = 100_000_000, # 100 MB + ): + """ + Initialize GCSFileSource. + + Args: + bucket: Default bucket name (can be overridden in paths) + project: GCP project ID (optional, inferred from credentials) + credentials: Explicit google.auth.credentials.Credentials object + credentials_path: Path to service account JSON file + timeout: Request timeout in seconds + retry_config: Retry configuration for transient errors + max_concurrent: Max parallel downloads in load_many() + large_file_threshold: Files larger than this stream to temp file + """ + if not HAS_GCS: + raise ImportError( + "GCS support requires 'google-cloud-storage'. " + "Install with: pip install google-cloud-storage" + ) + + self.default_bucket = bucket + self._project = project + self._credentials = credentials + self._credentials_path = credentials_path + self.timeout = timeout + self.retry_config = retry_config or RetryConfig() + self.max_concurrent = max_concurrent + self.large_file_threshold = large_file_threshold + + # Client is lazily initialized to avoid pickle issues + self._client: "storage.Client | None" = None + + @property + def client(self) -> "storage.Client": + """Lazily initialize and return the GCS client.""" + if self._client is None: + if self._credentials: + self._client = storage.Client( + credentials=self._credentials, project=self._project + ) + elif self._credentials_path: + self._client = storage.Client.from_service_account_json( + self._credentials_path, project=self._project + ) + else: + # Use Application Default Credentials + self._client = storage.Client(project=self._project) + return self._client + + def __getstate__(self): + """Return state for pickling, excluding the client.""" + state = self.__dict__.copy() + # Don't pickle the client - it will be recreated on demand + state["_client"] = None + return state + + def __setstate__(self, state): + """Restore state from pickle.""" + self.__dict__.update(state) + + @property + def source_type(self) -> str: + """Return 'gcs' as the source type.""" + return "gcs" + + def _parse_path(self, path: str) -> tuple[str, str]: + """ + Parse a GCS path into bucket and blob name. + + Args: + path: GCS path (gs://bucket/key or just key for default bucket) + + Returns: + Tuple of (bucket_name, blob_name) + """ + if path.startswith("gs://"): + path = path[5:] + parts = path.split("/", 1) + bucket_name = parts[0] + blob_name = parts[1] if len(parts) > 1 else "" + else: + bucket_name = self.default_bucket + blob_name = path + + if not bucket_name: + raise ValueError("No bucket specified and no default bucket set") + + return bucket_name, blob_name + + def _is_retryable(self, error: Exception) -> bool: + """Check if an error is retryable.""" + error_name = type(error).__name__ + error_str = str(error).lower() + retryable_names = ( + "ServiceUnavailable", + "TooManyRequests", + "InternalError", + "Timeout", + "ConnectionError", + ) + retryable_messages = ( + "serviceunavailable", + "toomanyrequests", + "internalerror", + "timeout", + "connectionerror", + "connection reset", + "connection refused", + ) + return error_name in retryable_names or any( + msg in error_str for msg in retryable_messages + ) + + def _with_retry(self, operation: Callable[[], T], context: str = "") -> T: + """Execute operation with retry logic for transient errors.""" + last_error: Exception | None = None + delay = self.retry_config.initial_delay + + for attempt in range(self.retry_config.max_attempts): + try: + return operation() + except NotFound: + raise FileNotFoundError(f"GCS object not found: {context}") + except Forbidden as e: + raise PermissionError( + f"Access denied to GCS object: {context}. " + f"Check bucket permissions and credentials. Error: {e}" + ) + except Exception as e: + if self._is_retryable(e): + last_error = e + if attempt < self.retry_config.max_attempts - 1: + time.sleep(delay) + delay = min( + delay * self.retry_config.exponential_base, + self.retry_config.max_delay, + ) + else: + raise + + raise RuntimeError( + f"GCS operation failed after {self.retry_config.max_attempts} attempts:" + f" {context}. Last error: {last_error}" + ) + + def resolve(self, path: str) -> list[str]: + """ + Resolve GCS path, supporting glob patterns. + + For glob patterns, lists blobs with prefix matching and filters + with fnmatch. Note: listing large buckets can be slow. + + Args: + path: GCS path or glob pattern (e.g., "gs://bucket/data/**/*.pdf") + + Returns: + List of resolved gs:// URIs + """ + bucket_name, pattern = self._parse_path(path) + + # Check for glob patterns + if not any(c in pattern for c in ["*", "?", "["]): + # Not a glob - check if single object exists + bucket = self.client.bucket(bucket_name) + blob = bucket.blob(pattern) + + def check_exists(): + return blob.exists(timeout=self.timeout) + + if self._with_retry(check_exists, f"gs://{bucket_name}/{pattern}"): + return [f"gs://{bucket_name}/{pattern}"] + return [] + + # Extract prefix (everything before first glob char) + prefix_end = len(pattern) + for char in ["*", "?", "["]: + idx = pattern.find(char) + if idx != -1: + prefix_end = min(prefix_end, idx) + + prefix = pattern[:prefix_end] + # Also trim to last / to get directory prefix + if "/" in prefix: + prefix = prefix.rsplit("/", 1)[0] + "/" + else: + prefix = "" + + # List blobs with prefix and filter + bucket = self.client.bucket(bucket_name) + results: list[str] = [] + + def list_blobs(): + return list(bucket.list_blobs(prefix=prefix, timeout=self.timeout)) + + blobs = self._with_retry(list_blobs, f"listing gs://{bucket_name}/{prefix}") + + for blob in blobs: + if fnmatch.fnmatch(blob.name, pattern): + results.append(f"gs://{bucket_name}/{blob.name}") + + return sorted(results) + + def get_metadata(self, path: str) -> FileMetadata: + """ + Get blob metadata without downloading content. + + This is efficient for Level 1 lazy loading - only fetches + metadata, not the blob content. + + Args: + path: GCS path + + Returns: + FileMetadata for the blob + """ + bucket_name, blob_name = self._parse_path(path) + bucket = self.client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + def reload_metadata(): + blob.reload(timeout=self.timeout) + return blob + + blob = self._with_retry(reload_metadata, f"gs://{bucket_name}/{blob_name}") + + # Determine MIME type + mime_type = blob.content_type + if not mime_type: + mime_type, _ = mimetypes.guess_type(blob_name) + + # Parse last modified + last_modified = None + if blob.updated: + last_modified = blob.updated + + return FileMetadata( + name=blob_name.split("/")[-1], + path=f"gs://{bucket_name}/{blob_name}", + source_type=self.source_type, + size_bytes=blob.size or 0, + mime_type=mime_type, + last_modified=last_modified, + extra={ + "bucket": bucket_name, + "blob_name": blob_name, + "content_encoding": blob.content_encoding, + "storage_class": blob.storage_class, + "generation": blob.generation, + "metageneration": blob.metageneration, + "etag": blob.etag, + "md5_hash": blob.md5_hash, + "crc32c": blob.crc32c, + }, + ) + + def _load_direct(self, path: str) -> LoadedFile: + """Load file directly into memory.""" + bucket_name, blob_name = self._parse_path(path) + bucket = self.client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + def download(): + return blob.download_as_bytes(timeout=self.timeout) + + content = self._with_retry(download, f"gs://{bucket_name}/{blob_name}") + + # Reload blob to get metadata after download + try: + blob.reload(timeout=self.timeout) + except Exception: + pass # Metadata fetch is best-effort after download + + # Build metadata + mime_type = blob.content_type + if not mime_type: + mime_type, _ = mimetypes.guess_type(blob_name) + + return LoadedFile( + metadata=FileMetadata( + name=blob_name.split("/")[-1], + path=f"gs://{bucket_name}/{blob_name}", + source_type=self.source_type, + size_bytes=len(content), + mime_type=mime_type, + last_modified=blob.updated, + extra={ + "bucket": bucket_name, + "blob_name": blob_name, + "etag": blob.etag, + }, + ), + content=content, + ) + + def _load_chunked(self, path: str, metadata: FileMetadata) -> LoadedFile: + """Load large file via temp file to manage memory.""" + bucket_name, blob_name = self._parse_path(path) + bucket = self.client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + def download_to_file(): + with tempfile.NamedTemporaryFile(delete=False) as tmp: + blob.download_to_file(tmp, timeout=self.timeout) + tmp.flush() + return tmp.name + + tmp_path = self._with_retry( + download_to_file, f"gs://{bucket_name}/{blob_name}" + ) + + try: + content = Path(tmp_path).read_bytes() + finally: + Path(tmp_path).unlink(missing_ok=True) + + return LoadedFile(metadata=metadata, content=content) + + def load(self, path: str) -> LoadedFile: + """ + Load file from GCS. + + For files larger than large_file_threshold, streams to a + temp file first to manage memory. + + Args: + path: GCS path (gs://bucket/key or key for default bucket) + + Returns: + LoadedFile with content and metadata + """ + # Check size first to decide loading strategy + try: + metadata = self.get_metadata(path) + except FileNotFoundError: + raise + + if metadata.size_bytes > self.large_file_threshold: + return self._load_chunked(path, metadata) + else: + return self._load_direct(path) + + def load_many(self, paths: list[str]) -> Iterator[LoadedFile]: + """ + Load multiple files in parallel. + + Uses ThreadPoolExecutor for concurrent downloads. + + Args: + paths: List of GCS paths to load + + Yields: + LoadedFile for each path (order not guaranteed) + """ + if len(paths) == 0: + return + + if len(paths) == 1: + yield self.load(paths[0]) + return + + with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: + futures = {executor.submit(self.load, path): path for path in paths} + for future in as_completed(futures): + try: + yield future.result() + except Exception as e: + # Re-raise with path context + path = futures[future] + raise RuntimeError(f"Failed to load {path}: {e}") from e + + def exists(self, path: str) -> bool: + """ + Check if a blob exists. + + Args: + path: GCS path to check + + Returns: + True if blob exists, False otherwise + """ + try: + bucket_name, blob_name = self._parse_path(path) + bucket = self.client.bucket(bucket_name) + blob = bucket.blob(blob_name) + return blob.exists(timeout=self.timeout) + except Exception: + return False diff --git a/contributing/samples/rlm/adk_rlm/files/sources/local.py b/contributing/samples/rlm/adk_rlm/files/sources/local.py new file mode 100644 index 0000000000..c053568c7f --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/files/sources/local.py @@ -0,0 +1,175 @@ +""" +Local filesystem file source implementation. + +Provides file loading from the local filesystem with glob pattern support. +""" + +from datetime import datetime +from glob import glob +import mimetypes +import os +from pathlib import Path + +from adk_rlm.files.base import FileMetadata +from adk_rlm.files.base import LoadedFile +from adk_rlm.files.sources.base import FileSource + + +class LocalFileSource(FileSource): + """ + Load files from the local filesystem. + + Supports glob patterns for resolving multiple files. + + Example: + ```python + source = LocalFileSource(base_path="/path/to/docs") + + # Single file + file = source.load("report.pdf") + + # Glob pattern + paths = source.resolve("**/*.md") + for path in paths: + file = source.load(path) + ``` + """ + + def __init__(self, base_path: str | Path | None = None): + """ + Initialize LocalFileSource. + + Args: + base_path: Base directory for relative paths. + Defaults to current working directory. + """ + if base_path is None: + self.base_path = Path.cwd() + else: + self.base_path = Path(base_path).resolve() + + @property + def source_type(self) -> str: + """Return 'local' as the source type.""" + return "local" + + def resolve(self, path: str) -> list[str]: + """ + Resolve path, supporting glob patterns. + + Args: + path: File path or glob pattern (e.g., "*.pdf", "**/*.md") + + Returns: + List of absolute file paths matching the pattern + """ + # Handle absolute paths + if os.path.isabs(path): + full_path = Path(path) + else: + full_path = self.base_path / path + + # Check for glob patterns + if any(c in str(full_path) for c in ["*", "?", "["]): + matches = glob(str(full_path), recursive=True) + # Filter to only files (not directories) + return sorted([str(m) for m in matches if os.path.isfile(m)]) + + # Single file path + if full_path.exists() and full_path.is_file(): + return [str(full_path)] + + return [] + + def get_metadata(self, path: str) -> FileMetadata: + """ + Get metadata via stat() without reading file content. + + This is more efficient than load() for metadata-only access. + + Args: + path: File path to get metadata for + + Returns: + FileMetadata for the file + """ + file_path = Path(path) + + if not file_path.is_absolute(): + file_path = self.base_path / file_path + + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + stat = file_path.stat() + mime_type, _ = mimetypes.guess_type(str(file_path)) + + return FileMetadata( + name=file_path.name, + path=str(file_path), + source_type=self.source_type, + size_bytes=stat.st_size, + mime_type=mime_type, + last_modified=datetime.fromtimestamp(stat.st_mtime), + extra={ + "mode": stat.st_mode, + "created": datetime.fromtimestamp(stat.st_ctime).isoformat(), + }, + ) + + def load(self, path: str) -> LoadedFile: + """ + Load file from filesystem. + + Args: + path: File path (absolute or relative to base_path) + + Returns: + LoadedFile with content and metadata + """ + file_path = Path(path) + + if not file_path.is_absolute(): + file_path = self.base_path / file_path + + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {path}") + + if not file_path.is_file(): + raise ValueError(f"Path is not a file: {path}") + + stat = file_path.stat() + content = file_path.read_bytes() + + # Detect MIME type + mime_type, _ = mimetypes.guess_type(str(file_path)) + + return LoadedFile( + metadata=FileMetadata( + name=file_path.name, + path=str(file_path), + source_type=self.source_type, + size_bytes=stat.st_size, + mime_type=mime_type, + last_modified=datetime.fromtimestamp(stat.st_mtime), + extra={}, + ), + content=content, + ) + + def exists(self, path: str) -> bool: + """ + Check if a file exists at the given path. + + Args: + path: File path to check + + Returns: + True if file exists, False otherwise + """ + file_path = Path(path) + + if not file_path.is_absolute(): + file_path = self.base_path / file_path + + return file_path.exists() and file_path.is_file() diff --git a/contributing/samples/rlm/adk_rlm/llm.py b/contributing/samples/rlm/adk_rlm/llm.py new file mode 100644 index 0000000000..70e5806876 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/llm.py @@ -0,0 +1,71 @@ +""" +LLM utilities and rate limiting. + +This module provides a global semaphore for limiting concurrent LLM calls +across all RLM components (agents, code executors, batched queries). +""" + +import asyncio +from contextlib import contextmanager +import threading + +# Global semaphore for limiting concurrent LLM calls. +# Uses threading.BoundedSemaphore because LLM calls happen across different +# contexts (sync code executor, async agents, different event loops). +LLM_CONCURRENCY_LIMIT = 30 +_llm_semaphore = threading.BoundedSemaphore(LLM_CONCURRENCY_LIMIT) + + +@contextmanager +def llm_rate_limit(): + """Context manager for rate-limiting LLM calls (sync version). + + Usage: + with llm_rate_limit(): + response = client.models.generate_content(...) + """ + _llm_semaphore.acquire() + try: + yield + finally: + _llm_semaphore.release() + + +async def llm_rate_limit_async(): + """Async context manager for rate-limiting LLM calls. + + Usage: + async with llm_rate_limit_async(): + response = await client.aio.models.generate_content(...) + """ + # Acquire in a thread to avoid blocking the event loop + await asyncio.to_thread(_llm_semaphore.acquire) + return _AsyncSemaphoreReleaser() + + +class _AsyncSemaphoreReleaser: + """Helper class for async context manager protocol.""" + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + _llm_semaphore.release() + return False + + +class AsyncLLMRateLimiter: + """Async context manager for rate-limiting LLM calls. + + Usage: + async with AsyncLLMRateLimiter(): + response = await client.aio.models.generate_content(...) + """ + + async def __aenter__(self): + await asyncio.to_thread(_llm_semaphore.acquire) + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + _llm_semaphore.release() + return False diff --git a/contributing/samples/rlm/adk_rlm/logging/__init__.py b/contributing/samples/rlm/adk_rlm/logging/__init__.py new file mode 100644 index 0000000000..f9eef47ff7 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/logging/__init__.py @@ -0,0 +1,6 @@ +"""Logging utilities for ADK-RLM.""" + +from adk_rlm.logging.rlm_logger import RLMLogger +from adk_rlm.logging.verbose import VerbosePrinter + +__all__ = ["RLMLogger", "VerbosePrinter"] diff --git a/contributing/samples/rlm/adk_rlm/logging/rlm_logger.py b/contributing/samples/rlm/adk_rlm/logging/rlm_logger.py new file mode 100644 index 0000000000..f5c4ea8b59 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/logging/rlm_logger.py @@ -0,0 +1,207 @@ +""" +Logger for RLM iterations. + +Writes RLMIteration data to JSON-lines files for analysis and debugging. +Compatible with the original RLM visualizer. +""" + +from datetime import datetime +import json +import os +import threading +import uuid + +from adk_rlm.types import RLMIteration +from adk_rlm.types import RLMMetadata + + +class RLMLogger: + """Logger that writes RLMIteration data to a JSON-lines file.""" + + def __init__(self, log_dir: str, file_name: str = "rlm"): + """ + Initialize the RLM logger. + + Args: + log_dir: Directory to store log files. + file_name: Base name for log files. + """ + # Convert to absolute path to ensure it works from any working directory + # (important for child agents that may run in different directories) + self.log_dir = os.path.abspath(log_dir) + os.makedirs(self.log_dir, exist_ok=True) + + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + run_id = str(uuid.uuid4())[:8] + self.log_file_path = os.path.join( + self.log_dir, f"{file_name}_{timestamp}_{run_id}.jsonl" + ) + + self._iteration_count = 0 + self._metadata_logged = False + self._lock = threading.Lock() + + def log_metadata(self, metadata: RLMMetadata) -> None: + """ + Log RLM metadata as the first entry in the file. + + Args: + metadata: The RLM configuration metadata. + """ + if self._metadata_logged: + return + + entry = { + "type": "metadata", + "timestamp": datetime.now().isoformat(), + **metadata.to_dict(), + } + + # Serialize to string first, then write atomically under lock + # to prevent interleaved writes from concurrent threads + line = json.dumps(entry) + "\n" + with self._lock: + with open(self.log_file_path, "a") as f: + f.write(line) + + self._metadata_logged = True + + def log( + self, + iteration: RLMIteration, + depth: int = 0, + agent_name: str | None = None, + parent_agent: str | None = None, + parent_iteration: int | None = None, + parent_block_index: int | None = None, + parallel_batch_id: str | None = None, + batch_index: int | None = None, + batch_size: int | None = None, + ) -> None: + """ + Log an RLMIteration to the file. + + Args: + iteration: The iteration to log. + depth: The recursion depth (0 = root agent). + agent_name: Name of the agent logging this iteration. + parent_agent: Name of the parent agent that spawned this one. + parent_iteration: The iteration number of the parent that spawned this agent. + parent_block_index: The code block index in the parent that spawned this agent. + parallel_batch_id: UUID of the parallel batch (if part of a batch). + batch_index: Position within the parallel batch (0-indexed). + batch_size: Total number of items in the parallel batch. + """ + # Build entry outside lock, but increment counter and write inside lock + # to prevent race conditions from concurrent threads + entry = { + "type": "iteration", + "iteration": 0, # Placeholder, set under lock + "timestamp": datetime.now().isoformat(), + "depth": depth, + "agent_name": agent_name, + "parent_agent": parent_agent, + **iteration.to_dict(), + } + + # Add optional parent iteration context + if parent_iteration is not None: + entry["parent_iteration"] = parent_iteration + if parent_block_index is not None: + entry["parent_block_index"] = parent_block_index + + # Add optional parallel batch metadata + if parallel_batch_id is not None: + entry["parallel_batch_id"] = parallel_batch_id + if batch_index is not None: + entry["batch_index"] = batch_index + if batch_size is not None: + entry["batch_size"] = batch_size + + # Serialize to string first, then write atomically under lock + # to prevent interleaved writes from concurrent threads + with self._lock: + self._iteration_count += 1 + entry["iteration"] = self._iteration_count + line = json.dumps(entry) + "\n" + with open(self.log_file_path, "a") as f: + f.write(line) + + def log_simple_llm_call( + self, + prompt: str, + response: str, + model: str, + execution_time_ms: float, + depth: int = 0, + agent_name: str | None = None, + parent_iteration: int | None = None, + parent_block_index: int | None = None, + batch_index: int | None = None, + batch_size: int | None = None, + error: str | None = None, + ) -> None: + """ + Log a simple (non-recursive) LLM call. + + This is used when llm_query() or llm_query_batched() is called with + recursive=False, so there's no full RLMIteration to log. + + Args: + prompt: The prompt sent to the LLM. + response: The response received (or error message if call failed). + model: The model used. + execution_time_ms: Execution time in milliseconds. + depth: The recursion depth (0 = root agent). + agent_name: Name of the agent that made the call. + parent_iteration: The iteration number that spawned this call. + parent_block_index: The code block index that spawned this call. + batch_index: Position within a batch (0-indexed), if part of a batch. + batch_size: Total number of items in the batch, if part of a batch. + error: Error message if the call failed. + """ + entry = { + "type": "simple_llm_call", + "timestamp": datetime.now().isoformat(), + "model": model, + "prompt": prompt[:500] if len(prompt) > 500 else prompt, + "prompt_full": prompt, + "response": response[:500] if len(response) > 500 else response, + "response_full": response, + "execution_time_ms": execution_time_ms, + "depth": depth, + "agent_name": agent_name, + "recursive": False, + "success": error is None, + } + + # Add error if present + if error is not None: + entry["error"] = error + + # Add optional context + if parent_iteration is not None: + entry["parent_iteration"] = parent_iteration + if parent_block_index is not None: + entry["parent_block_index"] = parent_block_index + + # Add batch metadata + if batch_index is not None: + entry["batch_index"] = batch_index + if batch_size is not None: + entry["batch_size"] = batch_size + + # Serialize and write atomically under lock + line = json.dumps(entry) + "\n" + with self._lock: + with open(self.log_file_path, "a") as f: + f.write(line) + + @property + def iteration_count(self) -> int: + """Return the number of iterations logged.""" + return self._iteration_count + + def get_log_path(self) -> str: + """Return the path to the log file.""" + return self.log_file_path diff --git a/contributing/samples/rlm/adk_rlm/logging/verbose.py b/contributing/samples/rlm/adk_rlm/logging/verbose.py new file mode 100644 index 0000000000..e062d90341 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/logging/verbose.py @@ -0,0 +1,403 @@ +""" +Verbose printing for RLM using rich. + +Provides console output for debugging and understanding RLM execution. +Uses a "Tokyo Night" inspired color theme. +""" + +from typing import Any + +from adk_rlm.types import CodeBlock +from adk_rlm.types import RLMIteration +from adk_rlm.types import RLMMetadata +from rich.console import Console +from rich.console import Group +from rich.panel import Panel +from rich.rule import Rule +from rich.style import Style +from rich.table import Table +from rich.text import Text + +# Tokyo Night Color Theme +COLORS = { + "primary": "#7AA2F7", # Soft blue - headers, titles + "secondary": "#BB9AF7", # Soft purple - emphasis + "success": "#9ECE6A", # Soft green - success, code + "warning": "#E0AF68", # Soft amber - warnings + "error": "#F7768E", # Soft red/pink - errors + "text": "#A9B1D6", # Soft gray-blue - regular text + "muted": "#565F89", # Muted gray - less important + "accent": "#7DCFFF", # Bright cyan - accents + "bg_subtle": "#1A1B26", # Dark background + "border": "#3B4261", # Border color + "code_bg": "#24283B", # Code background +} + +# Rich styles +STYLE_PRIMARY = Style(color=COLORS["primary"], bold=True) +STYLE_SECONDARY = Style(color=COLORS["secondary"]) +STYLE_SUCCESS = Style(color=COLORS["success"]) +STYLE_WARNING = Style(color=COLORS["warning"]) +STYLE_ERROR = Style(color=COLORS["error"]) +STYLE_TEXT = Style(color=COLORS["text"]) +STYLE_MUTED = Style(color=COLORS["muted"]) +STYLE_ACCENT = Style(color=COLORS["accent"], bold=True) + + +def _to_str(value: Any) -> str: + """Convert any value to string safely.""" + if isinstance(value, str): + return value + return str(value) + + +class VerbosePrinter: + """ + Rich console printer for RLM verbose output. + + Displays beautiful, structured output showing the RLM's execution: + - Initial configuration panel + - Each iteration with response summaries + - Code execution with results + - Sub-calls to other models + """ + + def __init__(self, enabled: bool = True): + """ + Initialize the verbose printer. + + Args: + enabled: Whether verbose printing is enabled. + """ + self.enabled = enabled + self.console = Console() if enabled else None + self._iteration_count = 0 + + def print_header( + self, + backend: str, + model: str, + environment: str, + max_iterations: int, + max_depth: int, + other_backends: list[str] | None = None, + ) -> None: + """Print the initial RLM configuration header.""" + if not self.enabled: + return + + # Main title + title = Text() + title.append("◆ ", style=STYLE_ACCENT) + title.append("RLM", style=Style(color=COLORS["primary"], bold=True)) + title.append(" ━ Recursive Language Model (ADK)", style=STYLE_MUTED) + + # Configuration table + config_table = Table( + show_header=False, + show_edge=False, + box=None, + padding=(0, 2), + expand=True, + ) + config_table.add_column("key", style=STYLE_MUTED, width=16) + config_table.add_column("value", style=STYLE_TEXT) + config_table.add_column("key2", style=STYLE_MUTED, width=16) + config_table.add_column("value2", style=STYLE_TEXT) + + config_table.add_row( + "Backend", + Text(backend, style=STYLE_SECONDARY), + "Environment", + Text(environment, style=STYLE_SECONDARY), + ) + config_table.add_row( + "Model", + Text(model, style=STYLE_ACCENT), + "Max Iterations", + Text(str(max_iterations), style=STYLE_WARNING), + ) + + if other_backends: + backends_text = Text(", ".join(other_backends), style=STYLE_SECONDARY) + config_table.add_row( + "Sub-models", + backends_text, + "Max Depth", + Text(str(max_depth), style=STYLE_WARNING), + ) + else: + config_table.add_row( + "Max Depth", + Text(str(max_depth), style=STYLE_WARNING), + "", + "", + ) + + panel = Panel( + config_table, + title=title, + title_align="left", + border_style=COLORS["border"], + padding=(1, 2), + ) + + self.console.print() + self.console.print(panel) + self.console.print() + + def print_metadata(self, metadata: RLMMetadata) -> None: + """Print RLM metadata as header.""" + if not self.enabled: + return + + model = metadata.backend_kwargs.get("model_name", "unknown") + other = list(metadata.other_backends) if metadata.other_backends else None + + self.print_header( + backend=metadata.backend, + model=model, + environment=metadata.environment_type, + max_iterations=metadata.max_iterations, + max_depth=metadata.max_depth, + other_backends=other, + ) + + def print_iteration_start(self, iteration: int) -> None: + """Print the start of a new iteration.""" + if not self.enabled: + return + + self._iteration_count = iteration + + rule = Rule( + Text(f" Iteration {iteration} ", style=STYLE_PRIMARY), + style=COLORS["border"], + characters="─", + ) + self.console.print(rule) + + def print_completion( + self, response: Any, iteration_time: float | None = None + ) -> None: + """Print a completion response.""" + if not self.enabled: + return + + # Header with timing + header = Text() + header.append("◇ ", style=STYLE_ACCENT) + header.append("LLM Response", style=STYLE_PRIMARY) + if iteration_time: + header.append(f" ({iteration_time:.2f}s)", style=STYLE_MUTED) + + # Response content + response_str = _to_str(response) + response_text = Text(response_str, style=STYLE_TEXT) + + # Count words roughly + word_count = len(response_str.split()) + footer = Text(f"~{word_count} words", style=STYLE_MUTED) + + panel = Panel( + Group(response_text, Text(), footer), + title=header, + title_align="left", + border_style=COLORS["muted"], + padding=(0, 1), + ) + self.console.print(panel) + + def print_code_execution(self, code_block: CodeBlock) -> None: + """Print code execution details.""" + if not self.enabled: + return + + result = code_block.result + + # Header + header = Text() + header.append("▸ ", style=STYLE_SUCCESS) + header.append( + "Code Execution", style=Style(color=COLORS["success"], bold=True) + ) + if result.execution_time: + header.append(f" ({result.execution_time:.3f}s)", style=STYLE_MUTED) + + # Build content + content_parts = [] + + # Code snippet + code_text = Text() + code_text.append("Code:\n", style=STYLE_MUTED) + code_text.append(_to_str(code_block.code), style=STYLE_TEXT) + content_parts.append(code_text) + + # Stdout if present + stdout_str = _to_str(result.stdout) if result.stdout else "" + if stdout_str.strip(): + stdout_text = Text() + stdout_text.append("\nOutput:\n", style=STYLE_MUTED) + stdout_text.append(stdout_str, style=STYLE_SUCCESS) + content_parts.append(stdout_text) + + # Stderr if present (error) + stderr_str = _to_str(result.stderr) if result.stderr else "" + if stderr_str.strip(): + stderr_text = Text() + stderr_text.append("\nError:\n", style=STYLE_MUTED) + stderr_text.append(stderr_str, style=STYLE_ERROR) + content_parts.append(stderr_text) + + # Sub-calls summary + if result.rlm_calls: + calls_text = Text() + calls_text.append( + f"\n↳ {len(result.rlm_calls)} sub-call(s)", style=STYLE_SECONDARY + ) + content_parts.append(calls_text) + + panel = Panel( + Group(*content_parts), + title=header, + title_align="left", + border_style=COLORS["success"], + padding=(0, 1), + ) + self.console.print(panel) + + def print_subcall( + self, + model: str, + prompt_preview: str, + response_preview: str, + execution_time: float | None = None, + ) -> None: + """Print a sub-call to another model.""" + if not self.enabled: + return + + # Header + header = Text() + header.append(" ↳ ", style=STYLE_SECONDARY) + header.append("Sub-call: ", style=STYLE_SECONDARY) + header.append(_to_str(model), style=STYLE_ACCENT) + if execution_time: + header.append(f" ({execution_time:.2f}s)", style=STYLE_MUTED) + + # Content + content = Text() + content.append("Prompt: ", style=STYLE_MUTED) + content.append(_to_str(prompt_preview)[:200], style=STYLE_TEXT) + if len(prompt_preview) > 200: + content.append("...", style=STYLE_MUTED) + content.append("\nResponse: ", style=STYLE_MUTED) + content.append(_to_str(response_preview)[:200], style=STYLE_TEXT) + if len(response_preview) > 200: + content.append("...", style=STYLE_MUTED) + + panel = Panel( + content, + title=header, + title_align="left", + border_style=COLORS["secondary"], + padding=(0, 1), + ) + self.console.print(panel) + + def print_iteration( + self, iteration: RLMIteration, iteration_num: int + ) -> None: + """ + Print a complete iteration including response and code executions. + """ + if not self.enabled: + return + + # Print iteration header + self.print_iteration_start(iteration_num) + + # Print the LLM response + self.print_completion(iteration.response, iteration.iteration_time) + + # Print each code block execution + for code_block in iteration.code_blocks: + self.print_code_execution(code_block) + + # Print any sub-calls made during this code block + for call in code_block.result.rlm_calls: + self.print_subcall( + model=call.root_model, + prompt_preview=_to_str(call.prompt) if call.prompt else "", + response_preview=_to_str(call.response) if call.response else "", + execution_time=call.execution_time, + ) + + def print_final_answer(self, answer: Any) -> None: + """Print the final answer.""" + if not self.enabled: + return + + # Title + title = Text() + title.append("★ ", style=STYLE_WARNING) + title.append( + "Final Answer", style=Style(color=COLORS["warning"], bold=True) + ) + + # Answer content + answer_text = Text(_to_str(answer), style=STYLE_TEXT) + + panel = Panel( + answer_text, + title=title, + title_align="left", + border_style=COLORS["warning"], + padding=(1, 2), + ) + + self.console.print() + self.console.print(panel) + self.console.print() + + def print_summary( + self, + total_iterations: int, + total_time: float, + usage_summary: dict[str, Any] | None = None, + ) -> None: + """Print a summary at the end of execution.""" + if not self.enabled: + return + + # Summary table + summary_table = Table( + show_header=False, + show_edge=False, + box=None, + padding=(0, 2), + ) + summary_table.add_column("metric", style=STYLE_MUTED) + summary_table.add_column("value", style=STYLE_ACCENT) + + summary_table.add_row("Iterations", str(total_iterations)) + summary_table.add_row("Total Time", f"{total_time:.2f}s") + + if usage_summary: + total_input = sum( + m.get("total_input_tokens", 0) + for m in usage_summary.get("model_usage_summaries", {}).values() + ) + total_output = sum( + m.get("total_output_tokens", 0) + for m in usage_summary.get("model_usage_summaries", {}).values() + ) + if total_input or total_output: + summary_table.add_row("Input Tokens", f"{total_input:,}") + summary_table.add_row("Output Tokens", f"{total_output:,}") + + self.console.print() + self.console.print(Rule(style=COLORS["border"], characters="═")) + self.console.print(summary_table, justify="center") + self.console.print(Rule(style=COLORS["border"], characters="═")) + self.console.print() diff --git a/contributing/samples/rlm/adk_rlm/main.py b/contributing/samples/rlm/adk_rlm/main.py new file mode 100644 index 0000000000..2ac2f0dfd1 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/main.py @@ -0,0 +1,379 @@ +""" +Main entry point and convenience wrapper for ADK-RLM. + +This module provides the RLM class which is the primary interface +for using Recursive Language Models with ADK framework integration. +""" + +from pathlib import Path +from typing import Any +from typing import AsyncGenerator +from typing import TYPE_CHECKING + +from adk_rlm.agents.rlm_agent import RLMAgent +from adk_rlm.events import RLMEventType +from adk_rlm.files import FileLoader +from adk_rlm.files import FileParser +from adk_rlm.files import FileSource +from adk_rlm.logging.rlm_logger import RLMLogger +from adk_rlm.types import RLMChatCompletion +from google.adk import Runner +from google.adk.events.event import Event +from google.adk.sessions import InMemorySessionService +from google.genai import types + +if TYPE_CHECKING: + from adk_rlm.files import LazyFileCollection + + +class RLM: + """ + Recursive Language Model - main user-facing class. + + This provides a simple interface to the RLM functionality, using + Google ADK framework for agent execution and session management. + + The primary interface is `run_streaming()` which yields ADK Events + for real-time UI updates. For simple synchronous usage, use the + module-level `completion()` convenience function. + + Example: + ```python + from adk_rlm import RLM, completion + + # Streaming API for real-time UI updates + rlm = RLM(model="gemini-3-pro-preview") + + async for event in rlm.run_streaming(context, prompt): + event_type = event.custom_metadata.get("event_type") + if event_type == "rlm.final.answer": + print(event.custom_metadata["answer"]) + + # Or use the convenience function for simple synchronous usage + result = completion( + context="Your long document here...", + prompt="What are the key themes?", + ) + print(result.response) + ``` + """ + + def __init__( + self, + model: str = "gemini-3-pro-preview", + sub_model: str | None = None, + max_iterations: int = 30, + max_depth: int = 5, + custom_system_prompt: str | None = None, + log_dir: str | None = None, + verbose: bool = False, + persistent: bool = False, + # File handling + file_sources: dict[str, FileSource] | None = None, + file_parsers: list[FileParser] | None = None, + base_path: str | Path | None = None, + # Legacy kwargs for compatibility + backend: str | None = None, + backend_kwargs: dict[str, Any] | None = None, + **kwargs, + ): + """ + Initialize the RLM. + + Args: + model: The main model to use (default: gemini-3-pro-preview). + sub_model: The model for recursive sub-calls (defaults to model). + max_iterations: Maximum number of RLM iterations (default: 30). + max_depth: Maximum recursion depth (default: 5). + custom_system_prompt: Custom system prompt (uses default if None). + log_dir: Directory for JSONL logs (None disables logging). + verbose: Whether to print Rich console output. + persistent: Whether to persist REPL state across calls. + file_sources: Dictionary of named file sources for file loading. + file_parsers: List of file parsers for file loading. + base_path: Base path for local file source. + backend: Legacy parameter (ignored, always uses Gemini). + backend_kwargs: Legacy parameter for backend configuration. + """ + # Handle legacy backend_kwargs + if ( + backend_kwargs + and "model_name" in backend_kwargs + and model == "gemini-3-pro-preview" + ): + model = backend_kwargs["model_name"] + + # Create logger if log_dir specified + logger = RLMLogger(log_dir) if log_dir else None + + # Create the underlying agent + self._agent = RLMAgent( + name="rlm_agent", + model=model, + sub_model=sub_model, + max_iterations=max_iterations, + max_depth=max_depth, + custom_system_prompt=custom_system_prompt, + logger=logger, + verbose=verbose, + persistent=persistent, + ) + + # Create session service for ADK Runner + self._session_service = InMemorySessionService() + + # Create ADK Runner + self._runner = Runner( + app_name="adk_rlm", + agent=self._agent, + session_service=self._session_service, + ) + + # Create file loader for file handling + self._file_loader = FileLoader( + sources=file_sources, + parsers=file_parsers, + base_path=base_path, + ) + + # Store config for reference + self.model = model + self.sub_model = sub_model or model + self.max_iterations = max_iterations + self.max_depth = max_depth + self.persistent = persistent + self.verbose = verbose + self._logger = logger + + async def run_streaming( + self, + context: str | dict | list, + prompt: str | None = None, + conversation_history: list[dict[str, str]] | None = None, + ) -> AsyncGenerator[Event, None]: + """ + Run RLM with streaming events. + + Yields ADK Event objects with custom_metadata containing: + - event_type: The type of event (see RLMEventType) + - Additional event-specific data + + This is the primary interface for building UIs on top of RLM. + + Args: + context: The context to analyze. + prompt: Optional user prompt/question about the context. + conversation_history: Optional list of previous conversation messages. + Each message should have 'role' ('user' or 'assistant') and 'content'. + This enables multi-turn conversations where the agent remembers + previous questions and answers. + + Yields: + ADK Event objects with RLM-specific metadata. + + Example: + ```python + async for event in rlm.run_streaming(context, prompt): + event_type = event.custom_metadata.get("event_type") + + if event_type == "rlm.iteration.start": + print(f"Starting iteration {event.custom_metadata['iteration']}") + + elif event_type == "rlm.code.end": + if event.custom_metadata.get("output"): + print(event.custom_metadata["output"]) + + elif event_type == "rlm.final.answer": + print(f"Final: {event.custom_metadata['answer']}") + ``` + """ + # Create session with context in state + session = await self._session_service.create_session( + app_name="adk_rlm", + user_id="default_user", + state={ + "rlm_context": context, + "rlm_prompt": prompt, + "rlm_conversation_history": conversation_history, + }, + ) + + # Build user message (the agent reads from session state) + message = types.Content( + role="user", parts=[types.Part(text=prompt or "Analyze the context.")] + ) + + # Run agent and yield events + async for event in self._runner.run_async( + user_id="default_user", + session_id=session.id, + new_message=message, + ): + yield event + + def close(self) -> None: + """Clean up resources (call when done with persistent mode).""" + self._agent.close() + + def __enter__(self) -> "RLM": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + self.close() + return False + + @property + def log_path(self) -> str | None: + """Return the path to the log file if logging is enabled.""" + return self._logger.get_log_path() if self._logger else None + + @property + def file_loader(self) -> FileLoader: + """Access the file loader for direct file operations.""" + return self._file_loader + + @property + def agent(self) -> RLMAgent: + """Access the underlying RLM agent.""" + return self._agent + + @property + def runner(self) -> Runner: + """Access the ADK Runner for advanced usage.""" + return self._runner + + def load_files( + self, files: list[str], lazy: bool = True + ) -> "LazyFileCollection | list": + """ + Load files without running RLM. + + Convenience method for loading files directly. + + Args: + files: List of file paths/URIs/globs + lazy: If True, return LazyFileCollection. If False, return parsed content. + + Returns: + LazyFileCollection if lazy=True, else list of ParsedContent + """ + if lazy: + return self._file_loader.create_lazy_files(files) + else: + return self._file_loader.load_files(files) + + +def completion( + context: str | dict | list | None = None, + prompt: str | None = None, + *, + files: list[str] | None = None, + model: str = "gemini-3-pro-preview", + sub_model: str | None = None, + max_iterations: int = 30, + max_depth: int = 5, + log_dir: str | None = None, + verbose: bool = False, +) -> RLMChatCompletion: + """ + Convenience function for simple synchronous RLM completion. + + This creates a temporary RLM instance, runs `run_streaming()`, and + collects the final answer. For more control, use the RLM class directly. + + Args: + context: The context/data to analyze. + prompt: Optional user prompt/question about the context. + files: List of file paths/URIs/globs to load as context. + model: The main model to use (default: gemini-3-pro-preview). + sub_model: The model for recursive sub-calls (defaults to model). + max_iterations: Maximum number of RLM iterations (default: 30). + max_depth: Maximum recursion depth (default: 5). + log_dir: Directory for JSONL logs (None disables logging). + verbose: Whether to print Rich console output. + + Returns: + RLMChatCompletion with response and metadata. + + Example: + ```python + from adk_rlm import completion + + # Simple usage + result = completion( + context="Your document here...", + prompt="Summarize the key points", + ) + print(result.response) + + # With files + result = completion( + files=["./docs/**/*.md"], + prompt="What are the main themes?", + ) + ``` + """ + import asyncio + import time + + time_start = time.perf_counter() + + # Create RLM instance + rlm = RLM( + model=model, + sub_model=sub_model, + max_iterations=max_iterations, + max_depth=max_depth, + log_dir=log_dir, + verbose=verbose, + ) + + # Build context from files if provided + if files: + file_context = rlm.file_loader.build_context(files, lazy=True) + if context is not None: + ctx = _merge_context(context, file_context) + else: + ctx = file_context + else: + if context is None: + raise ValueError("Either 'context' or 'files' must be provided") + ctx = context + + # Run streaming and collect final answer + async def _run(): + final_answer = None + async for event in rlm.run_streaming(ctx, prompt): + if event.custom_metadata: + event_type = event.custom_metadata.get("event_type") + if event_type == RLMEventType.FINAL_ANSWER.value: + final_answer = event.custom_metadata.get("answer") + return final_answer + + try: + final_answer = asyncio.run(_run()) + finally: + rlm.close() + + time_end = time.perf_counter() + + return RLMChatCompletion( + root_model=model, + prompt=context or str(files), + response=final_answer or "", + usage_summary=None, + execution_time=time_end - time_start, + ) + + +def _merge_context( + context: str | dict | list, + file_context: dict, +) -> dict: + """Merge direct context with file context.""" + if isinstance(context, str): + return {"user_context": context, **file_context} + elif isinstance(context, dict): + return {**context, **file_context} + else: + return {"user_context": context, **file_context} diff --git a/contributing/samples/rlm/adk_rlm/prompts.py b/contributing/samples/rlm/adk_rlm/prompts.py new file mode 100644 index 0000000000..e5541e9eb9 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/prompts.py @@ -0,0 +1,346 @@ +""" +System and user prompts for ADK-RLM. + +These prompts guide the model's behavior in the REPL environment. +""" + +import textwrap + +from adk_rlm.types import QueryMetadata + +# System prompt for the REPL environment +RLM_SYSTEM_PROMPT = textwrap.dedent( + """You are tasked with answering a query with associated context. You can access, transform, and analyze this context interactively in a REPL environment that can recursively query sub-LLMs, which you are strongly encouraged to use as much as possible. You will be queried iteratively until you provide a final answer. + +The REPL environment is initialized with: +1. A `context` variable that contains extremely important information about your query. You should check the content of the `context` variable to understand what you are working with. Make sure you look through it sufficiently as you answer your query. +2. A `llm_query(prompt, context=None, model=None, recursive=True)` function that allows you to query an LLM inside your REPL environment. By default (recursive=True), this creates a nested RLM execution that can itself execute code and make further llm_query calls - enabling deep recursive reasoning. Set recursive=False for simple LLM calls without code execution. The optional `context` parameter lets you pass objects (files, collections, dicts) directly to the child agent - the child receives this as its `context` variable. +3. A `llm_query_batched(prompts, contexts=None, model=None, recursive=False)` function for concurrent queries. IMPORTANT: Keep recursive=False (the default) for extraction, summarization, and Q&A tasks - embed file.content in your prompts. Only use recursive=True when children genuinely need to execute code. Results are returned in the same order as prompts. +4. The ability to use `print()` statements to view the output of your REPL code and continue your reasoning. + +IMPORTANT: The `llm_query` function with recursive=True creates a full RLM execution at the next depth level. This means the sub-LLM can analyze your prompt, write and execute code, and recursively call its own llm_query. This is powerful for hierarchical decomposition of complex problems. + +WHEN TO USE CHILD AGENTS (recursive=True): +- Complex sub-analyses: When a sub-problem requires its own multi-step reasoning with code execution +- Hierarchical decomposition: When breaking a large problem into sub-problems that each need iteration +- Delegation with large context: When passing substantial context to a child that needs to explore it programmatically + +WHEN TO USE SIMPLE LLM CALLS (recursive=False): +- Parallel processing: Use `llm_query_batched` with recursive=False to process many chunks concurrently +- Summarization: Condensing text into a shorter form +- Extraction: Pulling specific information from provided text +- Classification: Categorizing or labeling content +- Simple Q&A: Questions answerable directly from the provided context without code execution +- Aggregation: Combining multiple answers into a final result + +Most tasks should use recursive=False. Only use recursive=True when the sub-task itself requires writing and executing code to explore the problem. + +WHEN TO AVOID llm_query ENTIRELY: +- Direct lookups: When you can find the answer by reading the context directly +- Small contexts: When the context fits easily in your window and doesn't need chunking +- Simple computations: When Python code alone can compute the answer +- Pattern matching: When regex or string operations can extract what you need + +IMPORTANT: Default to the simplest approach. Only spawn child agents when the sub-task genuinely requires autonomous multi-step reasoning with code execution. + +IMPORTANT: Dont turn off your brain. Just because you can spawn a child agent doesnt mean you should. And if you do spawn a child agent, make sure you give it a good prompt and context. + +WORKING WITH FILES (LazyFile / LazyFileCollection): +When your context contains files, you have two approaches: + +APPROACH 1 - Simple extraction (recursive=False, PREFERRED for most tasks): +Embed file.content directly in your prompt. This is fast and efficient for summarization, extraction, and Q&A. +Make sure to properly format the string! +```repl +files = list(context['files']) +prompts = [f"Summarize this document titled '{{f.name}}':\\n\\n{{f.content}}" for f in files] +results = llm_query_batched(prompts, recursive=False) # Fast parallel processing +``` + +APPROACH 2 - Complex analysis (recursive=True, use sparingly): +Pass the file object via context= when the child needs to write code to explore the file. +```repl +# Only use this when the child genuinely needs to run code +result = llm_query("Analyze this document programmatically", context=file, recursive=True) +``` + +KEY FILE PROPERTIES: +- `file.name` - the filename (string, e.g., "report.md") +- `file.content` - the actual text content (string, loads the file) + +CRITICAL WARNING - Common Mistake: +- WRONG: Using file.name in prompts without file.content (passes just the filename string!) +- WRONG: contexts=[f.name for f in files] (passes list of filename strings!) +- CORRECT: Embed file.content in prompt, OR pass file object via context= + +You will only be able to see truncated outputs from the REPL environment, so you should use the query LLM function on variables you want to analyze. You will find this function especially useful when you have to analyze the semantics of the context. Use these variables as buffers to build up your final answer. +Make sure to explicitly look through the entire context in REPL before answering your query. An example strategy is to first look at the context and figure out a chunking strategy, then break up the context into smart chunks, and query an LLM per chunk with a particular question and save the answers to a buffer, then query an LLM with all the buffers to produce your final answer. + +You can use the REPL environment to help you understand your context, especially if it is huge. Remember that your sub LLMs are powerful -- they can fit around 500K characters in their context window, so don't be afraid to put a lot of context into them. For example, a viable strategy is to feed 10 documents per sub-LLM query. Analyze your input data and see if it is sufficient to just fit it in a few sub-LLM calls! + +When you want to execute Python code in the REPL environment, wrap it in triple backticks with 'repl' language identifier. For example, say we want our recursive model to search for the magic number in the context (assuming the context is a string), and the context is very long, so we want to chunk it: +```repl +chunk = context[:10000] +answer = llm_query(f"What is the magic number in the context? Here is the chunk: {{chunk}}") +print(answer) +``` + + +As an example, when the context isn't that long (e.g. >100M characters), a simple but viable strategy is, based on the context chunk lengths, to combine them and query an LLM over chunks. For example, if the context is a List[str], we ask the same query over each chunk using `llm_query_batched` for concurrent processing: +```repl +query = "A man became famous for his book "The Great Gatsby". How many jobs did he have?" +# Suppose our context is ~1M chars, and we want each sub-LLM query to be ~0.1M chars so we split it into 10 chunks +chunk_size = len(context) // 10 +chunks = [] +for i in range(10): + if i < 9: + chunk_str = "\\n".join(context[i*chunk_size:(i+1)*chunk_size]) + else: + chunk_str = "\\n".join(context[i*chunk_size:]) + chunks.append(chunk_str) + +# Use batched query for concurrent processing - much faster than sequential calls! +prompts = [f"Try to answer the following query: {{query}}. Here are the documents:\\n{{chunk}}. Only answer if you are confident in your answer based on the evidence." for chunk in chunks] +answers = llm_query_batched(prompts, recursive=False) +for i, answer in enumerate(answers): + print(f"I got the answer from chunk {{i}}: {{answer}}") +final_answer = llm_query(f"Aggregating all the answers per chunk, answer the original query about total number of jobs: {{query}}\\n\\nAnswers:\\n" + "\\n".join(answers)) +``` + +As another example, after analyzing the context and realizing its separated by Markdown headers, we can maintain state through buffers by chunking the context by headers, and iteratively querying an LLM over it: +```repl +# After finding out the context is separated by Markdown headers, we can chunk, summarize, and answer +import re +sections = re.split(r'### (.+)', context["content"]) +buffers = [] +for i in range(1, len(sections), 2): + header = sections[i] + info = sections[i+1] + summary = llm_query(f"Summarize this {{header}} section: {{info}}") + buffers.append(f"{{header}}: {{summary}}") +final_answer = llm_query(f"Based on these summaries, answer the original query: {{query}}\\n\\nSummaries:\\n" + "\\n".join(buffers)) +``` +In the next step, we can return FINAL_VAR(final_answer). + +RECURSIVE CHILD AGENTS - When They Actually Help: +Use recursive=True ONLY when the child must DISCOVER its approach by examining the data - not when you can specify the extraction upfront. + +WRONG approach (can be flattened to recursive=False): +```repl +# DON'T DO THIS - the extraction task is fully specified upfront +result = llm_query( + "Extract all API endpoints and their auth status from this file", + context=file, + recursive=True # WASTEFUL - child doesn't need to write code +) +``` + +RIGHT approach - when child must explore to find its approach: +Task: "These 50 config files should follow a common schema. Find schema violations and explain each one." + +```repl +# STEP 1: Sample the data to understand structure +files = list(context['files']) +samples = [f.content[:2000] for f in files[:5]] +print(f"Examining {len(files)} config files") +print(f"Sample structure:\n{samples[0][:500]}") +``` + +```repl +# STEP 2: Delegate schema inference to a child that must EXPLORE +# The child doesn't know the schema - it must: +# - Read multiple files to infer common patterns +# - Write code to parse and compare structures +# - Iteratively refine its understanding +schema_analysis = llm_query( + "Infer the common schema from these config files. Write code to: " + "1) Parse each file's structure, 2) Find fields that appear in >80% of files, " + "3) Detect type patterns (string vs number vs array). Return the inferred schema.", + context={'files': files}, + recursive=True # Child must iterate: parse → compare → refine +) +print(f"Inferred schema: {schema_analysis[:500]}") +``` + +```repl +# STEP 3: Delegate violation detection - child must write validation code +# Each file may violate differently - child needs to: +# - Apply the schema programmatically +# - Trace WHY each violation occurred +# - Handle edge cases it discovers +violations = llm_query( + f"Using this schema:\n{schema_analysis}\n\n" + "Validate each file and explain violations. Write code to check each field " + "and trace the cause of mismatches.", + context={'files': files}, + recursive=True # Child writes custom validation logic +) +print(violations) +``` + +Why recursive=True was needed: +- Schema wasn't known upfront - child had to DISCOVER it by examining files +- Validation logic depended on what schema was found +- Each violation type needed different investigation code +- Children maintained state across iterations (schema → validation → explanation) + +The test: If you can write the full extraction prompt without seeing the data, +use recursive=False. If the child must look at data to decide what to do, use recursive=True + +COMPLETE FILE PROCESSING WORKFLOW: +When analyzing multiple files, follow this pattern to avoid iteration explosion: + +```repl +# STEP 1: Examine what you have +files = list(context['files']) +print(f"Processing {{len(files)}} files: {{[f.name for f in files[:5]]}}...") + +# STEP 2: Process files in parallel with recursive=False (IMPORTANT!) +# Embed file.content in prompts - this is fast and avoids spawning recursive agents +prompts = [ + f"Analyze '{{f.name}}' for key themes, risks, and opportunities:\\n\\n{{f.content}}" + for f in files +] +results = llm_query_batched(prompts, recursive=False) # NOT recursive=True! + +# STEP 3: Collect results +for f, result in zip(files, results): + print(f"{{f.name}}: {{result[:200]}}...") + +# STEP 4: Aggregate AT THIS LEVEL - do not delegate aggregation! +combined = "\\n\\n".join(f"## {{f.name}}\\n{{r}}" for f, r in zip(files, results)) +final = llm_query( + f"Synthesize these analyses into a comprehensive answer:\\n\\n{{combined}}", + recursive=False +) +print(final) +``` + +CRITICAL - YOU MUST AGGREGATE: +After spawning child queries, YOU must collect and synthesize their results. +Do NOT expect children to produce your final answer - they return to you, not to the user. +Always end with FINAL() or FINAL_VAR() after aggregating. + +CONTEXT VALIDATION (if you are a child agent): +If your context seems unexpectedly small or looks like filenames instead of content, something went wrong. +```repl +# First thing: check what you actually received +print(f"Context type: {{type(context)}}") +print(f"Context length: {{len(str(context))}}") +if isinstance(context, str) and len(context) < 500: + print(f"WARNING: Context is very small - may be filename instead of content: {{context}}") +``` +If you only received a filename string (like "document.md") instead of actual content, inform the parent that you cannot proceed without the file content. + +IMPORTANT: When you are done with the iterative process, you MUST provide a final answer inside a FINAL function when you have completed your task, NOT in code. Do not use these tags unless you have completed your task. You have two options: +1. Use FINAL(your final answer here) to provide the answer directly +2. Use FINAL_VAR(variable_name) to return a variable you have created in the REPL environment as your final output + +Think step by step carefully, plan, and execute this plan immediately in your response -- do not just say "I will do this" or "I will do that". Output to the REPL environment and recursive LLMs as much as possible. Remember to explicitly answer the original query in your final answer. +""" +) + +USER_PROMPT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the prompt.\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:""" + +USER_PROMPT_WITH_ROOT = """Think step-by-step on what to do using the REPL environment (which contains the context) to answer the original prompt: \"{root_prompt}\".\n\nContinue using the REPL environment, which has the `context` variable, and querying sub-LLMs by writing to ```repl``` tags, and determine your answer. Your next action:""" + + +def build_rlm_system_prompt( + system_prompt: str, + query_metadata: QueryMetadata, +) -> list[dict[str, str]]: + """ + Build the initial system prompt for the REPL environment. + + Args: + system_prompt: The base system prompt. + query_metadata: Metadata about the query context. + + Returns: + List of message dictionaries. + """ + context_lengths = query_metadata.context_lengths + context_total_length = query_metadata.context_total_length + context_type = query_metadata.context_type + + # Truncate if too many chunks + if len(context_lengths) > 100: + others = len(context_lengths) - 100 + context_lengths_str = str(context_lengths[:100]) + f"... [{others} others]" + else: + context_lengths_str = str(context_lengths) + + metadata_prompt = ( + f"Your context is a {context_type} with {context_total_length} total" + " characters, and is broken up into chunks of char lengths:" + f" {context_lengths_str}." + ) + + return [ + {"role": "system", "content": system_prompt}, + {"role": "assistant", "content": metadata_prompt}, + ] + + +def build_user_prompt( + root_prompt: str | None = None, + iteration: int = 0, + context_count: int = 1, + history_count: int = 0, +) -> dict[str, str]: + """ + Build the user prompt for an iteration. + + Args: + root_prompt: Optional root prompt from the user. + iteration: Current iteration number. + context_count: Number of contexts loaded. + history_count: Number of conversation histories stored. + + Returns: + A message dictionary with the user prompt. + """ + if iteration == 0: + safeguard = ( + "You have not interacted with the REPL environment or seen your prompt" + " / context yet. Your next action should be to look through and figure" + " out how to answer the prompt, so don't just provide a final answer" + " yet.\n\n" + ) + prompt = safeguard + ( + USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) + if root_prompt + else USER_PROMPT + ) + else: + prompt = ( + "The history before is your previous interactions with the REPL" + " environment. " + + ( + USER_PROMPT_WITH_ROOT.format(root_prompt=root_prompt) + if root_prompt + else USER_PROMPT + ) + ) + + # Inform model about multiple contexts if present + if context_count > 1: + prompt += ( + f"\n\nNote: You have {context_count} contexts available (context_0" + f" through context_{context_count - 1})." + ) + + # Inform model about prior conversation histories if present + if history_count > 0: + if history_count == 1: + prompt += ( + "\n\nNote: You have 1 prior conversation history available in the" + " `history` variable." + ) + else: + prompt += ( + f"\n\nNote: You have {history_count} prior conversation histories" + f" available (history_0 through history_{history_count - 1})." + ) + + return {"role": "user", "content": prompt} diff --git a/contributing/samples/rlm/adk_rlm/repl/__init__.py b/contributing/samples/rlm/adk_rlm/repl/__init__.py new file mode 100644 index 0000000000..11b2f0aea6 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/repl/__init__.py @@ -0,0 +1,6 @@ +"""REPL environment implementations.""" + +from adk_rlm.repl.local_repl import LocalREPL +from adk_rlm.repl.safe_builtins import SAFE_BUILTINS + +__all__ = ["LocalREPL", "SAFE_BUILTINS"] diff --git a/contributing/samples/rlm/adk_rlm/repl/local_repl.py b/contributing/samples/rlm/adk_rlm/repl/local_repl.py new file mode 100644 index 0000000000..ad01971750 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/repl/local_repl.py @@ -0,0 +1,346 @@ +""" +Local REPL environment for ADK-RLM. + +This module provides a sandboxed Python REPL environment that can +execute code with access to context data and LLM query functions. +""" + +from collections.abc import Callable +from contextlib import contextmanager +import copy +import io +import json +import os +import shutil +import sys +import tempfile +import threading +import time +from typing import Any +import uuid + +from adk_rlm.repl.safe_builtins import SAFE_BUILTINS +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMChatCompletion + + +class LocalREPL: + """ + Local REPL environment with persistent Python namespace. + Executes code in a sandboxed namespace with access to context data. + """ + + def __init__( + self, + llm_query_fn: Callable[[str, str | None], str] | None = None, + llm_query_batched_fn: ( + Callable[[list[str], str | None], list[str]] | None + ) = None, + context_payload: dict | list | str | None = None, + setup_code: str | None = None, + ): + """ + Initialize the LocalREPL environment. + + Args: + llm_query_fn: Function to query an LLM with a prompt. + llm_query_batched_fn: Function to query an LLM with multiple prompts. + context_payload: Initial context to load into the environment. + setup_code: Optional Python code to execute during setup. + """ + self.llm_query_fn = llm_query_fn + self.llm_query_batched_fn = llm_query_batched_fn + + self.original_cwd = os.getcwd() + self.temp_dir = tempfile.mkdtemp(prefix=f"repl_env_{uuid.uuid4()}_") + self._lock = threading.Lock() + self._context_count: int = 0 + self._history_count: int = 0 + + # Track LLM calls made during code execution + self._pending_llm_calls: list[RLMChatCompletion] = [] + + # Setup globals, locals + self._setup() + + # Load context if provided + if context_payload is not None: + self.load_context(context_payload) + + # Run setup code if provided + if setup_code: + self.execute_code(setup_code) + + def _setup(self) -> None: + """Setup the environment with sandboxed globals and helper functions.""" + self.globals: dict[str, Any] = { + "__builtins__": SAFE_BUILTINS.copy(), + "__name__": "__main__", + } + self.locals: dict[str, Any] = {} + + # Add helper functions + self.globals["FINAL_VAR"] = self._final_var + self.globals["llm_query"] = self._llm_query + self.globals["llm_query_batched"] = self._llm_query_batched + + def _final_var(self, variable_name: str) -> str: + """Return the value of a variable as a final answer.""" + variable_name = variable_name.strip().strip("\"'") + if variable_name in self.locals: + return str(self.locals[variable_name]) + return f"Error: Variable '{variable_name}' not found" + + def _llm_query( + self, + prompt: str, + context: Any = None, + model: str | None = None, + recursive: bool = True, + ) -> str: + """Query the LLM with a single prompt. + + Args: + prompt: The prompt to send to the LLM. + context: Optional context object to pass to the child agent. + model: Optional model override. + recursive: If True and depth allows, use recursive RLM execution. + """ + if self.llm_query_fn is None: + return "Error: No LLM query function configured" + + try: + result = self.llm_query_fn( + prompt, context=context, model=model, recursive=recursive + ) + + # Track this LLM call if it returns an RLMChatCompletion + if isinstance(result, tuple) and len(result) == 2: + response, completion = result + if isinstance(completion, RLMChatCompletion): + self._pending_llm_calls.append(completion) + return response + + return result + except Exception as e: + return f"Error: LLM query failed - {e}" + + def _llm_query_batched( + self, + prompts: list[str], + contexts: list[Any] | None = None, + model: str | None = None, + recursive: bool = False, + ) -> list[str]: + """Query the LLM with multiple prompts concurrently. + + Args: + prompts: List of prompts to send. + contexts: Optional list of context objects (same length as prompts). + model: Optional model override. + recursive: If True, use recursive RLM execution for each prompt. + Default is False for performance. + """ + if self.llm_query_batched_fn is None: + if contexts is not None: + return [ + self._llm_query(p, context=c, model=model, recursive=recursive) + for p, c in zip(prompts, contexts) + ] + return [ + self._llm_query(p, model=model, recursive=recursive) for p in prompts + ] + + try: + results = self.llm_query_batched_fn( + prompts, contexts=contexts, model=model, recursive=recursive + ) + + # Handle tuple results with completions + if results and isinstance(results[0], tuple): + responses = [] + for item in results: + if len(item) == 2: + response, completion = item + if isinstance(completion, RLMChatCompletion): + self._pending_llm_calls.append(completion) + responses.append(response) + else: + responses.append(str(item)) + return responses + + return results + except Exception as e: + return [f"Error: LLM query failed - {e}"] * len(prompts) + + def load_context(self, context_payload: dict | list | str) -> None: + """Load context into the environment as context_0 (and 'context' alias).""" + self.add_context(context_payload, 0) + + def add_context( + self, context_payload: dict | list | str, context_index: int | None = None + ) -> int: + """ + Add a context with versioned variable name. + + Args: + context_payload: The context data to add. + context_index: Optional explicit index. If None, auto-increments. + + Returns: + The context index used. + """ + if context_index is None: + context_index = self._context_count + + var_name = f"context_{context_index}" + + if isinstance(context_payload, str): + context_path = os.path.join(self.temp_dir, f"context_{context_index}.txt") + with open(context_path, "w") as f: + f.write(context_payload) + self.execute_code( + f"with open(r'{context_path}', 'r') as f:\n {var_name} = f.read()" + ) + else: + # Try JSON serialization first for simple data structures + try: + context_path = os.path.join( + self.temp_dir, f"context_{context_index}.json" + ) + with open(context_path, "w") as f: + json.dump(context_payload, f) + self.execute_code( + f"import json\nwith open(r'{context_path}', 'r') as f:\n " + f" {var_name} = json.load(f)" + ) + except (TypeError, ValueError): + # For complex objects (e.g., LazyFileCollection), inject directly + self.locals[var_name] = context_payload + + # Alias context_0 as 'context' for backward compatibility + if context_index == 0: + if var_name in self.locals: + self.locals["context"] = self.locals[var_name] + else: + self.execute_code(f"context = {var_name}") + + self._context_count = max(self._context_count, context_index + 1) + return context_index + + def get_context_count(self) -> int: + """Return the number of contexts loaded.""" + return self._context_count + + def add_history( + self, + message_history: list[dict[str, Any]], + history_index: int | None = None, + ) -> int: + """ + Store a conversation's message history as a versioned variable. + + Args: + message_history: The list of message dicts from a completion call. + history_index: Optional explicit index. If None, auto-increments. + + Returns: + The history index used. + """ + if history_index is None: + history_index = self._history_count + + var_name = f"history_{history_index}" + + # Store deep copy to avoid reference issues + self.locals[var_name] = copy.deepcopy(message_history) + + # Alias history_0 as 'history' for convenience + if history_index == 0: + self.locals["history"] = self.locals[var_name] + + self._history_count = max(self._history_count, history_index + 1) + return history_index + + def get_history_count(self) -> int: + """Return the number of conversation histories stored.""" + return self._history_count + + @contextmanager + def _capture_output(self): + """Thread-safe context manager to capture stdout/stderr.""" + with self._lock: + old_stdout, old_stderr = sys.stdout, sys.stderr + stdout_buf, stderr_buf = io.StringIO(), io.StringIO() + try: + sys.stdout, sys.stderr = stdout_buf, stderr_buf + yield stdout_buf, stderr_buf + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + @contextmanager + def _temp_cwd(self): + """Temporarily change to temp directory for execution.""" + old_cwd = os.getcwd() + try: + os.chdir(self.temp_dir) + yield + finally: + os.chdir(old_cwd) + + def execute_code(self, code: str) -> REPLResult: + """Execute code in the persistent namespace and return result.""" + start_time = time.perf_counter() + + # Clear pending LLM calls from previous execution + self._pending_llm_calls = [] + + with self._capture_output() as (stdout_buf, stderr_buf), self._temp_cwd(): + try: + combined = {**self.globals, **self.locals} + exec(code, combined, combined) + + # Update locals with new variables + for key, value in combined.items(): + if key not in self.globals and not key.startswith("_"): + self.locals[key] = value + + stdout = stdout_buf.getvalue() + stderr = stderr_buf.getvalue() + except Exception as e: + stdout = stdout_buf.getvalue() + stderr = stderr_buf.getvalue() + f"\n{type(e).__name__}: {e}" + + return REPLResult( + stdout=stdout, + stderr=stderr, + locals=self.locals.copy(), + execution_time=time.perf_counter() - start_time, + rlm_calls=self._pending_llm_calls.copy(), + ) + + def reset(self) -> None: + """Reset the REPL environment to initial state.""" + self._setup() + self._context_count = 0 + self._history_count = 0 + self._pending_llm_calls = [] + + def cleanup(self) -> None: + """Clean up temp directory and reset state.""" + try: + shutil.rmtree(self.temp_dir) + except Exception: + pass + self.globals.clear() + self.locals.clear() + + def __enter__(self) -> "LocalREPL": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> bool: + self.cleanup() + return False + + def __del__(self) -> None: + self.cleanup() diff --git a/contributing/samples/rlm/adk_rlm/repl/safe_builtins.py b/contributing/samples/rlm/adk_rlm/repl/safe_builtins.py new file mode 100644 index 0000000000..cdc4a87d38 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/repl/safe_builtins.py @@ -0,0 +1,111 @@ +""" +Safe builtins for the REPL environment. + +This module provides a sandboxed set of Python builtins that blocks +dangerous operations like eval, exec, and input while allowing +standard utility functions. +""" + +SAFE_BUILTINS: dict = { + # Core types and functions + "print": print, + "len": len, + "str": str, + "int": int, + "float": float, + "list": list, + "dict": dict, + "set": set, + "tuple": tuple, + "bool": bool, + "type": type, + "isinstance": isinstance, + "issubclass": issubclass, + # Iteration + "enumerate": enumerate, + "zip": zip, + "map": map, + "filter": filter, + "sorted": sorted, + "reversed": reversed, + "range": range, + # Math + "min": min, + "max": max, + "sum": sum, + "abs": abs, + "round": round, + "any": any, + "all": all, + "pow": pow, + "divmod": divmod, + # String + "chr": chr, + "ord": ord, + "hex": hex, + "bin": bin, + "oct": oct, + "repr": repr, + "ascii": ascii, + "format": format, + # Object + "hash": hash, + "id": id, + "iter": iter, + "next": next, + "slice": slice, + "callable": callable, + "hasattr": hasattr, + "getattr": getattr, + "setattr": setattr, + "delattr": delattr, + "dir": dir, + "vars": vars, + # Types + "bytes": bytes, + "bytearray": bytearray, + "memoryview": memoryview, + "complex": complex, + "object": object, + "super": super, + "property": property, + "staticmethod": staticmethod, + "classmethod": classmethod, + # Imports (controlled) + "__import__": __import__, + "open": open, + # Exceptions + "Exception": Exception, + "BaseException": BaseException, + "ValueError": ValueError, + "TypeError": TypeError, + "KeyError": KeyError, + "IndexError": IndexError, + "AttributeError": AttributeError, + "FileNotFoundError": FileNotFoundError, + "OSError": OSError, + "IOError": IOError, + "RuntimeError": RuntimeError, + "NameError": NameError, + "ImportError": ImportError, + "StopIteration": StopIteration, + "AssertionError": AssertionError, + "NotImplementedError": NotImplementedError, + "ArithmeticError": ArithmeticError, + "LookupError": LookupError, + "Warning": Warning, + "ZeroDivisionError": ZeroDivisionError, + "OverflowError": OverflowError, + "FloatingPointError": FloatingPointError, + "KeyboardInterrupt": KeyboardInterrupt, + "SystemExit": SystemExit, + # BLOCKED - set to None to raise clear errors + "input": None, + "eval": None, + "exec": None, + "compile": None, + "globals": None, + "locals": None, + "breakpoint": None, + "__debug__": True, +} diff --git a/contributing/samples/rlm/adk_rlm/templates/index.html b/contributing/samples/rlm/adk_rlm/templates/index.html new file mode 100644 index 0000000000..6c7421cf8f --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/templates/index.html @@ -0,0 +1,2234 @@ + + + + + + ADK-RLM + + + + + + +
+
+ + + Connecting... +
+
+ + arrow_forward + Set context sources + + +
+
+ + + +
+ +
+ +
+
No sessions yet
+
+
+ +
+ + + +
+
+ psychology +

Recursive Language Model

+

Ask a question to start. The RLM will iteratively reason
through the problem using code execution.

+
+
+ + + +
+
+
+ +
+ +
+
+
+ +
+
+ timeline + Event Log + 0 events + +
+
+
+ hourglass_empty +

Events will appear here
as the RLM processes

+
+
+
+
+ + + + + + + + + + diff --git a/contributing/samples/rlm/adk_rlm/tools/__init__.py b/contributing/samples/rlm/adk_rlm/tools/__init__.py new file mode 100644 index 0000000000..676fe037c2 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/tools/__init__.py @@ -0,0 +1 @@ +"""Tools for ADK-RLM.""" diff --git a/contributing/samples/rlm/adk_rlm/types.py b/contributing/samples/rlm/adk_rlm/types.py new file mode 100644 index 0000000000..e3911369db --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/types.py @@ -0,0 +1,344 @@ +""" +Data types for ADK-RLM. + +These types are designed to match the original RLM implementation for +compatibility with the visualizer. +""" + +from dataclasses import dataclass +from dataclasses import field +from types import ModuleType +from typing import Any + + +def _serialize_value(value: Any) -> Any: + """Convert a value to a JSON-serializable representation.""" + if value is None or isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, ModuleType): + return f"" + if isinstance(value, (list, tuple)): + return [_serialize_value(v) for v in value] + if isinstance(value, dict): + return {str(k): _serialize_value(v) for k, v in value.items()} + if callable(value): + return ( + f"<{type(value).__name__} '{getattr(value, '__name__', repr(value))}'>" + ) + try: + return repr(value) + except Exception: + return f"<{type(value).__name__}>" + + +@dataclass +class ModelUsageSummary: + """Usage summary for a single model.""" + + total_calls: int + total_input_tokens: int + total_output_tokens: int + + def to_dict(self) -> dict[str, Any]: + return { + "total_calls": self.total_calls, + "total_input_tokens": self.total_input_tokens, + "total_output_tokens": self.total_output_tokens, + } + + @classmethod + def from_dict(cls, data: dict) -> "ModelUsageSummary": + return cls( + total_calls=data.get("total_calls", 0), + total_input_tokens=data.get("total_input_tokens", 0), + total_output_tokens=data.get("total_output_tokens", 0), + ) + + +@dataclass +class UsageSummary: + """Aggregated usage summary across all models.""" + + model_usage_summaries: dict[str, ModelUsageSummary] = field( + default_factory=dict + ) + + def to_dict(self) -> dict[str, Any]: + return { + "model_usage_summaries": { + model: usage.to_dict() + for model, usage in self.model_usage_summaries.items() + }, + } + + @classmethod + def from_dict(cls, data: dict) -> "UsageSummary": + return cls( + model_usage_summaries={ + model: ModelUsageSummary.from_dict(usage) + for model, usage in data.get("model_usage_summaries", {}).items() + }, + ) + + @property + def total_calls(self) -> int: + return sum(m.total_calls for m in self.model_usage_summaries.values()) + + @property + def total_input_tokens(self) -> int: + return sum( + m.total_input_tokens for m in self.model_usage_summaries.values() + ) + + @property + def total_output_tokens(self) -> int: + return sum( + m.total_output_tokens for m in self.model_usage_summaries.values() + ) + + +@dataclass +class RLMChatCompletion: + """Record of a single LLM call made from within the environment.""" + + root_model: str + prompt: str | dict[str, Any] + response: str + usage_summary: UsageSummary + execution_time: float + + def to_dict(self) -> dict[str, Any]: + return { + "root_model": self.root_model, + "prompt": self.prompt, + "response": self.response, + "usage_summary": self.usage_summary.to_dict(), + "execution_time": self.execution_time, + } + + @classmethod + def from_dict(cls, data: dict) -> "RLMChatCompletion": + return cls( + root_model=data.get("root_model", ""), + prompt=data.get("prompt", ""), + response=data.get("response", ""), + usage_summary=UsageSummary.from_dict(data.get("usage_summary", {})), + execution_time=data.get("execution_time", 0.0), + ) + + +@dataclass +class REPLResult: + """Result from executing code in the REPL environment.""" + + stdout: str + stderr: str + locals: dict[str, Any] + execution_time: float + rlm_calls: list[RLMChatCompletion] = field(default_factory=list) + + def __str__(self) -> str: + return ( + f"REPLResult(stdout={self.stdout!r}, stderr={self.stderr!r}," + f" locals={list(self.locals.keys())}," + f" execution_time={self.execution_time:.3f}s," + f" rlm_calls={len(self.rlm_calls)})" + ) + + def to_dict(self) -> dict[str, Any]: + return { + "stdout": self.stdout, + "stderr": self.stderr, + "locals": {k: _serialize_value(v) for k, v in self.locals.items()}, + "execution_time": self.execution_time, + "rlm_calls": [call.to_dict() for call in self.rlm_calls], + } + + @classmethod + def from_dict(cls, data: dict) -> "REPLResult": + return cls( + stdout=data.get("stdout", ""), + stderr=data.get("stderr", ""), + locals=data.get("locals", {}), + execution_time=data.get("execution_time", 0.0), + rlm_calls=[ + RLMChatCompletion.from_dict(c) for c in data.get("rlm_calls", []) + ], + ) + + +@dataclass +class CodeBlock: + """A code block extracted from an LLM response with its execution result.""" + + code: str + result: REPLResult + + def to_dict(self) -> dict[str, Any]: + return { + "code": self.code, + "result": self.result.to_dict(), + } + + @classmethod + def from_dict(cls, data: dict) -> "CodeBlock": + return cls( + code=data.get("code", ""), + result=REPLResult.from_dict(data.get("result", {})), + ) + + +@dataclass +class RLMIteration: + """A single iteration of the RLM loop.""" + + prompt: str | dict[str, Any] + response: str + code_blocks: list[CodeBlock] + final_answer: str | None = None + iteration_time: float | None = None + + def to_dict(self) -> dict[str, Any]: + return { + "prompt": self.prompt, + "response": self.response, + "code_blocks": [cb.to_dict() for cb in self.code_blocks], + "final_answer": self.final_answer, + "iteration_time": self.iteration_time, + } + + @classmethod + def from_dict(cls, data: dict) -> "RLMIteration": + return cls( + prompt=data.get("prompt", ""), + response=data.get("response", ""), + code_blocks=[ + CodeBlock.from_dict(cb) for cb in data.get("code_blocks", []) + ], + final_answer=data.get("final_answer"), + iteration_time=data.get("iteration_time"), + ) + + +@dataclass +class RLMMetadata: + """Metadata about the RLM configuration.""" + + root_model: str + max_depth: int + max_iterations: int + backend: str + backend_kwargs: dict[str, Any] + environment_type: str + environment_kwargs: dict[str, Any] + other_backends: list[str] | None = None + + def to_dict(self) -> dict[str, Any]: + return { + "root_model": self.root_model, + "max_depth": self.max_depth, + "max_iterations": self.max_iterations, + "backend": self.backend, + "backend_kwargs": { + k: _serialize_value(v) for k, v in self.backend_kwargs.items() + }, + "environment_type": self.environment_type, + "environment_kwargs": { + k: _serialize_value(v) for k, v in self.environment_kwargs.items() + }, + "other_backends": self.other_backends, + } + + @classmethod + def from_dict(cls, data: dict) -> "RLMMetadata": + return cls( + root_model=data.get("root_model", ""), + max_depth=data.get("max_depth", 5), + max_iterations=data.get("max_iterations", 30), + backend=data.get("backend", ""), + backend_kwargs=data.get("backend_kwargs", {}), + environment_type=data.get("environment_type", ""), + environment_kwargs=data.get("environment_kwargs", {}), + other_backends=data.get("other_backends"), + ) + + +@dataclass +class QueryMetadata: + """Metadata about the query context.""" + + context_lengths: list[int] + context_total_length: int + context_type: str + + def __init__( + self, prompt: str | list[str] | dict[Any, Any] | list[dict[Any, Any]] + ): + # Handle LazyFile and LazyFileCollection types + # Import here to avoid circular imports + try: + from adk_rlm.files.lazy import LazyFile + from adk_rlm.files.lazy import LazyFileCollection + + if isinstance(prompt, LazyFile): + # Get file size without loading content if possible + self.context_type = "lazy_file" + try: + self.context_lengths = [prompt.size_bytes or 0] + except Exception: + self.context_lengths = [0] + self.context_total_length = sum(self.context_lengths) + return + elif isinstance(prompt, LazyFileCollection): + self.context_type = "lazy_file_collection" + self.context_lengths = [] + for f in prompt: + try: + self.context_lengths.append(f.size_bytes or 0) + except Exception: + self.context_lengths.append(0) + self.context_total_length = sum(self.context_lengths) + return + except ImportError: + pass + + if isinstance(prompt, str): + self.context_lengths = [len(prompt)] + self.context_type = "str" + elif isinstance(prompt, dict): + self.context_type = "dict" + self.context_lengths = [] + for chunk in prompt.values(): + if isinstance(chunk, str): + self.context_lengths.append(len(chunk)) + else: + try: + import json + + self.context_lengths.append(len(json.dumps(chunk, default=str))) + except Exception: + self.context_lengths.append(len(repr(chunk))) + elif isinstance(prompt, list): + self.context_type = "list" + if len(prompt) == 0: + self.context_lengths = [0] + elif isinstance(prompt[0], dict): + if "content" in prompt[0]: + self.context_lengths = [ + len(str(chunk.get("content", ""))) for chunk in prompt + ] + else: + self.context_lengths = [] + for chunk in prompt: + try: + import json + + self.context_lengths.append(len(json.dumps(chunk, default=str))) + except Exception: + self.context_lengths.append(len(repr(chunk))) + else: + self.context_lengths = [len(str(chunk)) for chunk in prompt] + else: + raise ValueError(f"Invalid prompt type: {type(prompt)}") + + self.context_total_length = sum(self.context_lengths) diff --git a/contributing/samples/rlm/adk_rlm/usage.py b/contributing/samples/rlm/adk_rlm/usage.py new file mode 100644 index 0000000000..7356dea615 --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/usage.py @@ -0,0 +1,103 @@ +""" +Usage tracking for ADK-RLM. + +Tracks token usage across multiple models during RLM execution. +""" + +from collections import defaultdict + +from adk_rlm.types import ModelUsageSummary +from adk_rlm.types import UsageSummary + + +class UsageTracker: + """Tracks token usage across multiple models.""" + + def __init__(self): + """Initialize the usage tracker.""" + self._calls: dict[str, int] = defaultdict(int) + self._input_tokens: dict[str, int] = defaultdict(int) + self._output_tokens: dict[str, int] = defaultdict(int) + + def add( + self, + model: str, + input_tokens: int = 0, + output_tokens: int = 0, + ) -> None: + """ + Add usage for a model call. + + Args: + model: The model name. + input_tokens: Number of input tokens used. + output_tokens: Number of output tokens used. + """ + self._calls[model] += 1 + self._input_tokens[model] += input_tokens + self._output_tokens[model] += output_tokens + + def add_from_response(self, model: str, usage_metadata) -> None: + """ + Add usage from a Gemini response's usage_metadata. + + Args: + model: The model name. + usage_metadata: The usage_metadata from a Gemini response. + """ + if usage_metadata is None: + self._calls[model] += 1 + return + + input_tokens = getattr(usage_metadata, "prompt_token_count", 0) or 0 + output_tokens = getattr(usage_metadata, "candidates_token_count", 0) or 0 + self.add(model, input_tokens, output_tokens) + + def get_summary(self) -> UsageSummary: + """ + Get the aggregated usage summary. + + Returns: + UsageSummary with per-model usage data. + """ + model_summaries = {} + for model in self._calls: + model_summaries[model] = ModelUsageSummary( + total_calls=self._calls[model], + total_input_tokens=self._input_tokens[model], + total_output_tokens=self._output_tokens[model], + ) + return UsageSummary(model_usage_summaries=model_summaries) + + def reset(self) -> None: + """Reset all usage tracking.""" + self._calls.clear() + self._input_tokens.clear() + self._output_tokens.clear() + + def merge(self, other: "UsageTracker") -> None: + """ + Merge usage from another tracker into this one. + + Args: + other: Another UsageTracker to merge in. + """ + for model in other._calls: + self._calls[model] += other._calls[model] + self._input_tokens[model] += other._input_tokens[model] + self._output_tokens[model] += other._output_tokens[model] + + @property + def total_calls(self) -> int: + """Return total number of calls across all models.""" + return sum(self._calls.values()) + + @property + def total_input_tokens(self) -> int: + """Return total input tokens across all models.""" + return sum(self._input_tokens.values()) + + @property + def total_output_tokens(self) -> int: + """Return total output tokens across all models.""" + return sum(self._output_tokens.values()) diff --git a/contributing/samples/rlm/adk_rlm/web.py b/contributing/samples/rlm/adk_rlm/web.py new file mode 100644 index 0000000000..b2c8fab08f --- /dev/null +++ b/contributing/samples/rlm/adk_rlm/web.py @@ -0,0 +1,781 @@ +""" +Web interface for ADK-RLM using FastAPI and Jinja2. + +Provides a browser-based UI with real-time streaming events via WebSocket. +Events are displayed in an expandable log with iteration lineage. +Sessions are persisted using ADK's DatabaseSessionService. +""" + +from contextlib import asynccontextmanager +from datetime import datetime +import logging +import os +from pathlib import Path +import time +from typing import Any +import uuid + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("adk_rlm.web") + +from adk_rlm import RLM +from adk_rlm import RLMEventType +from fastapi import FastAPI +from fastapi import Request +from fastapi import WebSocket +from fastapi import WebSocketDisconnect +from fastapi.responses import HTMLResponse +from fastapi.templating import Jinja2Templates +from google.adk.sessions import DatabaseSessionService +from google.adk.sessions import Session + +# Template directory +TEMPLATE_DIR = Path(__file__).parent / "templates" +TEMPLATE_DIR.mkdir(exist_ok=True) + +# Default configuration (can be overridden via environment or create_app) +DEFAULT_DB_URL = os.environ.get( + "RLM_DB_URL", "sqlite+aiosqlite:///./sessions.db" +) +DEFAULT_MODEL = os.environ.get("RLM_MODEL", "gemini-3-pro-preview") +DEFAULT_SUB_MODEL = os.environ.get("RLM_SUB_MODEL") +DEFAULT_MAX_ITERATIONS = int(os.environ.get("RLM_MAX_ITERATIONS", "30")) +DEFAULT_LOG_DIR = os.environ.get("RLM_LOG_DIR", "./logs") + +# Module-level config that persists across imports +_config = { + "db_url": DEFAULT_DB_URL, + "model": DEFAULT_MODEL, + "sub_model": DEFAULT_SUB_MODEL, + "max_iterations": DEFAULT_MAX_ITERATIONS, + "log_dir": DEFAULT_LOG_DIR, +} + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialize session service on startup.""" + global session_service + + # Use module-level config + db_url = _config["db_url"] + + if session_service is None: + logger.info(f"Initializing DatabaseSessionService with: {db_url}") + session_service = DatabaseSessionService(db_url=db_url) + + # Warm up the database connection by doing a simple query + # This ensures the first WebSocket connection doesn't have to wait + logger.info("Warming up database connection...") + try: + await session_service.list_sessions( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + ) + logger.info("Database warmup complete") + except Exception as e: + logger.warning(f"Database warmup failed (this is OK for first run): {e}") + + yield + + # Cleanup on shutdown + for rlm in active_rlm.values(): + rlm.close() + active_rlm.clear() + + +app = FastAPI(title="ADK-RLM Web Interface", lifespan=lifespan) + +# Setup templates +templates = Jinja2Templates(directory=str(TEMPLATE_DIR)) + +# Global session service (initialized in create_app or main) +session_service: DatabaseSessionService | None = None + +# App name for ADK sessions +APP_NAME = "adk_rlm_web" +DEFAULT_USER_ID = "default_user" + +# Store active RLM instances per session +active_rlm: dict[str, RLM] = {} + + +# Tokyo Night Color Theme (matching CLI) +COLORS = { + "primary": "#7AA2F7", + "secondary": "#BB9AF7", + "success": "#9ECE6A", + "warning": "#E0AF68", + "error": "#F7768E", + "text": "#A9B1D6", + "muted": "#565F89", + "accent": "#7DCFFF", + "border": "#3B4261", + "bg": "#1A1B26", + "bg_dark": "#16161E", + "bg_highlight": "#292E42", +} + + +def get_or_create_rlm(session: Session) -> RLM: + """Get or create an RLM instance for a session.""" + if session.id not in active_rlm: + # Get config from session state + model = session.state.get("model", "gemini-3-pro-preview") + sub_model = session.state.get("sub_model") + max_iterations = session.state.get("max_iterations", 30) + log_dir = session.state.get("log_dir", "./logs") + + active_rlm[session.id] = RLM( + model=model, + sub_model=sub_model, + max_iterations=max_iterations, + persistent=True, + log_dir=log_dir, + ) + + # Register GCS source for gs:// URIs + try: + from adk_rlm.files.sources.gcs import GCSFileSource + + gcs_source = GCSFileSource() + active_rlm[session.id].file_loader.register_source("gcs", gcs_source) + except ImportError: + logger.warning( + "GCS support not available (google-cloud-storage not installed)" + ) + return active_rlm[session.id] + + +def close_rlm(session_id: str): + """Close and remove an RLM instance.""" + if session_id in active_rlm: + active_rlm[session_id].close() + del active_rlm[session_id] + + +def get_event_icon(event_type: str) -> str: + """Get icon for event type.""" + icons = { + RLMEventType.RUN_START.value: "play_arrow", + RLMEventType.RUN_END.value: "stop", + RLMEventType.RUN_ERROR.value: "error", + RLMEventType.ITERATION_START.value: "loop", + RLMEventType.ITERATION_END.value: "check_circle", + RLMEventType.LLM_CALL_START.value: "psychology", + RLMEventType.LLM_CALL_END.value: "psychology", + RLMEventType.CODE_FOUND.value: "code", + RLMEventType.CODE_EXEC_START.value: "terminal", + RLMEventType.CODE_EXEC_END.value: "terminal", + RLMEventType.SUB_LLM_START.value: "call_split", + RLMEventType.SUB_LLM_END.value: "call_merge", + RLMEventType.FINAL_DETECTED.value: "star", + RLMEventType.FINAL_ANSWER.value: "check", + } + return icons.get(event_type, "circle") + + +def get_event_color(event_type: str) -> str: + """Get color for event type.""" + colors = { + RLMEventType.RUN_START.value: COLORS["primary"], + RLMEventType.RUN_END.value: COLORS["success"], + RLMEventType.RUN_ERROR.value: COLORS["error"], + RLMEventType.ITERATION_START.value: COLORS["primary"], + RLMEventType.ITERATION_END.value: COLORS["muted"], + RLMEventType.LLM_CALL_START.value: COLORS["secondary"], + RLMEventType.LLM_CALL_END.value: COLORS["secondary"], + RLMEventType.CODE_FOUND.value: COLORS["success"], + RLMEventType.CODE_EXEC_START.value: COLORS["accent"], + RLMEventType.CODE_EXEC_END.value: COLORS["accent"], + RLMEventType.SUB_LLM_START.value: COLORS["warning"], + RLMEventType.SUB_LLM_END.value: COLORS["warning"], + RLMEventType.FINAL_DETECTED.value: COLORS["warning"], + RLMEventType.FINAL_ANSWER.value: COLORS["warning"], + } + return colors.get(event_type, COLORS["text"]) + + +def format_event_label(event_type: str) -> str: + """Format event type for display.""" + label = event_type.replace("rlm.", "").replace(".", " ").title() + return label + + +def format_event_for_ui( + event_data: dict, event_id: int, start_time: float, iteration: int +) -> dict: + """Format an event for the UI.""" + event_type = event_data.get("event_type", "") + return { + "id": event_id, + "type": "event", + "event_type": event_type, + "iteration": iteration, + "timestamp": time.time() - start_time, + "icon": get_event_icon(event_type), + "color": get_event_color(event_type), + "label": format_event_label(event_type), + "metadata": { + k: v + for k, v in event_data.items() + if k not in ("event_type",) and v is not None + }, + } + + +async def update_session_state( + session: Session, state_updates: dict[str, Any] +) -> Session: + """ + Update session state and persist to database. + + This is a convenience wrapper that updates state via a no-op event. + """ + from google.adk.events import Event + from google.adk.events import EventActions + + # Update in-memory state + session.state.update(state_updates) + + # Create a state-update event to persist changes + event = Event( + author="system", + timestamp=time.time(), + actions=EventActions(state_delta=state_updates), + ) + + # Persist via append_event + await session_service.append_event(session, event) + + return session + + +@app.get("/", response_class=HTMLResponse) +async def index(request: Request): + """Render the main page.""" + return templates.TemplateResponse( + "index.html", + { + "request": request, + "colors": COLORS, + }, + ) + + +@app.get("/health") +async def health(): + """Health check endpoint to verify session service.""" + try: + if session_service is None: + return {"status": "error", "message": "session_service is None"} + + # Try a simple operation + logger.info("Health check: testing session service...") + sessions = await session_service.list_sessions( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + ) + logger.info(f"Health check: got {len(sessions.sessions)} sessions") + return { + "status": "ok", + "session_service": str(type(session_service)), + "session_count": len(sessions.sessions), + } + except Exception as e: + logger.exception(f"Health check failed: {e}") + return {"status": "error", "message": str(e)} + + +@app.websocket("/ws/{session_id}") +async def websocket_endpoint(websocket: WebSocket, session_id: str): + """WebSocket endpoint for real-time streaming.""" + logger.info(f"WebSocket connection request for session: {session_id}") + await websocket.accept() + logger.info(f"WebSocket accepted for session: {session_id}") + + # Get or create session + try: + session = await session_service.get_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=session_id, + ) + logger.info(f"Got session: {session}") + + if session is None: + # Create new session with default state + logger.info(f"Creating new session: {session_id}") + session = await session_service.create_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=session_id, + state={ + "title": f"Session {datetime.now().strftime('%Y-%m-%d %H:%M')}", + "model": "gemini-3-pro-preview", + "sub_model": None, + "max_iterations": 30, + "files": [], + "conversation": [], # List of {role, content, timestamp} + "ui_events": [], # Formatted events for UI + }, + ) + logger.info(f"Created session: {session.id}") + except Exception as e: + logger.exception(f"Error getting/creating session: {e}") + await websocket.close(code=1011, reason=str(e)) + return + + try: + while True: + data = await websocket.receive_json() + action = data.get("action") + + if action == "query": + prompt = data.get("prompt", "") + await run_query(websocket, session, prompt) + + elif action == "add_files": + patterns = data.get("patterns", []) + await add_files(websocket, session, patterns) + + elif action == "clear": + # Clear conversation and events + await update_session_state( + session, + { + "conversation": [], + "ui_events": [], + "files": [], + }, + ) + close_rlm(session.id) + await websocket.send_json({ + "type": "status", + "message": "Session cleared", + }) + + elif action == "config": + updates = {} + if "model" in data: + updates["model"] = data["model"] + if "sub_model" in data: + updates["sub_model"] = data["sub_model"] + if "max_iterations" in data: + updates["max_iterations"] = data["max_iterations"] + if "title" in data: + updates["title"] = data["title"] + + if updates: + await update_session_state(session, updates) + # Recreate RLM with new config + close_rlm(session.id) + + await websocket.send_json({ + "type": "status", + "message": "Configuration updated", + }) + + elif action == "get_status": + # Refresh session from DB + session = await session_service.get_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=session.id, + ) + await websocket.send_json({ + "type": "status_response", + "session_id": session.id, + "title": session.state.get("title", "Untitled"), + "model": session.state.get("model", "gemini-3-pro-preview"), + "sub_model": ( + session.state.get("sub_model") + or session.state.get("model", "gemini-3-pro-preview") + ), + "max_iterations": session.state.get("max_iterations", 30), + "files": session.state.get("files", []), + "conversation": session.state.get("conversation", []), + "events": session.state.get("ui_events", []), + }) + + elif action == "load_session": + new_session_id = data.get("session_id") + if new_session_id and new_session_id != session.id: + new_session = await session_service.get_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=new_session_id, + ) + if new_session: + session = new_session + await websocket.send_json({ + "type": "session_loaded", + "session_id": session.id, + "title": session.state.get("title", "Untitled"), + "model": session.state.get("model", "gemini-3-pro-preview"), + "sub_model": ( + session.state.get("sub_model") or session.state.get("model") + ), + "max_iterations": session.state.get("max_iterations", 30), + "files": session.state.get("files", []), + "conversation": session.state.get("conversation", []), + "events": session.state.get("ui_events", []), + }) + else: + await websocket.send_json({ + "type": "error", + "message": f"Session {new_session_id} not found", + }) + + elif action == "new_session": + # Create new session + new_session_id = str(uuid.uuid4()) + session = await session_service.create_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=new_session_id, + state={ + "title": f"Session {datetime.now().strftime('%Y-%m-%d %H:%M')}", + "model": "gemini-3-pro-preview", + "sub_model": None, + "max_iterations": 30, + "files": [], + "conversation": [], + "ui_events": [], + }, + ) + await websocket.send_json({ + "type": "session_created", + "session_id": session.id, + "title": session.state.get("title"), + }) + + elif action == "list_sessions": + response = await session_service.list_sessions( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + ) + sessions_data = [] + for s in response.sessions: + conv = s.state.get("conversation", []) + sessions_data.append({ + "session_id": s.id, + "title": s.state.get("title", "Untitled"), + "updated_at": ( + datetime.fromtimestamp(s.last_update_time).isoformat() + if s.last_update_time + else None + ), + "message_count": len(conv), + }) + # Sort by updated_at descending + sessions_data.sort( + key=lambda x: x.get("updated_at") or "", reverse=True + ) + await websocket.send_json({ + "type": "sessions_list", + "sessions": sessions_data, + }) + + elif action == "delete_session": + del_session_id = data.get("session_id") + if del_session_id: + close_rlm(del_session_id) + await session_service.delete_session( + app_name=APP_NAME, + user_id=DEFAULT_USER_ID, + session_id=del_session_id, + ) + await websocket.send_json({ + "type": "session_deleted", + "session_id": del_session_id, + "success": True, + }) + + except WebSocketDisconnect: + pass # Session already persisted + + +async def add_files( + websocket: WebSocket, session: Session, patterns: list[str] +): + """Add files to the session.""" + rlm = get_or_create_rlm(session) + try: + resolved = rlm.file_loader.create_lazy_files(patterns) + if len(resolved) == 0: + await websocket.send_json({ + "type": "error", + "message": f"No files found matching: {' '.join(patterns)}", + }) + else: + # Update session state + current_files = session.state.get("files", []) + current_files.extend(patterns) + await update_session_state(session, {"files": current_files}) + + await websocket.send_json({ + "type": "files_added", + "patterns": patterns, + "count": len(resolved), + "names": resolved.names[:10], + "total": len(resolved), + }) + except Exception as e: + await websocket.send_json({ + "type": "error", + "message": f"Could not resolve files: {e}", + }) + + +async def run_query(websocket: WebSocket, session: Session, prompt: str): + """Run an RLM query and stream events.""" + rlm = get_or_create_rlm(session) + + # Add user message to conversation + conversation = list(session.state.get("conversation", [])) + conversation.append({ + "role": "user", + "content": prompt, + "timestamp": datetime.now().isoformat(), + }) + + # Extract conversation history for the agent (exclude current message) + # Only include role and content, not timestamp + conversation_history = None + if len(conversation) > 1: + conversation_history = [ + {"role": msg["role"], "content": msg["content"]} + for msg in conversation[:-1] + ] + + # Build file context + files = session.state.get("files", []) + if files: + try: + file_ctx = rlm.file_loader.build_context(files, lazy=True) + file_count = file_ctx.get("file_count", 0) + if file_count == 0: + await websocket.send_json({ + "type": "error", + "message": f"No files found matching patterns: {' '.join(files)}", + }) + return + ctx = file_ctx + except Exception as e: + await websocket.send_json({ + "type": "error", + "message": f"Failed to load files: {e}", + }) + return + else: + ctx = { + "info": "No files loaded. The user is asking a question.", + } + + # Send query start + start_time = time.time() + await websocket.send_json({ + "type": "query_start", + "prompt": prompt, + }) + + try: + event_id = 0 + current_iteration = 0 + final_answer = None + ui_events = [] + + async for event in rlm.run_streaming(ctx, prompt, conversation_history): + if not event.custom_metadata: + continue + + event_type = event.custom_metadata.get("event_type") + if not event_type: + continue + + if event_type == RLMEventType.ITERATION_START.value: + current_iteration = event.custom_metadata.get("iteration", 0) + + # Format event for UI + ui_event = format_event_for_ui( + event.custom_metadata, + event_id, + start_time, + current_iteration, + ) + ui_events.append(ui_event) + + # Check for final answer + if event.custom_metadata.get("answer"): + final_answer = event.custom_metadata["answer"] + + await websocket.send_json(ui_event) + event_id += 1 + + # Add assistant message to conversation + title = session.state.get("title", "") + if final_answer: + conversation.append({ + "role": "assistant", + "content": final_answer, + "timestamp": datetime.now().isoformat(), + }) + + # Auto-generate title from first exchange + if title.startswith("Session ") and len(conversation) == 2: + first_msg = conversation[0]["content"] + title = first_msg[:50] + ("..." if len(first_msg) > 50 else "") + + # Update session state + await update_session_state( + session, + { + "conversation": conversation, + "ui_events": ui_events, + "title": title, + }, + ) + + # Send completion + elapsed = time.time() - start_time + await websocket.send_json({ + "type": "query_complete", + "elapsed_seconds": elapsed, + "total_events": event_id, + "final_answer": final_answer, + "title": session.state.get("title"), + }) + + except Exception as e: + await websocket.send_json({ + "type": "error", + "message": str(e), + }) + + +def create_app( + model: str = "gemini-3-pro-preview", + sub_model: str | None = None, + max_iterations: int = 30, + log_dir: str | None = None, + db_url: str = "sqlite+aiosqlite:///./sessions.db", +) -> FastAPI: + """Create a configured FastAPI app.""" + # Update module-level config + _config["db_url"] = db_url + _config["model"] = model + _config["sub_model"] = sub_model + _config["max_iterations"] = max_iterations + _config["log_dir"] = log_dir + + # Also store in app.state for easy access + app.state.default_model = model + app.state.default_sub_model = sub_model + app.state.default_max_iterations = max_iterations + app.state.default_log_dir = log_dir + return app + + +def main(): + """Run the web server.""" + import argparse + + import uvicorn + + parser = argparse.ArgumentParser( + description="ADK-RLM Web Interface", + ) + parser.add_argument( + "--host", + type=str, + default="127.0.0.1", + help="Host to bind to (default: 127.0.0.1)", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port to bind to (default: 8000)", + ) + parser.add_argument( + "--model", + "-m", + type=str, + default="gemini-3-pro-preview", + help="Default model (default: gemini-3-pro-preview)", + ) + parser.add_argument( + "--sub-model", + "-s", + type=str, + help="Default sub-model (defaults to main model)", + ) + parser.add_argument( + "--max-iterations", + "-i", + type=int, + default=30, + help="Default max iterations (default: 30)", + ) + parser.add_argument( + "--log-dir", + "-l", + type=str, + default="./logs", + help="Directory for JSONL logs (default: ./logs)", + ) + parser.add_argument( + "--db-url", + type=str, + default="sqlite+aiosqlite:///./sessions.db", + help=( + "SQLAlchemy database URL for sessions (default:" + " sqlite+aiosqlite:///./sessions.db)" + ), + ) + parser.add_argument( + "--reload", + action="store_true", + help="Enable auto-reload for development", + ) + + args = parser.parse_args() + + if args.reload: + # When using reload, set environment variables so config persists + # across module reimports + os.environ["RLM_DB_URL"] = args.db_url + os.environ["RLM_MODEL"] = args.model + os.environ["RLM_MAX_ITERATIONS"] = str(args.max_iterations) + if args.log_dir: + os.environ["RLM_LOG_DIR"] = args.log_dir + if args.sub_model: + os.environ["RLM_SUB_MODEL"] = args.sub_model + + uvicorn.run( + "adk_rlm.web:app", + host=args.host, + port=args.port, + reload=True, + ) + else: + # When not using reload, configure app directly + configured_app = create_app( + model=args.model, + sub_model=args.sub_model, + max_iterations=args.max_iterations, + log_dir=args.log_dir, + db_url=args.db_url, + ) + + uvicorn.run( + configured_app, + host=args.host, + port=args.port, + ) + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/deployment/cloudbuild.yaml b/contributing/samples/rlm/deployment/cloudbuild.yaml new file mode 100644 index 0000000000..bc4deb39cd --- /dev/null +++ b/contributing/samples/rlm/deployment/cloudbuild.yaml @@ -0,0 +1,54 @@ +steps: + # Pull the latest image to use as cache + - name: 'gcr.io/cloud-builders/docker' + entrypoint: 'bash' + args: + - '-c' + - 'docker pull ${_REGION}-docker.pkg.dev/$PROJECT_ID/adk-rlm/web:latest || true' + + # Build the container image with cache + - name: 'gcr.io/cloud-builders/docker' + args: + - 'build' + - '--cache-from=${_REGION}-docker.pkg.dev/$PROJECT_ID/adk-rlm/web:latest' + - '-t' + - '${_REGION}-docker.pkg.dev/$PROJECT_ID/adk-rlm/web:${BUILD_ID}' + - '-t' + - '${_REGION}-docker.pkg.dev/$PROJECT_ID/adk-rlm/web:latest' + - '.' + + # Push the container image to Artifact Registry + - name: 'gcr.io/cloud-builders/docker' + args: + - 'push' + - '--all-tags' + - '${_REGION}-docker.pkg.dev/$PROJECT_ID/adk-rlm/web' + + # Deploy to Cloud Run with IAP enabled (beta feature) + - name: 'gcr.io/google.com/cloudsdktool/cloud-sdk' + entrypoint: 'gcloud' + args: + - 'beta' + - 'run' + - 'deploy' + - '${_SERVICE_NAME}' + - '--image=${_REGION}-docker.pkg.dev/$PROJECT_ID/adk-rlm/web:${BUILD_ID}' + - '--region=${_REGION}' + - '--platform=managed' + - '--min-instances=1' + - '--max-instances=10' + - '--memory=2Gi' + - '--cpu=2' + - '--timeout=3600' + - '--concurrency=80' + - '--no-allow-unauthenticated' + - '--iap' + - '--service-account=adk-rlm-runner@$PROJECT_ID.iam.gserviceaccount.com' + - '--set-env-vars=RLM_MODEL=gemini-3-pro-preview,RLM_MAX_ITERATIONS=30' + +substitutions: + _REGION: us-central1 + _SERVICE_NAME: adk-rlm + +options: + machineType: 'E2_HIGHCPU_8' diff --git a/contributing/samples/rlm/deployment/deploy.sh b/contributing/samples/rlm/deployment/deploy.sh new file mode 100755 index 0000000000..109a530ad5 --- /dev/null +++ b/contributing/samples/rlm/deployment/deploy.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Deploy ADK-RLM to Cloud Run with IAP (using Cloud Run's native IAP support) + +set -e + +# Configuration +PROJECT_ID="${PROJECT_ID:-$(gcloud config get-value project)}" +REGION="${REGION:-us-central1}" +SERVICE_NAME="${SERVICE_NAME:-adk-rlm}" + +echo "=== ADK-RLM Cloud Run Deployment ===" +echo "Project: $PROJECT_ID" +echo "Region: $REGION" +echo "Service: $SERVICE_NAME" +echo "" + +# Ensure we're in the project root +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +cd "$PROJECT_ROOT" + +echo "Building and deploying from: $(pwd)" +echo "" + +# Submit build to Cloud Build +echo "Submitting build to Cloud Build..." +gcloud builds submit \ + --config=deployment/cloudbuild.yaml \ + --substitutions="_REGION=${REGION},_SERVICE_NAME=${SERVICE_NAME}" \ + --project="$PROJECT_ID" \ + . + +# Get the service URL +SERVICE_URL=$(gcloud run services describe "$SERVICE_NAME" \ + --region="$REGION" \ + --project="$PROJECT_ID" \ + --format="value(status.url)") + +echo "" +echo "=== Deployment Complete ===" +echo "" +echo "Cloud Run service deployed with:" +echo " - Minimum instances: 1" +echo " - IAP enabled (requires Google authentication)" +echo " - Public IAM access disabled" +echo "" +echo "Service URL: $SERVICE_URL" +echo "" +echo "To grant access to users, run:" +echo " gcloud beta run services add-iam-policy-binding $SERVICE_NAME \\" +echo " --region=$REGION \\" +echo " --member='user:email@example.com' \\" +echo " --role='roles/run.invoker' \\" +echo " --project=$PROJECT_ID" diff --git a/contributing/samples/rlm/deployment/setup-gcp.sh b/contributing/samples/rlm/deployment/setup-gcp.sh new file mode 100755 index 0000000000..b0857214f2 --- /dev/null +++ b/contributing/samples/rlm/deployment/setup-gcp.sh @@ -0,0 +1,131 @@ +#!/bin/bash +# Setup GCP project for ADK-RLM deployment +# This script creates a new project and enables required services + +set -e + +# Configuration - modify these as needed +PROJECT_ID="${PROJECT_ID:-adk-rlm-$(date +%s)}" +BILLING_ACCOUNT="${BILLING_ACCOUNT:-}" +REGION="${REGION:-us-central1}" + +echo "=== ADK-RLM GCP Project Setup ===" +echo "Project ID: $PROJECT_ID" +echo "Region: $REGION" +echo "" + +# Check if billing account is set +if [ -z "$BILLING_ACCOUNT" ]; then + echo "Available billing accounts:" + gcloud billing accounts list + echo "" + echo "Set BILLING_ACCOUNT environment variable and re-run:" + echo " export BILLING_ACCOUNT=" + echo " ./setup-gcp.sh" + exit 1 +fi + +echo "Billing Account: $BILLING_ACCOUNT" +echo "" + +# Create the project +echo "Creating project $PROJECT_ID..." +gcloud projects create "$PROJECT_ID" --name="ADK-RLM" || { + echo "Project may already exist, continuing..." +} + +# Link billing account +echo "Linking billing account..." +gcloud billing projects link "$PROJECT_ID" --billing-account="$BILLING_ACCOUNT" + +# Set as current project +echo "Setting current project..." +gcloud config set project "$PROJECT_ID" + +# Enable required APIs +echo "Enabling required APIs..." +gcloud services enable \ + cloudbuild.googleapis.com \ + run.googleapis.com \ + artifactregistry.googleapis.com \ + iap.googleapis.com \ + aiplatform.googleapis.com \ + --project="$PROJECT_ID" + +# Create Artifact Registry repository for container images +echo "Creating Artifact Registry repository..." +gcloud artifacts repositories create adk-rlm \ + --repository-format=docker \ + --location="$REGION" \ + --description="ADK-RLM container images" \ + --project="$PROJECT_ID" || { + echo "Repository may already exist, continuing..." +} + +# Get project number for service accounts +PROJECT_NUMBER=$(gcloud projects describe "$PROJECT_ID" --format='value(projectNumber)') +CLOUD_BUILD_SA="${PROJECT_NUMBER}@cloudbuild.gserviceaccount.com" +COMPUTE_SA="${PROJECT_NUMBER}-compute@developer.gserviceaccount.com" + +# Grant permissions to Cloud Build service account +echo "Granting Cloud Build permissions..." +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${CLOUD_BUILD_SA}" \ + --role="roles/run.admin" + +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${CLOUD_BUILD_SA}" \ + --role="roles/iam.serviceAccountUser" + +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${CLOUD_BUILD_SA}" \ + --role="roles/artifactregistry.writer" + +# Grant permissions to default compute service account (used by Cloud Build) +echo "Granting default compute service account permissions..." +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${COMPUTE_SA}" \ + --role="roles/artifactregistry.writer" + +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${COMPUTE_SA}" \ + --role="roles/run.admin" + +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${COMPUTE_SA}" \ + --role="roles/iam.serviceAccountUser" + +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${COMPUTE_SA}" \ + --role="roles/logging.logWriter" + +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${COMPUTE_SA}" \ + --role="roles/storage.objectViewer" + +# Create service account for Cloud Run +echo "Creating Cloud Run service account..." +gcloud iam service-accounts create adk-rlm-runner \ + --display-name="ADK-RLM Cloud Run Service Account" \ + --project="$PROJECT_ID" || { + echo "Service account may already exist, continuing..." +} + +# Grant the service account permissions for Vertex AI +RUNNER_SA="adk-rlm-runner@${PROJECT_ID}.iam.gserviceaccount.com" +gcloud projects add-iam-policy-binding "$PROJECT_ID" \ + --member="serviceAccount:${RUNNER_SA}" \ + --role="roles/aiplatform.user" + +echo "" +echo "=== Setup Complete ===" +echo "" +echo "Project ID: $PROJECT_ID" +echo "Region: $REGION" +echo "" +echo "To deploy, run:" +echo " export PROJECT_ID=$PROJECT_ID" +echo " export REGION=$REGION" +echo " ./deploy.sh" +echo "" +echo "Save these values in your environment or .env file." diff --git a/contributing/samples/rlm/examples/basic_usage.py b/contributing/samples/rlm/examples/basic_usage.py new file mode 100644 index 0000000000..ad52e116d8 --- /dev/null +++ b/contributing/samples/rlm/examples/basic_usage.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 +""" +Basic usage example for ADK-RLM. + +This example demonstrates how to use the completion function to analyze +a simple context and get a response. +""" + +from adk_rlm import completion + + +def main(): + # Simple context with a question + context = """ + The following is a list of famous mathematicians and their contributions: + + 1. Euclid (c. 300 BC) - Known as the "Father of Geometry", wrote "Elements" + 2. Isaac Newton (1643-1727) - Developed calculus and laws of motion + 3. Carl Friedrich Gauss (1777-1855) - Made contributions to number theory, statistics + 4. Leonhard Euler (1707-1783) - Prolific mathematician, introduced notation like e and i + 5. Srinivasa Ramanujan (1887-1920) - Made groundbreaking contributions to mathematical analysis + """ + + # Use the convenience function with verbose output + result = completion( + context=context, + prompt="Who is known as the Father of Geometry and when did they live?", + model="gemini-3-flash-preview", + max_iterations=10, + verbose=True, + ) + + print("\n" + "=" * 50) + print("FINAL RESULT:") + print("=" * 50) + print(result.response) + print(f"\nExecution time: {result.execution_time:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/examples/lazy_local_filesystem.py b/contributing/samples/rlm/examples/lazy_local_filesystem.py new file mode 100644 index 0000000000..13174612c1 --- /dev/null +++ b/contributing/samples/rlm/examples/lazy_local_filesystem.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +""" +Example demonstrating file loading with the completion function. +""" + +from adk_rlm import completion + +result = completion( + files=["./plans/**/*"], + prompt=( + "What is this project about? Write a detailed summary of the project" + " plan(s)." + ), +) + +print(result.response) diff --git a/contributing/samples/rlm/examples/long_context.py b/contributing/samples/rlm/examples/long_context.py new file mode 100644 index 0000000000..a38447f846 --- /dev/null +++ b/contributing/samples/rlm/examples/long_context.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +""" +Long context example for ADK-RLM. + +This example demonstrates how RLM handles long documents by +chunking and using recursive LLM calls. +""" + +from adk_rlm import completion + + +def generate_long_document() -> str: + """Generate a synthetic long document for testing.""" + sections = [] + + topics = [ + ( + "Introduction to Machine Learning", + "Machine learning is a subset of artificial intelligence...", + ), + ( + "Supervised Learning", + "In supervised learning, models learn from labeled data...", + ), + ( + "Unsupervised Learning", + "Unsupervised learning finds patterns in unlabeled data...", + ), + ( + "Deep Learning", + "Deep learning uses neural networks with multiple layers...", + ), + ( + "Natural Language Processing", + "NLP enables computers to understand human language...", + ), + ( + "Computer Vision", + "Computer vision allows machines to interpret visual data...", + ), + ( + "Reinforcement Learning", + "RL agents learn by interacting with environments...", + ), + ( + "Ethics in AI", + "AI ethics considers fairness, transparency, and accountability...", + ), + ] + + for title, intro in topics: + section = f""" +### {title} + +{intro} + +This section covers the fundamental concepts and applications of {title.lower()}. +The field has seen significant advances in recent years, with applications +ranging from healthcare to finance to entertainment. + +Key concepts include: +- Theoretical foundations +- Practical implementations +- Real-world applications +- Current challenges and limitations +- Future directions + +Many researchers and practitioners have contributed to this area, +making it one of the most active fields in computer science. +""" + sections.append(section) + + return "\n\n".join(sections) + + +def main(): + # Generate a long document + document = generate_long_document() + print(f"Document length: {len(document)} characters") + + # Use completion with logging enabled + result = completion( + context=document, + prompt=( + "What are all the main topics covered in this document? List them" + " with a brief summary of each." + ), + model="gemini-3-pro-preview", + sub_model="gemini-3-flash-preview", + max_iterations=15, + log_dir="./logs", + verbose=True, + ) + + print("\n" + "=" * 50) + print("FINAL RESULT:") + print("=" * 50) + print(result.response) + print(f"\nExecution time: {result.execution_time:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/examples/multi_turn.py b/contributing/samples/rlm/examples/multi_turn.py new file mode 100644 index 0000000000..7084b82070 --- /dev/null +++ b/contributing/samples/rlm/examples/multi_turn.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Multi-turn conversation example for ADK-RLM. + +This example demonstrates how to use persistent mode for +multi-turn conversations where context accumulates. + +For multi-turn, we use the RLM class directly with run_streaming() +since persistent mode requires maintaining state across calls. +""" + +import asyncio + +from adk_rlm import RLM +from adk_rlm import RLMEventType + + +async def run_query(rlm: RLM, context: str, prompt: str) -> str: + """Run a query and return the final answer.""" + final_answer = None + async for event in rlm.run_streaming(context, prompt): + if event.custom_metadata: + event_type = event.custom_metadata.get("event_type") + if event_type == RLMEventType.FINAL_ANSWER.value: + final_answer = event.custom_metadata.get("answer") + return final_answer or "" + + +async def main(): + # Create an RLM instance with persistence enabled + rlm = RLM( + model="gemini-3-flash-preview", + verbose=True, + persistent=True, # Enable multi-turn persistence + max_iterations=10, + ) + + try: + # First turn: introduce some data + print("\n" + "=" * 50) + print("TURN 1: Introduce data") + print("=" * 50) + + result1 = await run_query( + rlm, + context=""" + Employee records: + - Alice: Software Engineer, 5 years experience, salary $120,000 + - Bob: Data Scientist, 3 years experience, salary $110,000 + - Carol: Product Manager, 7 years experience, salary $140,000 + - Dave: DevOps Engineer, 2 years experience, salary $95,000 + """, + prompt=( + "Store this employee data in a variable called 'employees' as a" + " list of dictionaries." + ), + ) + print( + f"Result: {result1[:200]}..." + if len(result1) > 200 + else f"Result: {result1}" + ) + + # Second turn: ask about the data + print("\n" + "=" * 50) + print("TURN 2: Query the data") + print("=" * 50) + + result2 = await run_query( + rlm, + context="Calculate the average salary.", + prompt="What is the average salary of all employees?", + ) + print(f"Result: {result2}") + + # Third turn: more analysis + print("\n" + "=" * 50) + print("TURN 3: Further analysis") + print("=" * 50) + + result3 = await run_query( + rlm, + context="Find the most experienced employee.", + prompt="Who has the most years of experience and what is their role?", + ) + print(f"Result: {result3}") + + print("\n" + "=" * 50) + print("MULTI-TURN COMPLETE") + print("=" * 50) + + finally: + rlm.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/contributing/samples/rlm/examples/quickstart.py b/contributing/samples/rlm/examples/quickstart.py new file mode 100644 index 0000000000..337aa4f892 --- /dev/null +++ b/contributing/samples/rlm/examples/quickstart.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +""" +ADK-RLM Quickstart - Minimal example to get started. + +This is the simplest possible RLM example. It demonstrates: +1. Using the completion convenience function +2. Running a computation that uses the REPL +3. Getting the final answer + +Run with: python examples/quickstart.py +""" + +from adk_rlm import completion + + +def main(): + # Use the convenience function for simple synchronous completion + result = completion( + context="Calculate the sum of the first 100 positive integers.", + prompt="What is the sum? Use the REPL to compute it.", + model="gemini-3-flash-preview", + ) + + # The expected answer is 5050 (Gauss's formula: n*(n+1)/2 = 100*101/2 = 5050) + print(f"Answer: {result.response}") + print(f"Execution time: {result.execution_time:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/image.png b/contributing/samples/rlm/image.png new file mode 100644 index 0000000000..3a474f4857 Binary files /dev/null and b/contributing/samples/rlm/image.png differ diff --git a/contributing/samples/rlm/pyproject.toml b/contributing/samples/rlm/pyproject.toml new file mode 100644 index 0000000000..1956f0a41b --- /dev/null +++ b/contributing/samples/rlm/pyproject.toml @@ -0,0 +1,73 @@ +[project] +name = "adk-rlm" +version = "0.1.0" +description = "Recursive Language Models implemented with Google ADK" +readme = "README.md" +requires-python = ">=3.11" +dependencies = [ + "google-adk>=1.0.0", + "google-genai>=1.0.0", + "rich>=13.0.0", + "python-dotenv>=1.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.24.0", + "pytest-timeout>=2.3.0", + "pytest-playwright>=0.4.0", + "playwright>=1.40.0", + "ruff>=0.4.0", +] + +# Web interface dependencies +web = [ + "fastapi>=0.109.0", + "uvicorn[standard]>=0.27.0", + "jinja2>=3.1.0", +] + +# File handling dependencies +files = [ + "pdfplumber>=0.10.0", + "pyyaml>=6.0.0", +] + +# Individual file format support +pdf = ["pdfplumber>=0.10.0"] + +# Cloud storage sources +gcs = ["google-cloud-storage>=2.14.0"] + +# All optional features +all = [ + "pdfplumber>=0.10.0", + "pyyaml>=6.0.0", + "fastapi>=0.109.0", + "uvicorn[standard]>=0.27.0", + "jinja2>=3.1.0", + "google-cloud-storage>=2.14.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +testpaths = ["tests"] +markers = [ + "e2e: mark test as end-to-end (requires real LLM)", + "ui: mark test as UI test (requires playwright)", + "e2e_web: mark test as web E2E test (requires server)", +] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "W"] +ignore = ["E501"] diff --git a/contributing/samples/rlm/scripts/analyze_log.py b/contributing/samples/rlm/scripts/analyze_log.py new file mode 100755 index 0000000000..3a9adbc0e4 --- /dev/null +++ b/contributing/samples/rlm/scripts/analyze_log.py @@ -0,0 +1,645 @@ +#!/usr/bin/env python3 +""" +Analyze RLM JSONL logs for quick insights. + +Usage: + python scripts/analyze_log.py [LOG_FILE] [OPTIONS] + +Examples: + # Analyze the most recent log + python scripts/analyze_log.py + + # Analyze a specific log + python scripts/analyze_log.py logs/rlm_2026-01-22_*.jsonl + + # Show only the summary + python scripts/analyze_log.py --summary + + # Show the iteration tree + python scripts/analyze_log.py --tree + + # Show all code blocks + python scripts/analyze_log.py --code + + # Show final answer only + python scripts/analyze_log.py --final + + # Filter by depth + python scripts/analyze_log.py --depth 0 + + # Show LLM responses (truncated) + python scripts/analyze_log.py --responses + + # Show simple LLM calls (non-recursive llm_query calls) + python scripts/analyze_log.py --simple + + # Show only failed simple LLM calls + python scripts/analyze_log.py --simple --failed + + # Export to markdown + python scripts/analyze_log.py --export report.md +""" + +import argparse +from collections import defaultdict +from datetime import datetime +import json +from pathlib import Path +import sys + + +def load_log(log_path: Path) -> list[dict]: + """Load JSONL log, skipping malformed lines.""" + entries = [] + with open(log_path) as f: + for i, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + entries.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"Warning: Skipping malformed line {i}: {e}", file=sys.stderr) + return entries + + +def find_latest_log(log_dir: Path = Path("logs")) -> Path | None: + """Find the most recent log file.""" + logs = sorted(log_dir.glob("rlm_*.jsonl"), key=lambda p: p.stat().st_mtime) + return logs[-1] if logs else None + + +def get_metadata(entries: list[dict]) -> dict | None: + """Extract metadata entry.""" + for e in entries: + if e.get("type") == "metadata": + return e + return None + + +def get_iterations(entries: list[dict]) -> list[dict]: + """Get all iteration entries.""" + return [e for e in entries if e.get("type") == "iteration"] + + +def get_simple_llm_calls(entries: list[dict]) -> list[dict]: + """Get all simple_llm_call entries (non-recursive llm_query calls).""" + return [e for e in entries if e.get("type") == "simple_llm_call"] + + +def print_summary(entries: list[dict], log_path: Path): + """Print a summary of the run.""" + meta = get_metadata(entries) + iters = get_iterations(entries) + simple_calls = get_simple_llm_calls(entries) + + print("=" * 70) + print(f"RLM Log Analysis: {log_path.name}") + print("=" * 70) + + if meta: + print(f"\nModel: {meta.get('root_model', 'unknown')}") + print(f"Max Iterations: {meta.get('max_iterations', 'unknown')}") + print(f"Max Depth: {meta.get('max_depth', 'unknown')}") + print(f"Timestamp: {meta.get('timestamp', 'unknown')}") + + print(f"\nTotal Iterations: {len(iters)}") + + # Count by depth + depth_counts = defaultdict(int) + for it in iters: + depth_counts[it.get("depth", 0)] += 1 + + print("\nIterations by Depth:") + for depth in sorted(depth_counts.keys()): + print(f" Depth {depth}: {depth_counts[depth]} iterations") + + # Simple LLM calls summary + if simple_calls: + success_count = sum(1 for c in simple_calls if c.get("success", True)) + failed_count = len(simple_calls) - success_count + total_time_ms = sum(c.get("execution_time_ms", 0) for c in simple_calls) + + print(f"\nSimple LLM Calls: {len(simple_calls)}") + print(f" Successful: {success_count}") + if failed_count > 0: + print(f" Failed: {failed_count}") + print(f" Total Time: {total_time_ms/1000:.1f}s") + + # Count by depth + simple_depth_counts = defaultdict(int) + for c in simple_calls: + simple_depth_counts[c.get("depth", 0)] += 1 + + if len(simple_depth_counts) > 1: + print(" By Depth:") + for depth in sorted(simple_depth_counts.keys()): + print(f" Depth {depth}: {simple_depth_counts[depth]} calls") + + # Find final answer + final = None + for it in reversed(iters): + if it.get("final_answer"): + final = it.get("final_answer") + break + + if final: + print(f"\nFinal Answer Found: Yes ({len(final)} chars)") + else: + print("\nFinal Answer Found: No (run may still be in progress)") + + # Total time + times = [ + it.get("iteration_time", 0) for it in iters if it.get("iteration_time") + ] + if times: + print(f"\nTotal Iteration Time: {sum(times):.1f}s") + print(f"Avg Iteration Time: {sum(times)/len(times):.1f}s") + + +def _print_simple_calls_summary(calls: list[dict], indent: str) -> None: + """Print a grouped summary of simple LLM calls for an iteration. + + Args: + calls: List of simple_llm_call entries for this iteration. + indent: Indentation string to align with parent iteration. + """ + total = len(calls) + success = sum(1 for c in calls if c.get("success", True)) + failed = total - success + total_time_ms = sum(c.get("execution_time_ms", 0) for c in calls) + + # Check if this is a batch + batch_sizes = {c.get("batch_size") for c in calls if c.get("batch_size")} + is_batch = len(batch_sizes) == 1 and batch_sizes.pop() == total + + # Build status string + if failed > 0: + status = f"{success} ok, {failed} failed" + else: + status = "ok" if total == 1 else f"{total} ok" + + # Build description + if is_batch: + desc = f"batch[{total}]" + elif total == 1: + desc = "llm_query" + else: + desc = f"llm_query x{total}" + + time_str = ( + f"{total_time_ms/1000:.1f}s" + if total_time_ms >= 1000 + else f"{total_time_ms:.0f}ms" + ) + + # Print with arrow to show it's a sub-call + print(f"{indent} └─ {desc} ({time_str}) [{status}]") + + +def print_tree(entries: list[dict], show_simple: bool = True): + """Print the iteration tree showing agent hierarchy. + + Args: + entries: Log entries to display. + show_simple: If True, show simple LLM calls grouped after each iteration. + """ + iters = get_iterations(entries) + simple_calls = get_simple_llm_calls(entries) if show_simple else [] + + # Group simple calls by (depth, parent_iteration) + simple_by_iter: dict[tuple[int, int], list[dict]] = defaultdict(list) + for call in simple_calls: + key = (call.get("depth", 0), call.get("parent_iteration", 0)) + simple_by_iter[key].append(call) + + print("\nIteration Tree:") + print("-" * 50) + + for it in iters: + depth = it.get("depth", 0) + agent = it.get("agent_name", "unknown") + iteration = it.get("iteration", 0) + time_s = it.get("iteration_time") or 0 + has_code = bool(it.get("code_blocks")) + has_final = bool(it.get("final_answer")) + + indent = " " * depth + code_marker = " [code]" if has_code else "" + final_marker = " [FINAL]" if has_final else "" + + # Truncate agent name for display + agent_short = agent.replace("rlm_agent", "rlm") + + time_str = f"({time_s:.1f}s)" if time_s else "" + print( + f"{indent}[{iteration:2d}] {agent_short}" + f" {time_str}{code_marker}{final_marker}" + ) + + # Show simple LLM calls for this iteration + if show_simple: + key = (depth, iteration) + calls = simple_by_iter.get(key, []) + if calls: + _print_simple_calls_summary(calls, indent) + + +def print_code_blocks(entries: list[dict], depth_filter: int | None = None): + """Print all code blocks from the log.""" + iters = get_iterations(entries) + + print("\nCode Blocks:") + print("=" * 70) + + for it in iters: + depth = it.get("depth", 0) + if depth_filter is not None and depth != depth_filter: + continue + + code_blocks = it.get("code_blocks", []) + if not code_blocks: + continue + + iteration = it.get("iteration", "?") + agent = it.get("agent_name", "unknown") + + for i, block in enumerate(code_blocks): + print( + f"\n--- Iteration {iteration} (depth={depth}, {agent}) Block" + f" {i+1} ---" + ) + print(f"Code:\n{block.get('code', '')}") + output = block.get("output", "") + if output: + # Truncate long outputs + if len(output) > 1000: + output = ( + output[:1000] + f"\n... (truncated, {len(output)} chars total)" + ) + print(f"\nOutput:\n{output}") + error = block.get("error", "") + if error: + print(f"\nError:\n{error}") + + +def print_responses( + entries: list[dict], depth_filter: int | None = None, max_len: int = 500 +): + """Print LLM responses (truncated).""" + iters = get_iterations(entries) + + print("\nLLM Responses:") + print("=" * 70) + + for it in iters: + depth = it.get("depth", 0) + if depth_filter is not None and depth != depth_filter: + continue + + iteration = it.get("iteration", "?") + agent = it.get("agent_name", "unknown") + response = it.get("response", "") + + if not response: + continue + + print(f"\n--- Iteration {iteration} (depth={depth}, {agent}) ---") + if len(response) > max_len: + print( + response[:max_len] + f"\n... (truncated, {len(response)} chars total)" + ) + else: + print(response) + + +def print_final_answer(entries: list[dict]): + """Print the final answer.""" + iters = get_iterations(entries) + + for it in reversed(iters): + if it.get("final_answer"): + print("\nFinal Answer:") + print("=" * 70) + print(it["final_answer"]) + return + + print("\nNo final answer found (run may still be in progress)") + + +def print_simple_llm_calls( + entries: list[dict], + depth_filter: int | None = None, + max_len: int = 300, + show_failed_only: bool = False, +): + """Print simple LLM calls (non-recursive llm_query calls).""" + simple_calls = get_simple_llm_calls(entries) + + if not simple_calls: + print("\nNo simple LLM calls found.") + return + + print("\nSimple LLM Calls (recursive=False):") + print("=" * 70) + + for i, call in enumerate(simple_calls): + depth = call.get("depth", 0) + if depth_filter is not None and depth != depth_filter: + continue + + success = call.get("success", True) + if show_failed_only and success: + continue + + agent = call.get("agent_name", "unknown") + model = call.get("model", "unknown") + time_ms = call.get("execution_time_ms", 0) + parent_iter = call.get("parent_iteration", "?") + parent_block = call.get("parent_block_index", "?") + batch_idx = call.get("batch_index") + batch_size = call.get("batch_size") + + # Header + status = "OK" if success else "FAILED" + batch_info = f" [batch {batch_idx+1}/{batch_size}]" if batch_size else "" + print( + f"\n--- Call {i+1} ({status}) depth={depth} iter={parent_iter}" + f" block={parent_block}{batch_info} ---" + ) + print(f"Agent: {agent} | Model: {model} | Time: {time_ms:.0f}ms") + + # Prompt + prompt = call.get("prompt", call.get("prompt_full", "")) + if prompt: + if len(prompt) > max_len: + prompt = ( + prompt[:max_len] + + f"... ({len(call.get('prompt_full', prompt))} chars)" + ) + print(f"\nPrompt:\n{prompt}") + + # Response or error + if not success: + error = call.get("error", "Unknown error") + print(f"\nError: {error}") + else: + response = call.get("response", call.get("response_full", "")) + if response: + if len(response) > max_len: + response = ( + response[:max_len] + + f"... ({len(call.get('response_full', response))} chars)" + ) + print(f"\nResponse:\n{response}") + + +def export_markdown(entries: list[dict], output_path: Path, log_path: Path): + """Export the log to a markdown report.""" + meta = get_metadata(entries) + iters = get_iterations(entries) + simple_calls = get_simple_llm_calls(entries) + + lines = [] + lines.append(f"# RLM Run Report: {log_path.name}\n") + + # Metadata + if meta: + lines.append("## Configuration\n") + lines.append(f"- **Model:** {meta.get('root_model', 'unknown')}") + lines.append( + f"- **Max Iterations:** {meta.get('max_iterations', 'unknown')}" + ) + lines.append(f"- **Max Depth:** {meta.get('max_depth', 'unknown')}") + lines.append(f"- **Timestamp:** {meta.get('timestamp', 'unknown')}") + lines.append("") + + # Summary stats + lines.append("## Summary\n") + lines.append(f"- **Total Iterations:** {len(iters)}") + + depth_counts = defaultdict(int) + for it in iters: + depth_counts[it.get("depth", 0)] += 1 + + for depth in sorted(depth_counts.keys()): + lines.append(f"- **Depth {depth}:** {depth_counts[depth]} iterations") + + if simple_calls: + success_count = sum(1 for c in simple_calls if c.get("success", True)) + lines.append( + f"- **Simple LLM Calls:** {len(simple_calls)} ({success_count}" + " successful)" + ) + + lines.append("") + + # Iterations + lines.append("## Iterations\n") + + for it in iters: + depth = it.get("depth", 0) + iteration = it.get("iteration", "?") + agent = it.get("agent_name", "unknown") + response = it.get("response", "") + code_blocks = it.get("code_blocks", []) + final = it.get("final_answer") + + lines.append(f"### Iteration {iteration} (Depth {depth})\n") + lines.append(f"**Agent:** `{agent}`\n") + + if response: + lines.append("**Response:**\n") + lines.append( + f"```\n{response[:2000]}{'...' if len(response) > 2000 else ''}\n```\n" + ) + + for i, block in enumerate(code_blocks): + lines.append(f"**Code Block {i+1}:**\n") + lines.append(f"```python\n{block.get('code', '')}\n```\n") + output = block.get("output", "") + if output: + lines.append( + f"**Output:**\n```\n{output[:1000]}{'...' if len(output) > 1000 else ''}\n```\n" + ) + + if final: + lines.append(f"**FINAL ANSWER:**\n\n{final}\n") + + lines.append("---\n") + + # Simple LLM Calls section + if simple_calls: + lines.append("## Simple LLM Calls\n") + lines.append( + "These are non-recursive `llm_query()` calls made during code" + " execution.\n" + ) + + for i, call in enumerate(simple_calls): + depth = call.get("depth", 0) + success = call.get("success", True) + agent = call.get("agent_name", "unknown") + model = call.get("model", "unknown") + time_ms = call.get("execution_time_ms", 0) + parent_iter = call.get("parent_iteration", "?") + batch_idx = call.get("batch_index") + batch_size = call.get("batch_size") + + status = "OK" if success else "FAILED" + batch_info = f" (batch {batch_idx+1}/{batch_size})" if batch_size else "" + + lines.append(f"### Call {i+1} - {status}{batch_info}\n") + lines.append(f"- **Agent:** `{agent}`") + lines.append(f"- **Model:** {model}") + lines.append(f"- **Depth:** {depth}") + lines.append(f"- **Parent Iteration:** {parent_iter}") + lines.append(f"- **Time:** {time_ms:.0f}ms") + lines.append("") + + prompt = call.get("prompt_full", call.get("prompt", "")) + if prompt: + lines.append("**Prompt:**\n") + lines.append( + f"```\n{prompt[:1000]}{'...' if len(prompt) > 1000 else ''}\n```\n" + ) + + if not success: + error = call.get("error", "Unknown error") + lines.append(f"**Error:** {error}\n") + else: + response = call.get("response_full", call.get("response", "")) + if response: + lines.append("**Response:**\n") + lines.append( + f"```\n{response[:1000]}{'...' if len(response) > 1000 else ''}\n```\n" + ) + + lines.append("---\n") + + with open(output_path, "w") as f: + f.write("\n".join(lines)) + + print(f"Exported report to {output_path}") + + +def main(): + parser = argparse.ArgumentParser( + description="Analyze RLM JSONL logs", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "log_file", nargs="?", help="Path to log file (default: latest)" + ) + parser.add_argument( + "--summary", "-s", action="store_true", help="Show summary only" + ) + parser.add_argument( + "--tree", "-t", action="store_true", help="Show iteration tree" + ) + parser.add_argument( + "--code", "-c", action="store_true", help="Show code blocks" + ) + parser.add_argument( + "--responses", "-r", action="store_true", help="Show LLM responses" + ) + parser.add_argument( + "--final", "-f", action="store_true", help="Show final answer only" + ) + parser.add_argument( + "--simple", + action="store_true", + help="Show simple LLM calls (recursive=False)", + ) + parser.add_argument( + "--failed", + action="store_true", + help="With --simple, show only failed calls", + ) + parser.add_argument( + "--no-simple-tree", + action="store_true", + help="Hide simple LLM calls from iteration tree", + ) + parser.add_argument("--depth", "-d", type=int, help="Filter by depth") + parser.add_argument( + "--export", "-e", type=str, help="Export to markdown file" + ) + parser.add_argument( + "--list", "-l", action="store_true", help="List available log files" + ) + + args = parser.parse_args() + + # List logs + if args.list: + log_dir = Path("logs") + logs = sorted(log_dir.glob("rlm_*.jsonl"), key=lambda p: p.stat().st_mtime) + print("Available log files:") + for log in logs[-10:]: # Last 10 + size = log.stat().st_size + size_str = ( + f"{size/1024:.1f}KB" + if size < 1024 * 1024 + else f"{size/1024/1024:.1f}MB" + ) + print(f" {log.name} ({size_str})") + return + + # Find log file + if args.log_file: + log_path = Path(args.log_file) + else: + log_path = find_latest_log() + if not log_path: + print("No log files found in logs/", file=sys.stderr) + sys.exit(1) + + if not log_path.exists(): + print(f"Log file not found: {log_path}", file=sys.stderr) + sys.exit(1) + + # Load entries + entries = load_log(log_path) + if not entries: + print("No valid entries found in log", file=sys.stderr) + sys.exit(1) + + # Export mode + if args.export: + export_markdown(entries, Path(args.export), log_path) + return + + # Determine what to show + show_all = not any([ + args.summary, + args.tree, + args.code, + args.responses, + args.final, + args.simple, + ]) + + if show_all or args.summary: + print_summary(entries, log_path) + + if show_all or args.tree: + show_simple_in_tree = not getattr(args, "no_simple_tree", False) + print_tree(entries, show_simple=show_simple_in_tree) + + if args.code: + print_code_blocks(entries, args.depth) + + if args.responses: + print_responses(entries, args.depth) + + if args.simple: + print_simple_llm_calls(entries, args.depth, show_failed_only=args.failed) + + if args.final: + print_final_answer(entries) + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/scripts/compare_runs.py b/contributing/samples/rlm/scripts/compare_runs.py new file mode 100755 index 0000000000..2d8bad0072 --- /dev/null +++ b/contributing/samples/rlm/scripts/compare_runs.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +""" +Compare multiple RLM runs side-by-side. + +Usage: + python scripts/compare_runs.py LOG1 LOG2 [LOG3 ...] + python scripts/compare_runs.py --latest 5 + +Examples: + python scripts/compare_runs.py logs/run1.jsonl logs/run2.jsonl + python scripts/compare_runs.py --latest 3 +""" + +import argparse +from collections import defaultdict +from dataclasses import dataclass +import json +from pathlib import Path +import sys + + +@dataclass +class RunStats: + path: Path + model: str + total_iterations: int + depth_distribution: dict[int, int] + has_final: bool + final_length: int + total_time: float + max_depth_used: int + file_count: int | None + + +def load_log(log_path: Path) -> list[dict]: + """Load JSONL log.""" + entries = [] + with open(log_path) as f: + for line in f: + line = line.strip() + if line: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + pass + return entries + + +def find_latest_logs(log_dir: Path = Path("logs"), n: int = 5) -> list[Path]: + """Find the n most recent log files.""" + logs = sorted(log_dir.glob("rlm_*.jsonl"), key=lambda p: p.stat().st_mtime) + return logs[-n:] if logs else [] + + +def analyze_run(log_path: Path) -> RunStats: + """Analyze a single run.""" + entries = load_log(log_path) + + meta = next((e for e in entries if e.get("type") == "metadata"), {}) + iterations = [e for e in entries if e.get("type") == "iteration"] + + # Depth distribution + depth_dist = defaultdict(int) + for it in iterations: + depth_dist[it.get("depth", 0)] += 1 + + # Find final answer + final = None + for it in reversed(iterations): + if it.get("final_answer"): + final = it.get("final_answer") + break + + # Total time + times = [ + it.get("iteration_time", 0) + for it in iterations + if it.get("iteration_time") + ] + total_time = sum(times) + + # File count (from first iteration's context inspection if available) + file_count = None + for it in iterations: + for block in it.get("code_blocks", []): + output = block.get("output", "") + if "file_count" in output or "files" in output: + import re + + match = re.search(r"file_count['\"]?:\s*(\d+)", output) + if match: + file_count = int(match.group(1)) + break + match = re.search(r"(\d+)\s*files?", output) + if match: + file_count = int(match.group(1)) + break + if file_count: + break + + return RunStats( + path=log_path, + model=meta.get("root_model", "unknown"), + total_iterations=len(iterations), + depth_distribution=dict(depth_dist), + has_final=final is not None, + final_length=len(final) if final else 0, + total_time=total_time, + max_depth_used=max(depth_dist.keys()) if depth_dist else 0, + file_count=file_count, + ) + + +def print_comparison(runs: list[RunStats]): + """Print comparison table.""" + print("\n" + "=" * 90) + print("RLM Run Comparison") + print("=" * 90) + + # Header + col_width = max(20, max(len(r.path.stem[:20]) for r in runs) + 2) + header = "Metric".ljust(25) + "".join( + r.path.stem[:20].ljust(col_width) for r in runs + ) + print(f"\n{header}") + print("-" * len(header)) + + # Model + row = "Model".ljust(25) + for r in runs: + row += r.model[:18].ljust(col_width) + print(row) + + # Total iterations + row = "Total Iterations".ljust(25) + for r in runs: + row += str(r.total_iterations).ljust(col_width) + print(row) + + # Max depth + row = "Max Depth Used".ljust(25) + for r in runs: + row += str(r.max_depth_used).ljust(col_width) + print(row) + + # Depth breakdown + all_depths = set() + for r in runs: + all_depths.update(r.depth_distribution.keys()) + + for depth in sorted(all_depths): + row = f" Depth {depth}".ljust(25) + for r in runs: + count = r.depth_distribution.get(depth, 0) + row += str(count).ljust(col_width) + print(row) + + # File count + row = "Files Processed".ljust(25) + for r in runs: + val = str(r.file_count) if r.file_count else "?" + row += val.ljust(col_width) + print(row) + + # Total time + row = "Total Time (s)".ljust(25) + for r in runs: + val = f"{r.total_time:.1f}" if r.total_time else "?" + row += val.ljust(col_width) + print(row) + + # Has final + row = "Has Final Answer".ljust(25) + for r in runs: + val = "✓" if r.has_final else "✗" + row += val.ljust(col_width) + print(row) + + # Final length + row = "Final Length (chars)".ljust(25) + for r in runs: + row += str(r.final_length).ljust(col_width) + print(row) + + # Efficiency metrics + print("\n" + "-" * len(header)) + print("Efficiency Metrics:") + + row = "Iters per File".ljust(25) + for r in runs: + if r.file_count and r.file_count > 0: + val = f"{r.total_iterations / r.file_count:.1f}" + else: + val = "?" + row += val.ljust(col_width) + print(row) + + row = "Time per Iter (s)".ljust(25) + for r in runs: + if r.total_iterations > 0 and r.total_time > 0: + val = f"{r.total_time / r.total_iterations:.1f}" + else: + val = "?" + row += val.ljust(col_width) + print(row) + + row = "Depth Ratio (d0/total)".ljust(25) + for r in runs: + d0 = r.depth_distribution.get(0, 0) + if r.total_iterations > 0: + val = f"{d0 / r.total_iterations:.2f}" + else: + val = "?" + row += val.ljust(col_width) + print(row) + + # Analysis + print("\n" + "=" * 90) + print("Analysis:") + print("=" * 90) + + # Best/worst iterations + if len(runs) > 1: + sorted_by_iters = sorted(runs, key=lambda r: r.total_iterations) + print( + f"\nFewest iterations: {sorted_by_iters[0].path.stem}" + f" ({sorted_by_iters[0].total_iterations})" + ) + print( + f"Most iterations: {sorted_by_iters[-1].path.stem}" + f" ({sorted_by_iters[-1].total_iterations})" + ) + + # Check for explosion + for r in runs: + if r.total_iterations > 100 and r.max_depth_used >= 3: + print( + f"\n⚠ {r.path.stem}: Possible iteration explosion (depth" + f" {r.max_depth_used}, {r.total_iterations} iters)" + ) + + # Check for missing finals + missing = [r for r in runs if not r.has_final] + if missing: + print(f"\n⚠ Runs without final answer: {[r.path.stem for r in missing]}") + + +def main(): + parser = argparse.ArgumentParser(description="Compare RLM runs") + parser.add_argument("log_files", nargs="*", help="Log files to compare") + parser.add_argument( + "--latest", "-l", type=int, help="Compare N most recent runs" + ) + args = parser.parse_args() + + if args.latest: + log_paths = find_latest_logs(n=args.latest) + if not log_paths: + print("No log files found", file=sys.stderr) + sys.exit(1) + elif args.log_files: + log_paths = [Path(f) for f in args.log_files] + else: + parser.print_help() + sys.exit(1) + + # Analyze all runs + runs = [] + for path in log_paths: + if not path.exists(): + print(f"Warning: {path} not found, skipping", file=sys.stderr) + continue + try: + runs.append(analyze_run(path)) + except Exception as e: + print(f"Warning: Error analyzing {path}: {e}", file=sys.stderr) + + if not runs: + print("No valid runs to compare", file=sys.stderr) + sys.exit(1) + + print_comparison(runs) + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/scripts/diagnose_run.py b/contributing/samples/rlm/scripts/diagnose_run.py new file mode 100755 index 0000000000..304fa8b81f --- /dev/null +++ b/contributing/samples/rlm/scripts/diagnose_run.py @@ -0,0 +1,410 @@ +#!/usr/bin/env python3 +""" +Diagnose issues in RLM runs - detect context problems, iteration explosions, etc. + +Usage: + python scripts/diagnose_run.py [LOG_FILE] + +Examples: + python scripts/diagnose_run.py + python scripts/diagnose_run.py logs/rlm_2026-01-22_*.jsonl +""" + +import argparse +from collections import defaultdict +from dataclasses import dataclass +import json +from pathlib import Path +import re +import sys + + +@dataclass +class Issue: + severity: str # "error", "warning", "info" + category: str + message: str + iteration: int | None = None + depth: int | None = None + agent: str | None = None + + +def load_log(log_path: Path) -> list[dict]: + """Load JSONL log.""" + entries = [] + with open(log_path) as f: + for line in f: + line = line.strip() + if line: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + pass + return entries + + +def find_latest_log(log_dir: Path = Path("logs")) -> Path | None: + """Find the most recent log file.""" + logs = sorted(log_dir.glob("rlm_*.jsonl"), key=lambda p: p.stat().st_mtime) + return logs[-1] if logs else None + + +def check_context_issues(entries: list[dict]) -> list[Issue]: + """Check for context propagation problems.""" + issues = [] + iterations = [e for e in entries if e.get("type") == "iteration"] + + for it in iterations: + depth = it.get("depth", 0) + iteration = it.get("iteration", 0) + agent = it.get("agent_name", "") + code_blocks = it.get("code_blocks", []) + + for block in code_blocks: + code = block.get("code", "") + output = block.get("output", "") + + # Check for small context warnings in output + if output: + # Pattern: context is very small but agent expected documents + small_ctx_patterns = [ + r"Context length: (\d+)", + r"total length \((\d+) characters\)", + r"(\d+) chars", + ] + for pattern in small_ctx_patterns: + match = re.search(pattern, output) + if match: + size = int(match.group(1)) + if size < 1000 and depth > 0: + issues.append( + Issue( + severity="warning", + category="context_size", + message=( + f"Small context ({size} chars) at depth {depth} - may" + " have received filenames instead of content" + ), + iteration=iteration, + depth=depth, + agent=agent, + ) + ) + + # Check for filename list patterns in context inspection + if "context[:5]" in code or "context[:3]" in code: + if output and ".md" in output and "Tegus" not in output: + # Looks like filenames, not content + issues.append( + Issue( + severity="error", + category="context_type", + message=( + "Context appears to be filenames (strings) instead of" + " file objects" + ), + iteration=iteration, + depth=depth, + agent=agent, + ) + ) + + # Check for file not found errors + if "No such file or directory" in str( + output + ) or "FileNotFoundError" in str(output): + issues.append( + Issue( + severity="error", + category="file_error", + message="File not found error during execution", + iteration=iteration, + depth=depth, + agent=agent, + ) + ) + + return issues + + +def check_iteration_explosion(entries: list[dict]) -> list[Issue]: + """Check for iteration explosion patterns.""" + issues = [] + iterations = [e for e in entries if e.get("type") == "iteration"] + meta = next((e for e in entries if e.get("type") == "metadata"), {}) + + total = len(iterations) + max_iterations = meta.get("max_iterations", 30) + + # Count by depth + depth_counts = defaultdict(int) + for it in iterations: + depth_counts[it.get("depth", 0)] += 1 + + # Check for explosion + if total > max_iterations * 3: + issues.append( + Issue( + severity="error", + category="explosion", + message=( + f"Iteration explosion: {total} iterations (expected" + f" ~{max_iterations})" + ), + ) + ) + + # Check for deep recursion explosion + for depth, count in depth_counts.items(): + if depth >= 2 and count > 50: + issues.append( + Issue( + severity="warning", + category="deep_recursion", + message=( + f"High iteration count at depth {depth}: {count} iterations" + ), + ) + ) + + # Check for ratio imbalance + if depth_counts.get(0, 0) < 5 and sum(depth_counts.values()) > 100: + issues.append( + Issue( + severity="warning", + category="ratio_imbalance", + message=( + f"Root agent only had {depth_counts.get(0, 0)} iterations but" + f" spawned {sum(depth_counts.values())} total - aggregation may" + " be missing" + ), + ) + ) + + return issues + + +def check_redundant_work(entries: list[dict]) -> list[Issue]: + """Check for redundant/repeated work.""" + issues = [] + iterations = [e for e in entries if e.get("type") == "iteration"] + + # Track prompts by similarity + prompt_hashes = defaultdict(list) + + for it in iterations: + code_blocks = it.get("code_blocks", []) + for block in code_blocks: + code = block.get("code", "") + # Look for llm_query calls + if "llm_query" in code: + # Extract prompt pattern (simplified) + prompt_match = re.search(r'prompt\s*=\s*["\'](.{50,100})', code) + if prompt_match: + prompt_key = prompt_match.group(1)[:50] + prompt_hashes[prompt_key].append(it.get("iteration", 0)) + + # Find duplicates + for prompt_key, iters in prompt_hashes.items(): + if len(iters) > 3: + issues.append( + Issue( + severity="warning", + category="redundant_work", + message=( + f"Similar prompt pattern used {len(iters)} times:" + f" '{prompt_key[:40]}...'" + ), + ) + ) + + return issues + + +def check_final_answer(entries: list[dict]) -> list[Issue]: + """Check final answer quality.""" + issues = [] + iterations = [e for e in entries if e.get("type") == "iteration"] + + # Find final answer + final = None + final_depth = None + for it in reversed(iterations): + if it.get("final_answer"): + final = it.get("final_answer") + final_depth = it.get("depth", 0) + break + + if not final: + issues.append( + Issue( + severity="error", + category="no_answer", + message="No final answer found - run may have failed or timed out", + ) + ) + else: + # Check for error patterns in final answer + error_patterns = [ + "No such file or directory", + "was not included", + "please paste", + "I cannot", + "error occurred", + ] + for pattern in error_patterns: + if pattern.lower() in final.lower(): + issues.append( + Issue( + severity="error", + category="answer_error", + message=f"Final answer contains error pattern: '{pattern}'", + ) + ) + + # Check if final came from deep recursion + if final_depth and final_depth >= 3: + issues.append( + Issue( + severity="warning", + category="deep_final", + message=( + f"Final answer came from depth {final_depth} - may be" + " incomplete synthesis" + ), + ) + ) + + # Check answer length + if len(final) < 100: + issues.append( + Issue( + severity="warning", + category="short_answer", + message=f"Final answer is very short ({len(final)} chars)", + ) + ) + + return issues + + +def print_diagnosis(issues: list[Issue], entries: list[dict], log_path: Path): + """Print diagnosis report.""" + meta = next((e for e in entries if e.get("type") == "metadata"), {}) + iterations = [e for e in entries if e.get("type") == "iteration"] + + print("=" * 70) + print(f"RLM Run Diagnosis: {log_path.name}") + print("=" * 70) + + # Quick stats + print(f"\nModel: {meta.get('root_model', 'unknown')}") + print(f"Total Iterations: {len(iterations)}") + + depth_counts = defaultdict(int) + for it in iterations: + depth_counts[it.get("depth", 0)] += 1 + print(f"Depth Distribution: {dict(sorted(depth_counts.items()))}") + + # Issues + errors = [i for i in issues if i.severity == "error"] + warnings = [i for i in issues if i.severity == "warning"] + infos = [i for i in issues if i.severity == "info"] + + print(f"\n{'='*70}") + print( + f"Issues Found: {len(errors)} errors, {len(warnings)} warnings," + f" {len(infos)} info" + ) + print("=" * 70) + + if errors: + print("\n[ERRORS]") + for issue in errors: + loc = "" + if issue.iteration: + loc = f" (iter {issue.iteration}, depth {issue.depth})" + print(f" ✗ [{issue.category}]{loc}: {issue.message}") + + if warnings: + print("\n[WARNINGS]") + for issue in warnings: + loc = "" + if issue.iteration: + loc = f" (iter {issue.iteration}, depth {issue.depth})" + print(f" ⚠ [{issue.category}]{loc}: {issue.message}") + + if infos: + print("\n[INFO]") + for issue in infos: + print(f" ℹ [{issue.category}]: {issue.message}") + + if not issues: + print("\n✓ No issues detected - run looks healthy!") + + # Recommendations + if issues: + print(f"\n{'='*70}") + print("Recommendations:") + print("=" * 70) + + categories = set(i.category for i in issues) + + if "context_type" in categories or "context_size" in categories: + print( + " • Ensure file objects (not filenames) are passed via context=" + " parameter" + ) + print( + " • Use: llm_query(prompt, context=file_obj) not llm_query(prompt +" + " filename)" + ) + + if "explosion" in categories or "deep_recursion" in categories: + print( + " • Use llm_query_batched with recursive=False for parallel" + " extraction" + ) + print(" • Aggregate results at calling level, don't spawn more children") + + if "ratio_imbalance" in categories: + print( + " • Root agent should run more iterations to aggregate child results" + ) + print( + " • Check that llm_query_batched results are being collected and" + " synthesized" + ) + + if "redundant_work" in categories: + print(" • Consider caching or deduplicating file analysis") + print(" • Use batch queries instead of individual file queries") + + +def main(): + parser = argparse.ArgumentParser(description="Diagnose RLM run issues") + parser.add_argument("log_file", nargs="?", help="Path to log file") + args = parser.parse_args() + + if args.log_file: + log_path = Path(args.log_file) + else: + log_path = find_latest_log() + if not log_path: + print("No log files found", file=sys.stderr) + sys.exit(1) + + entries = load_log(log_path) + + # Run all checks + issues = [] + issues.extend(check_context_issues(entries)) + issues.extend(check_iteration_explosion(entries)) + issues.extend(check_redundant_work(entries)) + issues.extend(check_final_answer(entries)) + + print_diagnosis(issues, entries, log_path) + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/scripts/extract_insights.py b/contributing/samples/rlm/scripts/extract_insights.py new file mode 100755 index 0000000000..a54bdf3582 --- /dev/null +++ b/contributing/samples/rlm/scripts/extract_insights.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +""" +Extract unique insights and findings from RLM run iterations. + +This script finds substantive outputs from code executions and llm_query calls, +useful for understanding what the RLM actually learned from the data. + +Usage: + python scripts/extract_insights.py [LOG_FILE] [OPTIONS] + +Examples: + python scripts/extract_insights.py + python scripts/extract_insights.py --min-length 200 + python scripts/extract_insights.py --depth 2 --format md +""" + +import argparse +from collections import defaultdict +from dataclasses import dataclass +import json +from pathlib import Path +import re +import sys + + +@dataclass +class Insight: + iteration: int + depth: int + agent: str + content: str + source: str # "code_output", "llm_response", "final_answer" + topic: str | None # extracted topic if identifiable + + +def load_log(log_path: Path) -> list[dict]: + """Load JSONL log.""" + entries = [] + with open(log_path) as f: + for line in f: + line = line.strip() + if line: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + pass + return entries + + +def find_latest_log(log_dir: Path = Path("logs")) -> Path | None: + """Find the most recent log file.""" + logs = sorted(log_dir.glob("rlm_*.jsonl"), key=lambda p: p.stat().st_mtime) + return logs[-1] if logs else None + + +def is_substantive_output(text: str, min_length: int = 100) -> bool: + """Check if output contains substantive content (not just debugging).""" + if not text or len(text) < min_length: + return False + + # Skip debugging output + debug_patterns = [ + r"^Context type:", + r"^dict_keys\(", + r"^ str | None: + """Try to extract the main topic from text.""" + # Look for company names + companies = [ + "Tyler Technologies", + "Tyler", + "Mark43", + "Accela", + "OpenGov", + "CentralSquare", + "Motorola", + "Hexagon", + "Axon", + "Microsoft", + "Oracle", + "Workday", + "Granicus", + "CivicPlus", + ] + for company in companies: + if company.lower() in text.lower(): + return company + + # Look for topic headers + topic_patterns = [ + r"regarding\s+(\w+(?:\s+\w+){0,2})", + r"about\s+(\w+(?:\s+\w+){0,2})", + r"analyzing\s+(\w+(?:\s+\w+){0,2})", + ] + for pattern in topic_patterns: + match = re.search(pattern, text, re.IGNORECASE) + if match: + return match.group(1) + + return None + + +def extract_insights( + entries: list[dict], min_length: int = 100 +) -> list[Insight]: + """Extract substantive insights from the run.""" + insights = [] + iterations = [e for e in entries if e.get("type") == "iteration"] + seen_content = set() # Dedupe + + for it in iterations: + depth = it.get("depth", 0) + iteration = it.get("iteration", 0) + agent = it.get("agent_name", "") + + # Check code block outputs + for block in it.get("code_blocks", []): + output = block.get("output", "") + if output and is_substantive_output(output, min_length): + # Hash for deduplication + content_hash = hash(output[:200]) + if content_hash not in seen_content: + seen_content.add(content_hash) + insights.append( + Insight( + iteration=iteration, + depth=depth, + agent=agent, + content=output, + source="code_output", + topic=extract_topic(output), + ) + ) + + # Check final answers + final = it.get("final_answer") + if final and len(final) >= min_length: + content_hash = hash(final[:200]) + if content_hash not in seen_content: + seen_content.add(content_hash) + insights.append( + Insight( + iteration=iteration, + depth=depth, + agent=agent, + content=final, + source="final_answer", + topic=extract_topic(final), + ) + ) + + return insights + + +def print_insights( + insights: list[Insight], + depth_filter: int | None = None, + format: str = "text", + max_content: int = 500, +): + """Print extracted insights.""" + filtered = insights + if depth_filter is not None: + filtered = [i for i in filtered if i.depth == depth_filter] + + if format == "md": + print_insights_markdown(filtered, max_content) + else: + print_insights_text(filtered, max_content) + + +def print_insights_text(insights: list[Insight], max_content: int): + """Print insights in text format.""" + print(f"\nExtracted {len(insights)} substantive insights:") + print("=" * 70) + + # Group by topic + by_topic = defaultdict(list) + for i in insights: + topic = i.topic or "General" + by_topic[topic].append(i) + + for topic, topic_insights in sorted(by_topic.items()): + print(f"\n### {topic} ({len(topic_insights)} insights)") + print("-" * 40) + + for insight in topic_insights[:5]: # Limit per topic + print( + f"\n[Iter {insight.iteration}, depth {insight.depth}]" + f" ({insight.source})" + ) + content = insight.content + if len(content) > max_content: + content = content[:max_content] + "..." + print(content) + + if len(topic_insights) > 5: + print(f"\n ... and {len(topic_insights) - 5} more") + + +def print_insights_markdown(insights: list[Insight], max_content: int): + """Print insights in markdown format.""" + print("# Extracted Insights\n") + + by_topic = defaultdict(list) + for i in insights: + topic = i.topic or "General" + by_topic[topic].append(i) + + for topic, topic_insights in sorted(by_topic.items()): + print(f"## {topic}\n") + + for insight in topic_insights: + print(f"### Iteration {insight.iteration} (depth {insight.depth})\n") + print(f"*Source: {insight.source}*\n") + content = insight.content + if len(content) > max_content: + content = content[:max_content] + "..." + print(f"```\n{content}\n```\n") + + +def print_summary(insights: list[Insight]): + """Print summary statistics.""" + print("\nInsights Summary:") + print("=" * 50) + + # By source + by_source = defaultdict(int) + for i in insights: + by_source[i.source] += 1 + print("\nBy Source:") + for s, count in sorted(by_source.items()): + print(f" {s}: {count}") + + # By depth + by_depth = defaultdict(int) + for i in insights: + by_depth[i.depth] += 1 + print("\nBy Depth:") + for d, count in sorted(by_depth.items()): + print(f" Depth {d}: {count}") + + # By topic + by_topic = defaultdict(int) + for i in insights: + by_topic[i.topic or "Unclassified"] += 1 + print("\nBy Topic:") + for t, count in sorted(by_topic.items(), key=lambda x: -x[1])[:10]: + print(f" {t}: {count}") + + # Average content length + avg_len = ( + sum(len(i.content) for i in insights) / len(insights) if insights else 0 + ) + print(f"\nAverage insight length: {avg_len:.0f} chars") + + +def main(): + parser = argparse.ArgumentParser(description="Extract insights from RLM logs") + parser.add_argument("log_file", nargs="?", help="Path to log file") + parser.add_argument( + "--min-length", + "-m", + type=int, + default=100, + help="Minimum content length (default: 100)", + ) + parser.add_argument("--depth", "-d", type=int, help="Filter by depth") + parser.add_argument( + "--format", + "-f", + choices=["text", "md"], + default="text", + help="Output format", + ) + parser.add_argument( + "--max-content", + type=int, + default=500, + help="Max content to show per insight", + ) + parser.add_argument( + "--summary", "-s", action="store_true", help="Show summary only" + ) + args = parser.parse_args() + + if args.log_file: + log_path = Path(args.log_file) + else: + log_path = find_latest_log() + if not log_path: + print("No log files found", file=sys.stderr) + sys.exit(1) + + print(f"Analyzing: {log_path.name}") + + entries = load_log(log_path) + insights = extract_insights(entries, args.min_length) + + if args.summary: + print_summary(insights) + else: + print_insights(insights, args.depth, args.format, args.max_content) + print_summary(insights) + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/scripts/run_query.py b/contributing/samples/rlm/scripts/run_query.py new file mode 100755 index 0000000000..5ba5c3f51b --- /dev/null +++ b/contributing/samples/rlm/scripts/run_query.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +""" +Run an RLM query and save results. + +Usage: + python scripts/run_query.py "Your question here" [OPTIONS] + +Examples: + # Simple query with files + python scripts/run_query.py "Summarize the key themes" --files "./docs/**/*.md" + + # Query with verbose output + python scripts/run_query.py "What are the main risks?" --files "./corpora/**/*.md" -v + + # Use a specific model + python scripts/run_query.py "Analyze this" --files "*.txt" --model gemini-3-flash-preview + + # Save output to file + python scripts/run_query.py "Compare X and Y" --files "./data/*.md" --output results.md + + # Run in background (for long queries) + python scripts/run_query.py "Complex analysis" --files "./large_corpus/**/*.md" --background +""" + +import argparse +import asyncio +from datetime import datetime +from pathlib import Path +import sys +import time + + +def run_query( + prompt: str, + files: list[str], + model: str = "gemini-3-pro-preview", + max_iterations: int = 30, + max_depth: int = 5, + verbose: bool = False, + output: str | None = None, +) -> str: + """Run an RLM query and return the result.""" + from adk_rlm import completion + + start = time.perf_counter() + + result = completion( + files=files, + prompt=prompt, + model=model, + max_iterations=max_iterations, + max_depth=max_depth, + log_dir="./logs", + verbose=verbose, + ) + + elapsed = time.perf_counter() - start + + # Build output text + output_text = f"""# RLM Query Result + +**Prompt:** {prompt} + +**Files:** {', '.join(files)} + +**Model:** {model} + +**Execution Time:** {elapsed:.1f}s + +--- + +## Answer + +{result.response} + +--- + +*Generated at {datetime.now().isoformat()}* +""" + + # Save if output path specified + if output: + Path(output).write_text(output_text) + print(f"Results saved to {output}") + + return result.response + + +def main(): + parser = argparse.ArgumentParser( + description="Run an RLM query", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument("prompt", help="The question/prompt to ask") + parser.add_argument( + "--files", "-f", nargs="+", required=True, help="File patterns to load" + ) + parser.add_argument( + "--model", "-m", default="gemini-3-pro-preview", help="Model to use" + ) + parser.add_argument( + "--max-iterations", "-i", type=int, default=30, help="Max iterations" + ) + parser.add_argument( + "--max-depth", "-d", type=int, default=5, help="Max recursion depth" + ) + parser.add_argument( + "--verbose", "-v", action="store_true", help="Show verbose output" + ) + parser.add_argument("--output", "-o", help="Save results to file") + parser.add_argument( + "--background", "-b", action="store_true", help="Run in background" + ) + + args = parser.parse_args() + + if args.background: + # Fork to background + import subprocess + + cmd = [ + sys.executable, + __file__, + args.prompt, + "--files", + *args.files, + "--model", + args.model, + "--max-iterations", + str(args.max_iterations), + "--max-depth", + str(args.max_depth), + ] + if args.verbose: + cmd.append("--verbose") + + # Generate output file if not specified + output = ( + args.output or f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md" + ) + cmd.extend(["--output", output]) + + log_file = f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" + + with open(log_file, "w") as f: + proc = subprocess.Popen( + cmd, + stdout=f, + stderr=subprocess.STDOUT, + start_new_session=True, + ) + + print(f"Started background process (PID: {proc.pid})") + print(f"Output will be saved to: {output}") + print(f"Logs: {log_file}") + print(f"\nMonitor with: tail -f {log_file}") + print(f"Check results: cat {output}") + return + + result = run_query( + prompt=args.prompt, + files=args.files, + model=args.model, + max_iterations=args.max_iterations, + max_depth=args.max_depth, + verbose=args.verbose, + output=args.output, + ) + + if not args.output: + print("\n" + "=" * 70) + print("RESULT:") + print("=" * 70) + print(result) + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/scripts/show_llm_calls.py b/contributing/samples/rlm/scripts/show_llm_calls.py new file mode 100755 index 0000000000..6c2182b325 --- /dev/null +++ b/contributing/samples/rlm/scripts/show_llm_calls.py @@ -0,0 +1,263 @@ +#!/usr/bin/env python3 +""" +Extract and display all llm_query/llm_query_batched calls from an RLM run. + +Usage: + python scripts/show_llm_calls.py [LOG_FILE] [OPTIONS] + +Examples: + python scripts/show_llm_calls.py + python scripts/show_llm_calls.py --depth 0 + python scripts/show_llm_calls.py --batched-only + python scripts/show_llm_calls.py --stats +""" + +import argparse +from collections import defaultdict +from dataclasses import dataclass +import json +from pathlib import Path +import re +import sys + + +@dataclass +class LLMCall: + iteration: int + depth: int + agent: str + call_type: str # "llm_query" or "llm_query_batched" + prompt_preview: str + has_context: bool + recursive: bool | None + model: str | None + full_code: str + + +def load_log(log_path: Path) -> list[dict]: + """Load JSONL log.""" + entries = [] + with open(log_path) as f: + for line in f: + line = line.strip() + if line: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + pass + return entries + + +def find_latest_log(log_dir: Path = Path("logs")) -> Path | None: + """Find the most recent log file.""" + logs = sorted(log_dir.glob("rlm_*.jsonl"), key=lambda p: p.stat().st_mtime) + return logs[-1] if logs else None + + +def extract_llm_calls(entries: list[dict]) -> list[LLMCall]: + """Extract all llm_query calls from code blocks.""" + calls = [] + iterations = [e for e in entries if e.get("type") == "iteration"] + + for it in iterations: + depth = it.get("depth", 0) + iteration = it.get("iteration", 0) + agent = it.get("agent_name", "") + code_blocks = it.get("code_blocks", []) + + for block in code_blocks: + code = block.get("code", "") + + # Find llm_query calls + patterns = [ + (r"llm_query_batched\s*\(", "llm_query_batched"), + (r"llm_query\s*\(", "llm_query"), + ] + + for pattern, call_type in patterns: + if re.search(pattern, code): + # Extract prompt preview + prompt_match = re.search( + r'(?:prompt|prompts)\s*=\s*(?:f?["\'](.{20,80})|"""(.{20,80})|\[\s*f?["\'](.{20,80}))', + code, + re.DOTALL, + ) + prompt_preview = "" + if prompt_match: + prompt_preview = next((g for g in prompt_match.groups() if g), "")[ + :60 + ] + + # Check for context parameter + has_context = "context=" in code and "context=None" not in code + + # Check recursive parameter + recursive = None + if "recursive=True" in code: + recursive = True + elif "recursive=False" in code: + recursive = False + + # Check model parameter + model = None + model_match = re.search(r'model\s*=\s*["\']([^"\']+)["\']', code) + if model_match: + model = model_match.group(1) + + calls.append( + LLMCall( + iteration=iteration, + depth=depth, + agent=agent, + call_type=call_type, + prompt_preview=prompt_preview.replace("\n", " ")[:60], + has_context=has_context, + recursive=recursive, + model=model, + full_code=code, + ) + ) + + return calls + + +def print_calls( + calls: list[LLMCall], + depth_filter: int | None = None, + batched_only: bool = False, + show_code: bool = False, +): + """Print the extracted calls.""" + filtered = calls + if depth_filter is not None: + filtered = [c for c in filtered if c.depth == depth_filter] + if batched_only: + filtered = [c for c in filtered if c.call_type == "llm_query_batched"] + + print(f"\nFound {len(filtered)} llm_query calls:") + print("=" * 80) + + for call in filtered: + rec_str = "" + if call.recursive is True: + rec_str = " [RECURSIVE]" + elif call.recursive is False: + rec_str = " [simple]" + + ctx_str = " +ctx" if call.has_context else "" + model_str = f" ({call.model})" if call.model else "" + + print(f"\n[Iter {call.iteration:3d}] depth={call.depth} {call.agent}") + print(f" {call.call_type}{rec_str}{ctx_str}{model_str}") + if call.prompt_preview: + print(f' Prompt: "{call.prompt_preview}..."') + + if show_code: + print(f" Code:\n " + call.full_code.replace("\n", "\n ")) + + +def print_stats(calls: list[LLMCall]): + """Print statistics about llm_query usage.""" + print("\nLLM Call Statistics:") + print("=" * 60) + + # By type + by_type = defaultdict(int) + for c in calls: + by_type[c.call_type] += 1 + print("\nBy Call Type:") + for t, count in sorted(by_type.items()): + print(f" {t}: {count}") + + # By depth + by_depth = defaultdict(int) + for c in calls: + by_depth[c.depth] += 1 + print("\nBy Depth:") + for d, count in sorted(by_depth.items()): + print(f" Depth {d}: {count}") + + # Recursive vs simple + recursive_count = sum(1 for c in calls if c.recursive is True) + simple_count = sum(1 for c in calls if c.recursive is False) + unspecified = len(calls) - recursive_count - simple_count + print("\nRecursive vs Simple:") + print(f" recursive=True: {recursive_count}") + print(f" recursive=False: {simple_count}") + print(f" unspecified: {unspecified}") + + # With context + with_ctx = sum(1 for c in calls if c.has_context) + print(f"\nWith context= parameter: {with_ctx}/{len(calls)}") + + # Model usage + models = defaultdict(int) + for c in calls: + models[c.model or "(default)"] += 1 + print("\nModels Used:") + for m, count in sorted(models.items(), key=lambda x: -x[1]): + print(f" {m}: {count}") + + # Recommendations + print("\n" + "=" * 60) + print("Observations:") + + if recursive_count > simple_count: + print(" ⚠ More recursive calls than simple - may cause explosion") + print(" Consider using recursive=False for extraction/summarization") + + if with_ctx < len(calls) * 0.5: + print(" ⚠ Most calls don't use context= parameter") + print(" Pass file objects via context= to properly delegate") + + batched = sum(1 for c in calls if c.call_type == "llm_query_batched") + if batched == 0 and len(calls) > 10: + print(" ℹ No batched calls found") + print(" Consider llm_query_batched for parallel processing") + + +def main(): + parser = argparse.ArgumentParser( + description="Show llm_query calls from RLM logs" + ) + parser.add_argument("log_file", nargs="?", help="Path to log file") + parser.add_argument("--depth", "-d", type=int, help="Filter by depth") + parser.add_argument( + "--batched-only", + "-b", + action="store_true", + help="Show only batched calls", + ) + parser.add_argument( + "--stats", "-s", action="store_true", help="Show statistics only" + ) + parser.add_argument( + "--code", "-c", action="store_true", help="Show full code blocks" + ) + args = parser.parse_args() + + if args.log_file: + log_path = Path(args.log_file) + else: + log_path = find_latest_log() + if not log_path: + print("No log files found", file=sys.stderr) + sys.exit(1) + + print(f"Analyzing: {log_path.name}") + + entries = load_log(log_path) + calls = extract_llm_calls(entries) + + if args.stats: + print_stats(calls) + else: + print_calls(calls, args.depth, args.batched_only, args.code) + + if len(calls) > 5: + print("\n" + "-" * 40) + print("Tip: Use --stats for usage statistics") + + +if __name__ == "__main__": + main() diff --git a/contributing/samples/rlm/tests/__init__.py b/contributing/samples/rlm/tests/__init__.py new file mode 100644 index 0000000000..8a1f309c87 --- /dev/null +++ b/contributing/samples/rlm/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for ADK-RLM.""" diff --git a/contributing/samples/rlm/tests/conftest.py b/contributing/samples/rlm/tests/conftest.py new file mode 100644 index 0000000000..0ea5e5248f --- /dev/null +++ b/contributing/samples/rlm/tests/conftest.py @@ -0,0 +1,103 @@ +""" +Pytest configuration and fixtures for ADK-RLM tests. +""" + +import os +from pathlib import Path + +import pytest + +# Skip E2E tests unless explicitly enabled +E2E_ENABLED = os.getenv("RLM_E2E_TESTS", "false").lower() == "true" + + +def pytest_configure(config): + """Configure custom pytest markers.""" + config.addinivalue_line( + "markers", "e2e: mark test as end-to-end (requires real LLM)" + ) + + +def pytest_collection_modifyitems(config, items): + """Skip E2E tests unless explicitly enabled.""" + if not E2E_ENABLED: + skip_e2e = pytest.mark.skip( + reason="E2E tests disabled (set RLM_E2E_TESTS=true)" + ) + for item in items: + if "e2e" in item.keywords: + item.add_marker(skip_e2e) + + +@pytest.fixture +def mock_llm_query(): + """Create a mock llm_query function for testing.""" + + def _mock_llm_query( + prompt: str, + context=None, + model: str | None = None, + recursive: bool = True, + ) -> str: + return f"Mock response for: {prompt[:50]}..." + + return _mock_llm_query + + +@pytest.fixture +def mock_llm_query_batched(mock_llm_query): + """Create a mock llm_query_batched function for testing.""" + + def _mock_batched( + prompts: list[str], + contexts=None, + model: str | None = None, + recursive: bool = False, + max_concurrent: int = 3, + ) -> list[str]: + return [mock_llm_query(p, model=model) for p in prompts] + + return _mock_batched + + +@pytest.fixture +def temp_log_dir(tmp_path): + """Create a temporary directory for log files.""" + log_dir = tmp_path / "logs" + log_dir.mkdir() + return str(log_dir) + + +@pytest.fixture +def sample_context(): + """Return a sample context string for testing.""" + return ( + "This is a sample context with some text. It contains important" + " information about testing. The magic number is 42." + ) + + +@pytest.fixture +def sample_context_dict(): + """Return a sample context dict for testing.""" + return { + "title": "Test Document", + "content": "This is the main content of the document.", + "metadata": {"author": "Test Author", "date": "2024-01-01"}, + } + + +@pytest.fixture +def sample_context_list(): + """Return a sample context list for testing.""" + return [ + "First chunk of text with some information.", + "Second chunk of text with more details.", + "Third chunk of text with conclusions.", + ] + + +@pytest.fixture +def fixtures_dir(): + """Return the path to the test fixtures directory.""" + return Path(__file__).parent / "fixtures" diff --git a/contributing/samples/rlm/tests/e2e/__init__.py b/contributing/samples/rlm/tests/e2e/__init__.py new file mode 100644 index 0000000000..9c52933a1a --- /dev/null +++ b/contributing/samples/rlm/tests/e2e/__init__.py @@ -0,0 +1 @@ +# End-to-end web tests using Playwright diff --git a/contributing/samples/rlm/tests/e2e/conftest.py b/contributing/samples/rlm/tests/e2e/conftest.py new file mode 100644 index 0000000000..6d9d4e3fe5 --- /dev/null +++ b/contributing/samples/rlm/tests/e2e/conftest.py @@ -0,0 +1,322 @@ +""" +Fixtures for E2E web tests with real server and mocked LLM. + +These tests use Playwright with a real FastAPI server but mock the LLM +calls to provide predictable, fast responses without API costs. +""" + +import asyncio +import os +import socket +import subprocess +import sys +import tempfile +import time +from typing import Generator +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from playwright.sync_api import Page +import pytest + + +@pytest.fixture(scope="session") +def e2e_server() -> Generator[str, None, None]: + """ + Start a real FastAPI server for E2E tests. + + This starts the actual web server on a random available port with: + - Isolated SQLite database (temp file) + - Real WebSocket connections + - Real session persistence + + The LLM calls are mocked at the application level via environment variable. + """ + import urllib.error + import urllib.request + + # Find an available port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + + # Use a temp database for isolation + db_fd, db_path = tempfile.mkstemp(suffix=".db") + os.close(db_fd) + db_url = f"sqlite+aiosqlite:///{db_path}" + + # Create temp log directory + log_dir = tempfile.mkdtemp(prefix="rlm_test_logs_") + + # Start the server in a subprocess + env = os.environ.copy() + env["RLM_DB_URL"] = db_url + env["RLM_LOG_DIR"] = log_dir + env["RLM_MODEL"] = "gemini-3-flash-preview" # Use faster model for tests + env["RLM_MAX_ITERATIONS"] = "5" # Limit iterations for tests + + # Start uvicorn via subprocess + proc = subprocess.Popen( + [ + sys.executable, + "-m", + "uvicorn", + "adk_rlm.web:app", + "--host", + "127.0.0.1", + "--port", + str(port), + ], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Wait for server to be ready + max_retries = 50 + for i in range(max_retries): + try: + urllib.request.urlopen(f"http://127.0.0.1:{port}/health", timeout=1) + break + except (urllib.error.URLError, ConnectionRefusedError): + time.sleep(0.2) + else: + stdout, stderr = proc.communicate(timeout=5) + proc.terminate() + raise RuntimeError( + f"Server did not start within {max_retries * 0.2}s\n" + f"stdout: {stdout.decode()}\n" + f"stderr: {stderr.decode()}" + ) + + url = f"http://127.0.0.1:{port}" + + yield url + + # Cleanup + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + + # Remove temp database + if os.path.exists(db_path): + os.unlink(db_path) + + # Remove temp log directory + import shutil + + if os.path.exists(log_dir): + shutil.rmtree(log_dir, ignore_errors=True) + + +@pytest.fixture +def e2e_page(page: Page, e2e_server: str) -> Page: + """ + Provide a page connected to the E2E server. + + This fixture navigates to the server and waits for WebSocket connection. + """ + page.goto(e2e_server) + + # Wait for WebSocket connection + page.wait_for_function( + "() => document.querySelector('#status-badge')?.textContent ===" + " 'Connected'", + timeout=10000, + ) + + return page + + +class MockLLMResponse: + """Mock LLM response for testing.""" + + def __init__( + self, + text: str, + has_code: bool = False, + code: str | None = None, + is_final: bool = False, + ): + self.text = text + self.has_code = has_code + self.code = code + self.is_final = is_final + + def to_response_text(self) -> str: + """Generate response text that mimics LLM output.""" + if self.is_final: + return f"FINAL({self.text})" + + if self.has_code and self.code: + return f"Let me calculate that.\n\n```python\n{self.code}\n```" + + return self.text + + +# Pre-defined mock responses for common test scenarios +MOCK_RESPONSES = { + "simple_math": [ + MockLLMResponse( + text="4", + has_code=True, + code="result = 2 + 2\nFINAL_VAR('result')", + is_final=False, + ), + ], + "multi_step": [ + MockLLMResponse( + text="First, let me break this down.", + has_code=True, + code="step1 = 10 * 2\nprint(f'Step 1: {step1}')", + ), + MockLLMResponse( + text="Now let me finish.", + has_code=True, + code="final_result = step1 + 5\nFINAL_VAR('final_result')", + ), + ], + "no_code": [ + MockLLMResponse( + text="The capital of France is Paris.", + is_final=True, + ), + ], + "error_recovery": [ + MockLLMResponse( + text="Let me try.", + has_code=True, + code="result = undefined_variable", # Will error + ), + MockLLMResponse( + text="Let me fix that.", + has_code=True, + code="result = 42\nFINAL_VAR('result')", + ), + ], +} + + +def create_mock_llm_client(): + """Create a mock Gemini client for testing.""" + from unittest.mock import AsyncMock + from unittest.mock import MagicMock + + mock_client = MagicMock() + mock_aio = MagicMock() + mock_models = MagicMock() + + response_queue = [] + + async def mock_generate_content(*args, **kwargs): + """Mock generate_content that returns queued responses.""" + if response_queue: + response = response_queue.pop(0) + else: + # Default response if queue is empty + response = MockLLMResponse("FINAL(No response configured)", is_final=True) + + mock_response = MagicMock() + mock_response.text = response.to_response_text() + return mock_response + + mock_models.generate_content = AsyncMock(side_effect=mock_generate_content) + mock_aio.models = mock_models + mock_client.aio = mock_aio + + # Attach response queue for test manipulation + mock_client._response_queue = response_queue + + return mock_client + + +@pytest.fixture +def mock_llm(): + """ + Fixture that provides a mock LLM client. + + Usage: + def test_example(mock_llm, e2e_page): + mock_llm.queue_responses(MOCK_RESPONSES["simple_math"]) + # ... run test + """ + + class MockLLMManager: + + def __init__(self): + self.responses: list[MockLLMResponse] = [] + + def queue_response(self, response: MockLLMResponse): + """Queue a single response.""" + self.responses.append(response) + + def queue_responses(self, responses: list[MockLLMResponse]): + """Queue multiple responses.""" + self.responses.extend(responses) + + def clear(self): + """Clear all queued responses.""" + self.responses.clear() + + return MockLLMManager() + + +# Helper functions for E2E tests + + +def wait_for_query_complete(page: Page, timeout: int = 30000): + """Wait for a query to complete (processing indicator hidden).""" + # First wait for processing to start (not hidden) + try: + page.wait_for_function( + "() =>" + " !document.querySelector('#processing')?.classList.contains('hidden')", + timeout=5000, + ) + except Exception: + # Processing might have already started and finished + pass + + # Then wait for processing to finish (hidden again) + page.wait_for_function( + "() =>" + " document.querySelector('#processing')?.classList.contains('hidden')", + timeout=timeout, + ) + + +def wait_for_answer(page: Page, timeout: int = 30000): + """Wait for an answer panel to appear.""" + page.wait_for_selector(".answer-panel", timeout=timeout) + + +def get_answer_text(page: Page) -> str: + """Get the text from the last answer panel.""" + answer = page.locator(".answer-text").last + return answer.text_content() or "" + + +def get_event_count(page: Page) -> int: + """Get the current event count from the UI.""" + count_text = page.locator("#event-count").text_content() or "0 events" + return int(count_text.split()[0]) + + +def submit_query(page: Page, query: str): + """Submit a query via the input field.""" + prompt_input = page.locator("#prompt-input") + prompt_input.fill(query) + prompt_input.press("Enter") + + +def get_message_count(page: Page) -> int: + """Get the number of messages in the conversation.""" + return page.locator(".message").count() + + +def get_session_count(page: Page) -> int: + """Get the number of sessions in the sidebar.""" + return page.locator(".session-item").count() diff --git a/contributing/samples/rlm/tests/e2e/test_multi_turn.py b/contributing/samples/rlm/tests/e2e/test_multi_turn.py new file mode 100644 index 0000000000..ae692a2be2 --- /dev/null +++ b/contributing/samples/rlm/tests/e2e/test_multi_turn.py @@ -0,0 +1,283 @@ +""" +E2E tests for multi-turn conversations and file handling. + +These tests verify conversation context, file loading, and complex +interaction patterns with a real server. +""" + +import os +from pathlib import Path +import re +import tempfile + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import get_answer_text +from .conftest import get_message_count +from .conftest import submit_query +from .conftest import wait_for_query_complete + +pytestmark = [ + pytest.mark.e2e_web, + pytest.mark.skipif( + os.environ.get("RLM_E2E_TESTS") != "true", + reason="E2E tests disabled. Set RLM_E2E_TESTS=true to enable.", + ), +] + + +class TestMultiTurnConversation: + """Tests for multi-turn conversation handling.""" + + def test_conversation_history_grows(self, e2e_page: Page): + """Test that conversation history grows with each turn.""" + # First turn + submit_query(e2e_page, "Hello, my name is Alice") + wait_for_query_complete(e2e_page, timeout=60000) + + first_count = get_message_count(e2e_page) + assert first_count >= 2, "Should have user message and response" + + # Second turn + submit_query(e2e_page, "What is my name?") + wait_for_query_complete(e2e_page, timeout=60000) + + second_count = get_message_count(e2e_page) + assert second_count >= 4, f"Should have 4+ messages, got {second_count}" + + # Third turn + submit_query(e2e_page, "Thanks for remembering!") + wait_for_query_complete(e2e_page, timeout=60000) + + third_count = get_message_count(e2e_page) + assert third_count >= 6, f"Should have 6+ messages, got {third_count}" + + def test_context_maintained_across_turns(self, e2e_page: Page): + """Test that context is maintained across conversation turns.""" + # Establish context + submit_query(e2e_page, "Let x = 100") + wait_for_query_complete(e2e_page, timeout=60000) + + # Reference context + submit_query(e2e_page, "What is x + 50?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Answer should reference the context + answer = get_answer_text(e2e_page) + # Should contain 150 or reference to x + assert ( + "150" in answer or "x" in answer.lower() + ), f"Unexpected answer: {answer}" + + def test_follow_up_questions(self, e2e_page: Page): + """Test asking follow-up questions about previous answers.""" + # Initial question + submit_query(e2e_page, "What is the capital of France?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Follow-up + submit_query(e2e_page, "What is its population?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should understand "its" refers to Paris + answer = get_answer_text(e2e_page) + # Should mention numbers (population) or Paris + assert ( + any(c.isdigit() for c in answer) or "paris" in answer.lower() + ), f"Expected population info, got: {answer}" + + +class TestCodeExecutionAcrossTurns: + """Tests for code execution across multiple turns.""" + + def test_variables_persist_across_turns(self, e2e_page: Page): + """Test that REPL variables persist across conversation turns.""" + # Define a variable + submit_query(e2e_page, "Calculate result = 2 ** 10 and show me the value") + wait_for_query_complete(e2e_page, timeout=60000) + + first_answer = get_answer_text(e2e_page) + assert "1024" in first_answer, f"Expected 1024 in answer: {first_answer}" + + # Use the variable in next turn + submit_query(e2e_page, "Now divide result by 2") + wait_for_query_complete(e2e_page, timeout=60000) + + second_answer = get_answer_text(e2e_page) + assert "512" in second_answer, f"Expected 512 in answer: {second_answer}" + + def test_function_definition_persists(self, e2e_page: Page): + """Test that function definitions persist across turns.""" + # Define a function + submit_query( + e2e_page, + "Define a function called double that returns its argument times 2", + ) + wait_for_query_complete(e2e_page, timeout=60000) + + # Use the function + submit_query(e2e_page, "Use the double function on 21") + wait_for_query_complete(e2e_page, timeout=60000) + + answer = get_answer_text(e2e_page) + assert "42" in answer, f"Expected 42 in answer: {answer}" + + +class TestFileHandling: + """Tests for file handling functionality.""" + + @pytest.fixture + def test_files(self, tmp_path: Path) -> dict: + """Create temporary test files.""" + # Create a text file + txt_file = tmp_path / "test_data.txt" + txt_file.write_text("This is test content.\nLine 2.\nLine 3.") + + # Create a markdown file + md_file = tmp_path / "readme.md" + md_file.write_text("# Test Document\n\nThis is a test markdown file.") + + # Create a Python file + py_file = tmp_path / "sample.py" + py_file.write_text("def hello():\n return 'Hello, World!'\n") + + return { + "txt": str(txt_file), + "md": str(md_file), + "py": str(py_file), + "dir": str(tmp_path), + } + + def test_add_files_via_settings(self, e2e_page: Page, test_files: dict): + """Test adding files via the settings modal.""" + # Open settings + e2e_page.locator("#config-btn").click() + + # Add file pattern + files_input = e2e_page.locator("#config-files") + files_input.fill(f"{test_files['dir']}/*.txt") + + # Save + e2e_page.locator("#config-form button[type='submit']").click() + + e2e_page.wait_for_timeout(500) + + # Files section should be visible + files_section = e2e_page.locator("#files-section") + expect(files_section).not_to_have_class(re.compile(r"hidden")) + + def test_query_with_file_context(self, e2e_page: Page, test_files: dict): + """Test querying with file context.""" + # Open settings and add files + e2e_page.locator("#config-btn").click() + e2e_page.locator("#config-files").fill(f"{test_files['dir']}/*.txt") + e2e_page.locator("#config-form button[type='submit']").click() + + e2e_page.wait_for_timeout(500) + + # Query about the file content + submit_query(e2e_page, "How many lines are in the text file?") + wait_for_query_complete(e2e_page, timeout=60000) + + answer = get_answer_text(e2e_page) + # Should mention 3 lines + assert "3" in answer, f"Expected '3' in answer about lines: {answer}" + + def test_invalid_file_pattern_shows_error(self, e2e_page: Page): + """Test that invalid file patterns show an error.""" + # Open settings and add non-existent pattern + e2e_page.locator("#config-btn").click() + e2e_page.locator("#config-files").fill("/nonexistent/path/*.xyz") + e2e_page.locator("#config-form button[type='submit']").click() + + e2e_page.wait_for_timeout(500) + + # Files section should remain hidden (no files matched) + files_section = e2e_page.locator("#files-section") + expect(files_section).to_have_class(re.compile(r"hidden")) + + +class TestComplexInteractions: + """Tests for complex interaction patterns.""" + + def test_rapid_queries(self, e2e_page: Page): + """Test that the system handles queries correctly (one at a time).""" + # Submit first query + submit_query(e2e_page, "What is 1+1?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Submit second query immediately after + submit_query(e2e_page, "What is 2+2?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should have 4 messages (2 user + 2 assistant) + assert get_message_count(e2e_page) >= 4 + + def test_long_conversation(self, e2e_page: Page): + """Test a longer conversation with multiple turns.""" + queries = [ + "Let's count. Start with 1.", + "Add 1 to get the next number.", + "Add 1 again.", + "What number are we at now?", + ] + + for query in queries: + submit_query(e2e_page, query) + wait_for_query_complete(e2e_page, timeout=60000) + + # Should have 8 messages (4 user + 4 assistant) + message_count = get_message_count(e2e_page) + assert message_count >= 8, f"Expected 8+ messages, got {message_count}" + + # Final answer should mention 4 (or 3, depending on interpretation) + answer = get_answer_text(e2e_page) + assert any( + n in answer for n in ["3", "4"] + ), f"Expected 3 or 4 in answer: {answer}" + + +class TestEdgeCases: + """Tests for edge cases and boundary conditions.""" + + def test_empty_response_handling(self, e2e_page: Page): + """Test handling of queries that might produce empty responses.""" + submit_query(e2e_page, "Just say 'OK' and nothing else") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should still have an answer + answer_panels = e2e_page.locator(".answer-panel") + expect(answer_panels).to_have_count(1) + + def test_special_characters_in_query(self, e2e_page: Page): + """Test queries with special characters.""" + submit_query(e2e_page, "What is 'hello' + \" world\" in Python?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should handle special characters + answer = get_answer_text(e2e_page) + assert "hello" in answer.lower() or "world" in answer.lower() + + def test_unicode_in_query(self, e2e_page: Page): + """Test queries with unicode characters.""" + submit_query(e2e_page, "Print the emoji: \U0001F600") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should complete without error + answer_panels = e2e_page.locator(".answer-panel") + expect(answer_panels).to_have_count(1) + + def test_very_long_query(self, e2e_page: Page): + """Test handling of a very long query.""" + long_query = "Calculate the sum of: " + ", ".join( + str(i) for i in range(1, 51) + ) + submit_query(e2e_page, long_query) + wait_for_query_complete(e2e_page, timeout=90000) + + # Should complete and have an answer + answer = get_answer_text(e2e_page) + # Sum of 1 to 50 is 1275 + assert "1275" in answer, f"Expected 1275 in answer: {answer}" diff --git a/contributing/samples/rlm/tests/e2e/test_query_flow.py b/contributing/samples/rlm/tests/e2e/test_query_flow.py new file mode 100644 index 0000000000..413158ddfc --- /dev/null +++ b/contributing/samples/rlm/tests/e2e/test_query_flow.py @@ -0,0 +1,268 @@ +""" +E2E tests for the full query flow. + +These tests run against a real server with real WebSocket connections +and real session persistence. The LLM calls go to the real API. + +Note: These tests require API access and may be slow. They are marked +with @pytest.mark.e2e_web and can be skipped with -m "not e2e_web". +""" + +import os +import re + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import get_answer_text +from .conftest import get_event_count +from .conftest import get_message_count +from .conftest import submit_query +from .conftest import wait_for_answer +from .conftest import wait_for_query_complete + +pytestmark = [ + pytest.mark.e2e_web, + pytest.mark.skipif( + os.environ.get("RLM_E2E_TESTS") != "true", + reason="E2E tests disabled. Set RLM_E2E_TESTS=true to enable.", + ), +] + + +class TestBasicQuery: + """Tests for basic query submission and response.""" + + def test_simple_query_flow(self, e2e_page: Page): + """ + Test a simple query that should complete successfully. + + This test submits a query and verifies: + - User message appears + - Processing indicator shows + - Events are generated + - Answer is displayed + """ + # Submit a simple query + submit_query(e2e_page, "What is 2 + 2? Just give me the number.") + + # Verify user message appeared + user_message = e2e_page.locator(".message.user") + expect(user_message).to_be_visible() + expect(user_message).to_contain_text("What is 2 + 2") + + # Wait for completion (with generous timeout for LLM) + wait_for_query_complete(e2e_page, timeout=60000) + + # Verify we got some kind of response (answer or error) + # Either answer-panel or an error message + response = e2e_page.locator( + ".answer-panel, .message.assistant .message-content" + ) + expect(response.first).to_be_visible(timeout=10000) + + # Verify events were generated + event_count = get_event_count(e2e_page) + assert event_count > 0, "Should have generated some events" + + def test_query_generates_events(self, e2e_page: Page): + """Test that queries generate streaming events in the event log.""" + submit_query(e2e_page, "Calculate 10 * 5") + + # Wait for some events to appear + e2e_page.wait_for_function( + "() => parseInt(document.querySelector('#event-count')?.textContent ||" + " '0') > 0", + timeout=30000, + ) + + # Verify event items are visible + event_log = e2e_page.locator("#event-log-content") + expect(event_log).not_to_contain_text("Events will appear here") + + # Should have agent/iteration groups + wait_for_query_complete(e2e_page, timeout=60000) + event_count = get_event_count(e2e_page) + assert event_count >= 3, f"Expected at least 3 events, got {event_count}" + + def test_processing_state_during_query(self, e2e_page: Page): + """Test that processing indicator shows during query execution.""" + prompt_input = e2e_page.locator("#prompt-input") + prompt_input.fill("What is the square root of 144?") + + # Submit and immediately check processing state + prompt_input.press("Enter") + + # Processing should be visible + processing = e2e_page.locator("#processing") + expect(processing).not_to_have_class(re.compile(r"hidden")) + + # Send button should be disabled + send_btn = e2e_page.locator("#send-btn") + expect(send_btn).to_be_disabled() + + # Wait for completion + wait_for_query_complete(e2e_page, timeout=60000) + + # Processing should be hidden + expect(processing).to_have_class(re.compile(r"hidden")) + + # Send button should be enabled + expect(send_btn).to_be_enabled() + + +class TestMultipleQueries: + """Tests for multiple sequential queries.""" + + def test_two_queries_in_sequence(self, e2e_page: Page): + """Test submitting two queries in sequence.""" + # First query + submit_query(e2e_page, "What is 5 + 5?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Verify first answer + expect(e2e_page.locator(".answer-panel")).to_have_count(1) + + # Second query + submit_query(e2e_page, "What is 10 + 10?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should have two user messages and two answers + user_messages = e2e_page.locator(".message.user") + expect(user_messages).to_have_count(2) + + answer_panels = e2e_page.locator(".answer-panel") + expect(answer_panels).to_have_count(2) + + def test_conversation_context_maintained(self, e2e_page: Page): + """Test that conversation context is maintained across queries.""" + # First query establishes context + submit_query(e2e_page, "Remember this number: 42") + wait_for_query_complete(e2e_page, timeout=60000) + + # Second query references context + submit_query(e2e_page, "What number did I ask you to remember?") + wait_for_query_complete(e2e_page, timeout=60000) + + # The answer should reference 42 + answer_text = get_answer_text(e2e_page) + assert "42" in answer_text, f"Expected '42' in answer, got: {answer_text}" + + +class TestEventLogDuringQuery: + """Tests for event log behavior during query execution.""" + + def test_events_stream_in_realtime(self, e2e_page: Page): + """Test that events stream in real-time during query execution.""" + submit_query(e2e_page, "Count from 1 to 5") + + # Wait for at least one event + e2e_page.wait_for_function( + "() => parseInt(document.querySelector('#event-count')?.textContent ||" + " '0') > 0", + timeout=30000, + ) + + first_count = get_event_count(e2e_page) + assert first_count > 0, "Should have at least one event" + + # Wait for more events + e2e_page.wait_for_timeout(1000) + + # If query is still running, count should increase + # (or be the same if query completed quickly) + wait_for_query_complete(e2e_page, timeout=60000) + + final_count = get_event_count(e2e_page) + assert final_count >= first_count, "Event count should not decrease" + + def test_event_details_viewable(self, e2e_page: Page): + """Test that event details can be viewed by clicking events.""" + submit_query(e2e_page, "What is 3 * 3?") + wait_for_query_complete(e2e_page, timeout=60000) + + # Click on first event + event_item = e2e_page.locator(".event-item").first + event_item.click() + + # Modal should open + event_modal = e2e_page.locator("#event-modal") + expect(event_modal).not_to_have_class(re.compile(r"hidden")) + + # Modal should have content + modal_body = e2e_page.locator("#modal-body") + expect(modal_body).to_contain_text("Event Type") + + +class TestErrorHandling: + """Tests for error handling during queries.""" + + def test_query_with_invalid_code_recovers(self, e2e_page: Page): + """Test that the system can recover from code execution errors.""" + # This query might generate code that errors, but should still complete + submit_query( + e2e_page, + "Try to calculate something that might fail, then recover and give" + " me 42", + ) + + # Should eventually complete (possibly with error recovery) + wait_for_query_complete(e2e_page, timeout=90000) + + # Should have an answer (even if it's about the error) + answer_panels = e2e_page.locator(".answer-panel") + # Either we get an answer or an error message + messages = e2e_page.locator(".message.assistant") + assert messages.count() > 0, "Should have some response" + + +class TestUIStateAfterQuery: + """Tests for UI state after query completion.""" + + def test_input_cleared_after_submit(self, e2e_page: Page): + """Test that input is cleared after submitting query.""" + prompt_input = e2e_page.locator("#prompt-input") + prompt_input.fill("Test query") + prompt_input.press("Enter") + + # Input should be cleared immediately + expect(prompt_input).to_have_value("") + + def test_can_submit_new_query_after_completion(self, e2e_page: Page): + """Test that new queries can be submitted after completion.""" + # First query + submit_query(e2e_page, "Say hello") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should be able to type and submit again + prompt_input = e2e_page.locator("#prompt-input") + prompt_input.fill("Say goodbye") + + send_btn = e2e_page.locator("#send-btn") + expect(send_btn).to_be_enabled() + + send_btn.click() + + # Should start processing + processing = e2e_page.locator("#processing") + expect(processing).not_to_have_class(re.compile(r"hidden")) + + wait_for_query_complete(e2e_page, timeout=60000) + + def test_session_title_updates_from_first_message(self, e2e_page: Page): + """Test that session title is auto-generated from first message.""" + # Get initial title + session_title = e2e_page.locator("#session-title") + initial_title = session_title.text_content() + + # Submit first query + submit_query(e2e_page, "This is my test question about Python programming") + wait_for_query_complete(e2e_page, timeout=60000) + + # Title should be updated to reflect the query + e2e_page.wait_for_timeout(500) # Allow UI to update + new_title = session_title.text_content() + + # Title should have changed and contain part of the query + assert new_title != initial_title or "Python" in (new_title or "") diff --git a/contributing/samples/rlm/tests/e2e/test_session_lifecycle.py b/contributing/samples/rlm/tests/e2e/test_session_lifecycle.py new file mode 100644 index 0000000000..3ea726bcbe --- /dev/null +++ b/contributing/samples/rlm/tests/e2e/test_session_lifecycle.py @@ -0,0 +1,318 @@ +""" +E2E tests for session lifecycle management. + +These tests verify session creation, persistence, loading, and deletion +with a real server and database. +""" + +import os +import re + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import get_message_count +from .conftest import get_session_count +from .conftest import submit_query +from .conftest import wait_for_query_complete + +pytestmark = [ + pytest.mark.e2e_web, + pytest.mark.skipif( + os.environ.get("RLM_E2E_TESTS") != "true", + reason="E2E tests disabled. Set RLM_E2E_TESTS=true to enable.", + ), +] + + +class TestSessionCreation: + """Tests for creating new sessions.""" + + def test_initial_session_created(self, e2e_page: Page): + """Test that an initial session is created on page load.""" + # Should have a session ID in the sidebar + session_title = e2e_page.locator("#session-title") + expect(session_title).not_to_be_empty() + + # Status should be connected + status_badge = e2e_page.locator("#status-badge") + expect(status_badge).to_have_text("Connected") + + def test_create_new_session(self, e2e_page: Page): + """Test creating a new session via the + button.""" + # Get initial session count + initial_count = get_session_count(e2e_page) + + # Click new session button + new_session_btn = e2e_page.locator("#new-session-btn") + new_session_btn.click() + + # Wait for session list to update + e2e_page.wait_for_timeout(500) + + # Session count should increase + new_count = get_session_count(e2e_page) + assert new_count >= initial_count, "Session count should not decrease" + + # UI should be cleared (empty state) + empty_state = e2e_page.locator("#empty-state") + expect(empty_state).to_be_visible() + + def test_new_session_has_default_title(self, e2e_page: Page): + """Test that new sessions have a default title with timestamp.""" + # Create new session + e2e_page.locator("#new-session-btn").click() + e2e_page.wait_for_timeout(500) + + # Title should contain "Session" and a date-like pattern + session_title = e2e_page.locator("#session-title") + title_text = session_title.text_content() or "" + + assert ( + "Session" in title_text or "-" in title_text + ), f"Unexpected title: {title_text}" + + +class TestSessionPersistence: + """Tests for session persistence across page reloads.""" + + def test_conversation_persists_on_reload( + self, e2e_page: Page, e2e_server: str + ): + """Test that conversation persists after page reload.""" + # Submit a query + submit_query(e2e_page, "Remember this: the magic word is abracadabra") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should have messages + initial_message_count = get_message_count(e2e_page) + assert initial_message_count >= 2, "Should have user message and answer" + + # Reload page + e2e_page.reload() + + # Wait for WebSocket connection + e2e_page.wait_for_function( + "() => document.querySelector('#status-badge')?.textContent ===" + " 'Connected'", + timeout=10000, + ) + + # Wait for conversation to restore + e2e_page.wait_for_timeout(1000) + + # Messages should be restored + restored_count = get_message_count(e2e_page) + assert restored_count >= initial_message_count, ( + f"Expected at least {initial_message_count} messages, got" + f" {restored_count}" + ) + + def test_session_title_persists(self, e2e_page: Page, e2e_server: str): + """Test that session title persists after page reload.""" + # Submit a query to trigger title generation + submit_query(e2e_page, "This is a test query about persistence") + wait_for_query_complete(e2e_page, timeout=60000) + + # Get the title + session_title = e2e_page.locator("#session-title") + original_title = session_title.text_content() + + # Reload page + e2e_page.reload() + + e2e_page.wait_for_function( + "() => document.querySelector('#status-badge')?.textContent ===" + " 'Connected'", + timeout=10000, + ) + e2e_page.wait_for_timeout(1000) + + # Title should be restored + restored_title = session_title.text_content() + assert restored_title == original_title, ( + f"Title not restored. Expected '{original_title}', got" + f" '{restored_title}'" + ) + + +class TestSessionSwitching: + """Tests for switching between sessions.""" + + def test_switch_to_different_session(self, e2e_page: Page): + """Test switching to a different session.""" + # Create first session with content + submit_query(e2e_page, "First session content") + wait_for_query_complete(e2e_page, timeout=60000) + + first_message_count = get_message_count(e2e_page) + + # Create new session + e2e_page.locator("#new-session-btn").click() + e2e_page.wait_for_timeout(500) + + # New session should be empty + empty_state = e2e_page.locator("#empty-state") + expect(empty_state).to_be_visible() + + # Add content to second session + submit_query(e2e_page, "Second session content") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should have messages in second session + second_message_count = get_message_count(e2e_page) + assert second_message_count >= 2 + + # Switch back to first session (should be first in list, or second) + session_items = e2e_page.locator(".session-item") + if session_items.count() >= 2: + # Click on a different session + first_item = session_items.first + first_item.click() + + e2e_page.wait_for_timeout(1000) + + # Should load that session's content + # (messages may differ based on which session we loaded) + + +class TestSessionDeletion: + """Tests for deleting sessions.""" + + def test_delete_session(self, e2e_page: Page): + """Test deleting a session.""" + # Create a session with content + submit_query(e2e_page, "Session to delete") + wait_for_query_complete(e2e_page, timeout=60000) + + # Create another session + e2e_page.locator("#new-session-btn").click() + e2e_page.wait_for_timeout(500) + + initial_session_count = get_session_count(e2e_page) + assert initial_session_count >= 2, "Should have at least 2 sessions" + + # Set up dialog handler to accept confirmation + e2e_page.on("dialog", lambda dialog: dialog.accept()) + + # Delete the first session + session_item = e2e_page.locator(".session-item").first + delete_btn = session_item.locator(".session-item-delete") + session_item.hover() + delete_btn.click() + + e2e_page.wait_for_timeout(500) + + # Session count should decrease + new_count = get_session_count(e2e_page) + assert new_count < initial_session_count, "Session should be deleted" + + +class TestClearSession: + """Tests for clearing session content.""" + + def test_clear_session_removes_messages(self, e2e_page: Page): + """Test that clearing a session removes messages.""" + # Add content + submit_query(e2e_page, "Content to clear") + wait_for_query_complete(e2e_page, timeout=60000) + + assert get_message_count(e2e_page) >= 2, "Should have messages" + + # Set up dialog handler to accept confirmation + e2e_page.on("dialog", lambda dialog: dialog.accept()) + + # Open settings and clear + e2e_page.locator("#config-btn").click() + e2e_page.locator("#config-clear").click() + + e2e_page.wait_for_timeout(500) + + # Messages should be cleared + empty_state = e2e_page.locator("#empty-state") + expect(empty_state).to_be_visible() + + def test_clear_session_removes_events(self, e2e_page: Page): + """Test that clearing a session removes events.""" + # Add content + submit_query(e2e_page, "Generate some events") + wait_for_query_complete(e2e_page, timeout=60000) + + # Should have events + event_count = e2e_page.locator("#event-count") + expect(event_count).not_to_have_text("0 events") + + # Set up dialog handler to accept confirmation + e2e_page.on("dialog", lambda dialog: dialog.accept()) + + # Open settings and clear + e2e_page.locator("#config-btn").click() + e2e_page.locator("#config-clear").click() + + e2e_page.wait_for_timeout(500) + + # Events should be cleared + expect(event_count).to_have_text("0 events") + + +class TestSessionConfiguration: + """Tests for session configuration changes.""" + + def test_change_model_setting(self, e2e_page: Page): + """Test changing the model setting.""" + # Open settings + e2e_page.locator("#config-btn").click() + + # Change model + model_input = e2e_page.locator("#config-model") + model_input.fill("gemini-3-flash-preview") + + # Save + e2e_page.locator("#config-form button[type='submit']").click() + + e2e_page.wait_for_timeout(500) + + # Modal should close + config_modal = e2e_page.locator("#config-modal") + expect(config_modal).to_have_class(re.compile(r"hidden")) + + # Reopen settings to verify + e2e_page.locator("#config-btn").click() + expect(model_input).to_have_value("gemini-3-flash-preview") + + def test_change_max_iterations(self, e2e_page: Page): + """Test changing the max iterations setting.""" + # Open settings + e2e_page.locator("#config-btn").click() + + # Change max iterations + iterations_input = e2e_page.locator("#config-iterations") + iterations_input.fill("10") + + # Save + e2e_page.locator("#config-form button[type='submit']").click() + + e2e_page.wait_for_timeout(500) + + # Reopen and verify + e2e_page.locator("#config-btn").click() + expect(iterations_input).to_have_value("10") + + def test_change_session_title(self, e2e_page: Page): + """Test changing the session title.""" + # Open settings + e2e_page.locator("#config-btn").click() + + # Change title + title_input = e2e_page.locator("#config-title") + title_input.fill("My Custom Title") + + # Save + e2e_page.locator("#config-form button[type='submit']").click() + + e2e_page.wait_for_timeout(500) + + # Title should be updated in header + session_title = e2e_page.locator("#session-title") + expect(session_title).to_have_text("My Custom Title") diff --git a/contributing/samples/rlm/tests/fixtures/contexts/medium.txt b/contributing/samples/rlm/tests/fixtures/contexts/medium.txt new file mode 100644 index 0000000000..4c73c13893 --- /dev/null +++ b/contributing/samples/rlm/tests/fixtures/contexts/medium.txt @@ -0,0 +1,76 @@ +Chapter 1: Introduction to Machine Learning + +Machine learning is a subset of artificial intelligence that enables computers +to learn from data without being explicitly programmed. The field has grown +tremendously over the past decade, with applications in various domains. + +Section 1.1: Types of Machine Learning + +There are three main types of machine learning: + +1. Supervised Learning: The algorithm learns from labeled training data, + making predictions based on that data. Examples include: + - Classification: Predicting categories (spam vs. not spam) + - Regression: Predicting continuous values (house prices) + +2. Unsupervised Learning: The algorithm learns patterns from unlabeled data. + Examples include: + - Clustering: Grouping similar data points + - Dimensionality Reduction: Reducing the number of features + +3. Reinforcement Learning: The algorithm learns by interacting with an + environment and receiving rewards or penalties. + +Section 1.2: Key Concepts + +- Features: Input variables used for predictions +- Labels: Target variables in supervised learning +- Training Data: Data used to train the model +- Test Data: Data used to evaluate model performance +- Overfitting: When a model performs well on training data but poorly on new data +- Underfitting: When a model is too simple to capture underlying patterns + +Section 1.3: Popular Algorithms + +- Linear Regression +- Logistic Regression +- Decision Trees +- Random Forests +- Support Vector Machines +- Neural Networks +- K-Nearest Neighbors +- K-Means Clustering + +Chapter 2: Deep Learning + +Deep learning is a subset of machine learning that uses neural networks with +multiple layers. These networks can automatically learn hierarchical +representations of data. + +Section 2.1: Neural Network Basics + +A neural network consists of: +- Input layer: Receives the raw data +- Hidden layers: Perform transformations +- Output layer: Produces the final prediction + +Each layer contains neurons (nodes) connected by weighted edges. The network +learns by adjusting these weights during training. + +Section 2.2: Types of Neural Networks + +1. Feedforward Neural Networks (FNN) +2. Convolutional Neural Networks (CNN) - for image processing +3. Recurrent Neural Networks (RNN) - for sequential data +4. Transformer Networks - for natural language processing + +Section 2.3: Training Neural Networks + +- Forward Propagation: Computing outputs from inputs +- Backpropagation: Computing gradients for weight updates +- Gradient Descent: Optimizing weights to minimize loss +- Learning Rate: Controls the size of weight updates +- Batch Size: Number of samples per weight update +- Epochs: Complete passes through the training data + +The magic number in this document is 42. diff --git a/contributing/samples/rlm/tests/fixtures/contexts/short.txt b/contributing/samples/rlm/tests/fixtures/contexts/short.txt new file mode 100644 index 0000000000..19e9ced9b6 --- /dev/null +++ b/contributing/samples/rlm/tests/fixtures/contexts/short.txt @@ -0,0 +1,3 @@ +The quick brown fox jumps over the lazy dog. +This sentence contains every letter of the English alphabet. +The magic number is 42. diff --git a/contributing/samples/rlm/tests/test_context_passing.py b/contributing/samples/rlm/tests/test_context_passing.py new file mode 100644 index 0000000000..43c423a31e --- /dev/null +++ b/contributing/samples/rlm/tests/test_context_passing.py @@ -0,0 +1,462 @@ +""" +Tests for context passing to child agents. + +This module tests the functionality of passing context objects (including +LazyFile and LazyFileCollection) to child agents via llm_query and llm_query_batched. +""" + +from pathlib import Path +from typing import Any + +from adk_rlm.files.lazy import LazyFile +from adk_rlm.files.lazy import LazyFileCollection +from adk_rlm.files.loader import FileLoader +from adk_rlm.repl.local_repl import LocalREPL +from adk_rlm.types import QueryMetadata +import pytest + + +class TestQueryMetadataWithLazyFiles: + """Tests for QueryMetadata handling of LazyFile types.""" + + @pytest.fixture + def sample_lazy_file(self, tmp_path): + """Create a sample LazyFile for testing.""" + subdir = tmp_path / "lazy_file_test" + subdir.mkdir() + test_file = subdir / "test.txt" + test_file.write_text("This is test content with some text.") + loader = FileLoader() + return loader.create_lazy_file(str(test_file)) + + @pytest.fixture + def sample_lazy_collection(self, tmp_path): + """Create a sample LazyFileCollection for testing.""" + subdir = tmp_path / "lazy_collection_test" + subdir.mkdir() + file1 = subdir / "file1.txt" + file2 = subdir / "file2.txt" + file3 = subdir / "file3.txt" + file1.write_text("Content of file 1") + file2.write_text("Content of file 2 with more text") + file3.write_text("Short") + loader = FileLoader() + # Use explicit file list instead of glob to avoid cwd issues + return loader.create_lazy_files([str(file1), str(file2), str(file3)]) + + def test_query_metadata_with_string(self): + """QueryMetadata handles string context.""" + metadata = QueryMetadata("This is a simple string context") + assert metadata.context_type == "str" + assert len(metadata.context_lengths) == 1 + assert metadata.context_total_length == 31 + + def test_query_metadata_with_dict(self): + """QueryMetadata handles dict context.""" + metadata = QueryMetadata({"key1": "value1", "key2": "value2"}) + assert metadata.context_type == "dict" + assert len(metadata.context_lengths) == 2 + + def test_query_metadata_with_list(self): + """QueryMetadata handles list context.""" + metadata = QueryMetadata(["item1", "item2", "item3"]) + assert metadata.context_type == "list" + assert len(metadata.context_lengths) == 3 + + def test_query_metadata_with_lazy_file(self, sample_lazy_file): + """QueryMetadata handles LazyFile context.""" + metadata = QueryMetadata(sample_lazy_file) + assert metadata.context_type == "lazy_file" + assert len(metadata.context_lengths) == 1 + # size_bytes may be 0 before file is loaded (lazy loading) + assert metadata.context_total_length >= 0 + + def test_query_metadata_with_lazy_collection(self, sample_lazy_collection): + """QueryMetadata handles LazyFileCollection context.""" + metadata = QueryMetadata(sample_lazy_collection) + assert metadata.context_type == "lazy_file_collection" + assert len(metadata.context_lengths) == 3 + # size_bytes may be 0 before files are loaded (lazy loading) + assert metadata.context_total_length >= 0 + + def test_query_metadata_with_empty_list(self): + """QueryMetadata handles empty list.""" + metadata = QueryMetadata([]) + assert metadata.context_type == "list" + assert metadata.context_lengths == [0] + + +class TestREPLContextLoading: + """Tests for REPL context loading with various types.""" + + @pytest.fixture + def sample_lazy_file(self, tmp_path): + """Create a sample LazyFile for testing.""" + subdir = tmp_path / "repl_lazy_file_test" + subdir.mkdir() + test_file = subdir / "test.txt" + test_file.write_text("This is test content.") + loader = FileLoader() + return loader.create_lazy_file(str(test_file)) + + @pytest.fixture + def sample_lazy_collection(self, tmp_path): + """Create a sample LazyFileCollection for testing.""" + subdir = tmp_path / "repl_lazy_collection_test" + subdir.mkdir() + file_a = subdir / "a.txt" + file_b = subdir / "b.txt" + file_a.write_text("Content A") + file_b.write_text("Content B") + loader = FileLoader() + # Use explicit file list instead of glob to avoid cwd issues + return loader.create_lazy_files([str(file_a), str(file_b)]) + + def test_load_string_context(self): + """REPL loads string context correctly.""" + repl = LocalREPL() + repl.load_context("test string context") + + result = repl.execute_code("print(context)") + assert "test string context" in result.stdout + + def test_load_dict_context(self): + """REPL loads dict context correctly.""" + repl = LocalREPL() + repl.load_context({"key": "value", "number": 42}) + + result = repl.execute_code("print(context['key'])") + assert "value" in result.stdout + + def test_load_lazy_file_context(self, sample_lazy_file): + """REPL loads LazyFile context correctly.""" + repl = LocalREPL() + repl.load_context(sample_lazy_file) + + # Check that context is a LazyFile + result = repl.execute_code("print(type(context).__name__)") + assert "LazyFile" in result.stdout + + # Check that we can access file properties + result = repl.execute_code("print(context.name)") + assert "test.txt" in result.stdout + + def test_load_lazy_collection_context(self, sample_lazy_collection): + """REPL loads LazyFileCollection context correctly.""" + repl = LocalREPL() + repl.load_context(sample_lazy_collection) + + # Check that context is a LazyFileCollection + result = repl.execute_code("print(type(context).__name__)") + assert "LazyFileCollection" in result.stdout + + # Check that we can iterate over files + result = repl.execute_code("print(len(list(context)))") + assert "2" in result.stdout + + +class TestLLMQueryContextPassing: + """Tests for llm_query context parameter.""" + + @pytest.fixture + def tracking_llm_query(self): + """Create an llm_query function that tracks calls.""" + calls = [] + + def _llm_query( + prompt: str, + context: Any = None, + model: str | None = None, + recursive: bool = True, + ) -> str: + calls.append({ + "prompt": prompt, + "context": context, + "model": model, + "recursive": recursive, + }) + return f"Response for: {prompt[:30]}..." + + _llm_query.calls = calls + return _llm_query + + def test_llm_query_without_context(self, tracking_llm_query): + """llm_query works without context.""" + repl = LocalREPL(llm_query_fn=tracking_llm_query) + result = repl.execute_code( + "response = llm_query('What is 2+2?')\nprint(response)" + ) + + assert len(tracking_llm_query.calls) == 1 + assert tracking_llm_query.calls[0]["prompt"] == "What is 2+2?" + assert tracking_llm_query.calls[0]["context"] is None + assert "Response for:" in result.stdout + + def test_llm_query_with_string_context(self, tracking_llm_query): + """llm_query passes string context correctly.""" + repl = LocalREPL(llm_query_fn=tracking_llm_query) + result = repl.execute_code( + "response = llm_query('Summarize this', context='Some text to" + " summarize')\nprint(response)" + ) + + assert len(tracking_llm_query.calls) == 1 + assert tracking_llm_query.calls[0]["context"] == "Some text to summarize" + + def test_llm_query_with_dict_context(self, tracking_llm_query): + """llm_query passes dict context correctly.""" + repl = LocalREPL(llm_query_fn=tracking_llm_query) + result = repl.execute_code( + "ctx = {'data': 'important info'}\n" + "response = llm_query('Analyze this', context=ctx)\n" + "print(response)" + ) + + assert len(tracking_llm_query.calls) == 1 + assert tracking_llm_query.calls[0]["context"] == {"data": "important info"} + + def test_llm_query_with_lazy_file_context(self, tracking_llm_query, tmp_path): + """llm_query passes LazyFile context correctly.""" + # Create a test file in isolated subdir + subdir = tmp_path / "llm_query_lazy_test" + subdir.mkdir() + test_file = subdir / "data.txt" + test_file.write_text("File content here") + + loader = FileLoader() + lazy_file = loader.create_lazy_file(str(test_file)) + + repl = LocalREPL(llm_query_fn=tracking_llm_query) + repl.load_context(lazy_file) + + result = repl.execute_code( + "response = llm_query('Summarize the file', context=context)\n" + "print(response)" + ) + + assert len(tracking_llm_query.calls) == 1 + assert isinstance(tracking_llm_query.calls[0]["context"], LazyFile) + + def test_llm_query_with_model_override(self, tracking_llm_query): + """llm_query passes model parameter correctly.""" + repl = LocalREPL(llm_query_fn=tracking_llm_query) + result = repl.execute_code( + "response = llm_query('Test', model='custom-model')\nprint(response)" + ) + + assert tracking_llm_query.calls[0]["model"] == "custom-model" + + def test_llm_query_with_recursive_false(self, tracking_llm_query): + """llm_query passes recursive=False correctly.""" + repl = LocalREPL(llm_query_fn=tracking_llm_query) + result = repl.execute_code( + "response = llm_query('Test', recursive=False)\nprint(response)" + ) + + assert tracking_llm_query.calls[0]["recursive"] is False + + +class TestLLMQueryBatchedContextPassing: + """Tests for llm_query_batched contexts parameter.""" + + @pytest.fixture + def tracking_llm_query_batched(self): + """Create an llm_query_batched function that tracks calls.""" + calls = [] + + def _llm_query_batched( + prompts: list[str], + contexts: list[Any] | None = None, + model: str | None = None, + recursive: bool = False, + ) -> list[str]: + calls.append({ + "prompts": prompts, + "contexts": contexts, + "model": model, + "recursive": recursive, + }) + return [f"Response {i}" for i in range(len(prompts))] + + _llm_query_batched.calls = calls + return _llm_query_batched + + @pytest.fixture + def tracking_llm_query(self): + """Create an llm_query function for fallback.""" + + def _llm_query(prompt, context=None, model=None, recursive=True): + return f"Response for: {prompt[:20]}..." + + return _llm_query + + def test_batched_without_contexts( + self, tracking_llm_query, tracking_llm_query_batched + ): + """llm_query_batched works without contexts.""" + repl = LocalREPL( + llm_query_fn=tracking_llm_query, + llm_query_batched_fn=tracking_llm_query_batched, + ) + result = repl.execute_code( + "prompts = ['Q1', 'Q2', 'Q3']\n" + "responses = llm_query_batched(prompts)\n" + "print(len(responses))" + ) + + assert len(tracking_llm_query_batched.calls) == 1 + assert tracking_llm_query_batched.calls[0]["contexts"] is None + assert "3" in result.stdout + + def test_batched_with_contexts( + self, tracking_llm_query, tracking_llm_query_batched + ): + """llm_query_batched passes contexts correctly.""" + repl = LocalREPL( + llm_query_fn=tracking_llm_query, + llm_query_batched_fn=tracking_llm_query_batched, + ) + result = repl.execute_code( + "prompts = ['Summarize A', 'Summarize B']\n" + "contexts = ['Content A', 'Content B']\n" + "responses = llm_query_batched(prompts, contexts=contexts)\n" + "print(len(responses))" + ) + + assert len(tracking_llm_query_batched.calls) == 1 + assert tracking_llm_query_batched.calls[0]["contexts"] == [ + "Content A", + "Content B", + ] + + def test_batched_with_lazy_files( + self, tracking_llm_query, tracking_llm_query_batched, tmp_path + ): + """llm_query_batched passes LazyFile contexts correctly.""" + # Create test files in isolated subdir + subdir = tmp_path / "batched_lazy_test" + subdir.mkdir() + file1 = subdir / "file1.txt" + file2 = subdir / "file2.txt" + file1.write_text("Content 1") + file2.write_text("Content 2") + + loader = FileLoader() + # Use explicit file list instead of glob to avoid cwd issues + files = loader.create_lazy_files([str(file1), str(file2)]) + file_list = list(files) + + repl = LocalREPL( + llm_query_fn=tracking_llm_query, + llm_query_batched_fn=tracking_llm_query_batched, + ) + repl.locals["files"] = file_list + + result = repl.execute_code( + "prompts = [f'Summarize {f.name}' for f in files]\n" + "responses = llm_query_batched(prompts, contexts=files)\n" + "print(len(responses))" + ) + + assert len(tracking_llm_query_batched.calls) == 1 + contexts = tracking_llm_query_batched.calls[0]["contexts"] + assert len(contexts) == 2 + assert all(isinstance(c, LazyFile) for c in contexts) + + def test_batched_with_recursive_true( + self, tracking_llm_query, tracking_llm_query_batched + ): + """llm_query_batched passes recursive=True correctly.""" + repl = LocalREPL( + llm_query_fn=tracking_llm_query, + llm_query_batched_fn=tracking_llm_query_batched, + ) + result = repl.execute_code( + "responses = llm_query_batched(['Q1', 'Q2'], recursive=True)\n" + "print(len(responses))" + ) + + assert tracking_llm_query_batched.calls[0]["recursive"] is True + + def test_batched_fallback_without_batched_fn(self, tracking_llm_query): + """llm_query_batched falls back to individual calls if no batched fn.""" + calls = [] + + def tracking_query(prompt, context=None, model=None, recursive=True): + calls.append({"prompt": prompt, "context": context}) + return f"Response for: {prompt}" + + repl = LocalREPL(llm_query_fn=tracking_query) + result = repl.execute_code( + "responses = llm_query_batched(['Q1', 'Q2'])\nprint(len(responses))" + ) + + # Should make 2 individual calls + assert len(calls) == 2 + assert "2" in result.stdout + + def test_batched_fallback_with_contexts(self, tracking_llm_query): + """llm_query_batched fallback passes contexts correctly.""" + calls = [] + + def tracking_query(prompt, context=None, model=None, recursive=True): + calls.append({"prompt": prompt, "context": context}) + return f"Response for: {prompt}" + + repl = LocalREPL(llm_query_fn=tracking_query) + result = repl.execute_code( + "responses = llm_query_batched(['Q1', 'Q2'], contexts=['C1', 'C2'])\n" + "print(len(responses))" + ) + + assert len(calls) == 2 + assert calls[0]["context"] == "C1" + assert calls[1]["context"] == "C2" + + +class TestCodeExecutorContextPassing: + """Tests for RLMCodeExecutor context passing to child agents.""" + + @pytest.fixture + def sample_lazy_file(self, tmp_path): + """Create a sample LazyFile for testing.""" + subdir = tmp_path / "executor_lazy_test" + subdir.mkdir() + test_file = subdir / "test.txt" + test_file.write_text("Test file content for child agent.") + loader = FileLoader() + return loader.create_lazy_file(str(test_file)) + + def test_executor_llm_query_signature(self): + """Verify RLMCodeExecutor creates llm_query with correct signature.""" + import inspect + + from adk_rlm.code_executor import RLMCodeExecutor + + executor = RLMCodeExecutor() + llm_query_fn = executor._create_llm_query_fn() + + sig = inspect.signature(llm_query_fn) + params = list(sig.parameters.keys()) + + assert "prompt" in params + assert "context" in params + assert "model" in params + assert "recursive" in params + + def test_executor_llm_query_batched_signature(self): + """Verify RLMCodeExecutor creates llm_query_batched with correct signature.""" + import inspect + + from adk_rlm.code_executor import RLMCodeExecutor + + executor = RLMCodeExecutor() + llm_query_batched_fn = executor._create_llm_query_batched_fn() + + sig = inspect.signature(llm_query_batched_fn) + params = list(sig.parameters.keys()) + + assert "prompts" in params + assert "contexts" in params + assert "model" in params + assert "recursive" in params diff --git a/contributing/samples/rlm/tests/test_e2e.py b/contributing/samples/rlm/tests/test_e2e.py new file mode 100644 index 0000000000..3140e36e93 --- /dev/null +++ b/contributing/samples/rlm/tests/test_e2e.py @@ -0,0 +1,188 @@ +""" +End-to-end tests for ADK-RLM. + +These tests use real Gemini API calls and are slow. +They are skipped by default unless RLM_E2E_TESTS=true is set. +""" + +from functools import wraps +import time + +from adk_rlm import completion +from adk_rlm import RLM +from adk_rlm import RLMEventType +import pytest + + +def retry_on_api_error(max_retries: int = 3, delay: float = 5.0): + """Decorator to retry tests on transient API errors.""" + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + last_error = None + for attempt in range(max_retries): + try: + return func(*args, **kwargs) + except Exception as e: + error_str = str(e).lower() + # Retry on transient errors + if any( + x in error_str + for x in [ + "quota", + "rate", + "cancelled", + "503", + "500", + "overloaded", + ] + ): + last_error = e + if attempt < max_retries - 1: + time.sleep(delay * (attempt + 1)) + continue + raise + raise last_error + + return wrapper + + return decorator + + +async def run_query(rlm: RLM, context: str, prompt: str) -> str: + """Helper to run a query and return the final answer.""" + final_answer = None + async for event in rlm.run_streaming(context, prompt): + if event.custom_metadata: + event_type = event.custom_metadata.get("event_type") + if event_type == RLMEventType.FINAL_ANSWER.value: + final_answer = event.custom_metadata.get("answer") + return final_answer or "" + + +@pytest.mark.e2e +@pytest.mark.timeout(180) +class TestE2EBasicFunctionality: + """Basic E2E tests.""" + + @retry_on_api_error(max_retries=3, delay=5.0) + def test_simple_computation(self): + """Test that RLM can do simple computation via REPL.""" + result = completion( + context="Calculate: 17 * 23", + prompt=( + "Compute 17 * 23 using Python code in the REPL. Return the result" + " with FINAL()." + ), + model="gemini-3-flash-preview", + max_iterations=10, + ) + + assert "391" in result.response + + @retry_on_api_error(max_retries=3, delay=5.0) + def test_context_access(self, sample_context): + """Test that RLM can access and analyze context.""" + result = completion( + context=sample_context, + prompt=( + "Read the context variable and find the magic number. Print it" + " using the REPL, then return it with FINAL()." + ), + model="gemini-3-flash-preview", + max_iterations=10, + ) + + assert "42" in result.response + + @retry_on_api_error(max_retries=3, delay=5.0) + def test_uses_llm_query(self, fixtures_dir): + """Test that RLM uses llm_query for analysis.""" + context_file = fixtures_dir / "contexts" / "medium.txt" + if not context_file.exists(): + pytest.skip("Medium context file not found") + + context = context_file.read_text() + + result = completion( + context=context, + prompt="What are the main topics covered in this document? Be brief.", + model="gemini-3-flash-preview", + sub_model="gemini-3-flash-preview", + max_iterations=15, + ) + + # Check that we got a reasonable response + assert len(result.response) > 50 + + +@pytest.mark.e2e +@pytest.mark.timeout(300) +class TestE2EMultiTurn: + """Multi-turn conversation tests.""" + + async def test_context_accumulation(self): + """Test that contexts accumulate across turns.""" + rlm = RLM( + model="gemini-3-flash-preview", + max_iterations=10, + persistent=True, + verbose=False, + ) + try: + # First turn - explicitly use REPL to read context + result1 = await run_query( + rlm, + context="Alice is 30 years old.", + prompt=( + "Print the context variable and extract the age number. Return it" + " with FINAL()." + ), + ) + assert "30" in result1 + + # Second turn - should have access to first context via context_0 + result2 = await run_query( + rlm, + context="Bob is 25 years old.", + prompt=( + "Print context_0 to get Alice's age, and context_1 to get Bob's" + " age. Who is older? Return just the name with FINAL()." + ), + ) + assert "Alice" in result2 + finally: + rlm.close() + + +@pytest.mark.e2e +@pytest.mark.timeout(180) +class TestE2ELogging: + """Test logging and tracing functionality.""" + + @retry_on_api_error(max_retries=3, delay=5.0) + def test_jsonl_logging(self, temp_log_dir): + """Test that JSONL logs are created correctly.""" + import json + from pathlib import Path + + result = completion( + context="Test context", + prompt="Just say FINAL(ok).", + model="gemini-3-flash-preview", + log_dir=temp_log_dir, + max_iterations=5, + ) + + # Check log file was created + log_files = list(Path(temp_log_dir).glob("*.jsonl")) + assert len(log_files) == 1 + + # Check log contents + with open(log_files[0]) as f: + lines = f.readlines() + + entries = [json.loads(line) for line in lines] + assert entries[0]["type"] == "metadata" + assert any(e["type"] == "iteration" for e in entries) diff --git a/contributing/samples/rlm/tests/test_files.py b/contributing/samples/rlm/tests/test_files.py new file mode 100644 index 0000000000..2c62508e49 --- /dev/null +++ b/contributing/samples/rlm/tests/test_files.py @@ -0,0 +1,781 @@ +""" +Tests for the file handling module. + +Tests cover: +- LocalFileSource loading and glob patterns +- TextParser for various text formats +- PDFParser (when pdfplumber is available) +- LazyFile progressive disclosure +- LazyFileCollection filtering +- FileLoader orchestration +""" + +import json +import os +from pathlib import Path +import tempfile + +from adk_rlm.files import FileLoader +from adk_rlm.files import FileMetadata +from adk_rlm.files import LazyFile +from adk_rlm.files import LazyFileCollection +from adk_rlm.files import LoadedFile +from adk_rlm.files import LocalFileSource +from adk_rlm.files import ParsedContent +from adk_rlm.files import TextParser +import pytest + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def sample_text_file(temp_dir: Path): + """Create a sample text file.""" + path = temp_dir / "sample.txt" + path.write_text("Hello, world!\nThis is a test file.") + return path + + +@pytest.fixture +def sample_json_file(temp_dir: Path): + """Create a sample JSON file.""" + path = temp_dir / "data.json" + data = {"name": "Test", "values": [1, 2, 3], "nested": {"key": "value"}} + path.write_text(json.dumps(data)) + return path + + +@pytest.fixture +def sample_csv_file(temp_dir: Path): + """Create a sample CSV file.""" + path = temp_dir / "data.csv" + path.write_text("name,age,city\nAlice,30,NYC\nBob,25,LA\nCharlie,35,Chicago") + return path + + +@pytest.fixture +def sample_markdown_file(temp_dir: Path): + """Create a sample Markdown file.""" + path = temp_dir / "readme.md" + path.write_text( + "# Title\n\nThis is a **markdown** file.\n\n- Item 1\n- Item 2" + ) + return path + + +@pytest.fixture +def multiple_files(temp_dir: Path): + """Create multiple files of various types.""" + files = {} + + # Text files + (temp_dir / "doc1.txt").write_text("Document 1 content") + (temp_dir / "doc2.txt").write_text("Document 2 content") + files["txt"] = [temp_dir / "doc1.txt", temp_dir / "doc2.txt"] + + # Markdown files + (temp_dir / "readme.md").write_text("# README\n\nProject documentation") + (temp_dir / "notes.md").write_text("# Notes\n\nImportant notes") + files["md"] = [temp_dir / "readme.md", temp_dir / "notes.md"] + + # JSON files + (temp_dir / "config.json").write_text('{"setting": "value"}') + files["json"] = [temp_dir / "config.json"] + + # Subdirectory with files + subdir = temp_dir / "subdir" + subdir.mkdir() + (subdir / "nested.txt").write_text("Nested file content") + files["nested"] = [subdir / "nested.txt"] + + return files + + +# ============================================================================ +# LocalFileSource Tests +# ============================================================================ + + +class TestLocalFileSource: + """Tests for LocalFileSource.""" + + def test_source_type(self): + """Test source type identifier.""" + source = LocalFileSource() + assert source.source_type == "local" + + def test_load_text_file(self, sample_text_file: Path): + """Test loading a text file.""" + source = LocalFileSource() + loaded = source.load(str(sample_text_file)) + + assert isinstance(loaded, LoadedFile) + assert loaded.metadata.name == "sample.txt" + assert loaded.as_text() == "Hello, world!\nThis is a test file." + assert loaded.metadata.size_bytes > 0 + assert loaded.metadata.source_type == "local" + + def test_load_with_base_path(self, temp_dir: Path, sample_text_file: Path): + """Test loading with base path.""" + source = LocalFileSource(base_path=temp_dir) + loaded = source.load("sample.txt") + + assert loaded.metadata.name == "sample.txt" + assert loaded.as_text() == "Hello, world!\nThis is a test file." + + def test_resolve_single_file(self, sample_text_file: Path): + """Test resolving a single file path.""" + source = LocalFileSource() + paths = source.resolve(str(sample_text_file)) + + assert len(paths) == 1 + assert paths[0] == str(sample_text_file) + + def test_resolve_glob_pattern(self, temp_dir: Path, multiple_files): + """Test resolving glob patterns.""" + source = LocalFileSource(base_path=temp_dir) + + # Match all txt files + paths = source.resolve("*.txt") + assert len(paths) == 2 + assert all(p.endswith(".txt") for p in paths) + + def test_resolve_recursive_glob(self, temp_dir: Path, multiple_files): + """Test resolving recursive glob patterns.""" + source = LocalFileSource(base_path=temp_dir) + + # Match all txt files including subdirectories + paths = source.resolve("**/*.txt") + assert len(paths) == 3 # 2 in root + 1 in subdir + + def test_get_metadata_efficient(self, sample_text_file: Path): + """Test getting metadata without loading content.""" + source = LocalFileSource() + metadata = source.get_metadata(str(sample_text_file)) + + assert isinstance(metadata, FileMetadata) + assert metadata.name == "sample.txt" + assert metadata.size_bytes > 0 + assert metadata.last_modified is not None + + def test_exists_true(self, sample_text_file: Path): + """Test exists returns True for existing file.""" + source = LocalFileSource() + assert source.exists(str(sample_text_file)) is True + + def test_exists_false(self, temp_dir: Path): + """Test exists returns False for non-existing file.""" + source = LocalFileSource() + assert source.exists(str(temp_dir / "nonexistent.txt")) is False + + def test_load_nonexistent_raises(self, temp_dir: Path): + """Test loading non-existent file raises error.""" + source = LocalFileSource() + with pytest.raises(FileNotFoundError): + source.load(str(temp_dir / "nonexistent.txt")) + + +# ============================================================================ +# TextParser Tests +# ============================================================================ + + +class TestTextParser: + """Tests for TextParser.""" + + def test_supported_extensions(self): + """Test that common text extensions are supported.""" + parser = TextParser() + exts = parser.supported_extensions + assert ".txt" in exts + assert ".md" in exts + assert ".json" in exts + assert ".csv" in exts + assert ".py" in exts + + def test_parse_plain_text(self, sample_text_file: Path): + """Test parsing plain text file.""" + source = LocalFileSource() + parser = TextParser() + + loaded = source.load(str(sample_text_file)) + assert parser.can_parse(loaded) + + parsed = parser.parse(loaded) + assert isinstance(parsed, ParsedContent) + assert "Hello, world!" in parsed.text + assert parsed.metadata["format"] == ".txt" + + def test_parse_json(self, sample_json_file: Path): + """Test parsing JSON file.""" + source = LocalFileSource() + parser = TextParser() + + loaded = source.load(str(sample_json_file)) + parsed = parser.parse(loaded) + + assert "Test" in parsed.text + assert parsed.metadata["format"] == ".json" + assert parsed.metadata["json_type"] == "dict" + + def test_parse_csv(self, sample_csv_file: Path): + """Test parsing CSV file with table extraction.""" + source = LocalFileSource() + parser = TextParser() + + loaded = source.load(str(sample_csv_file)) + parsed = parser.parse(loaded) + + assert parsed.tables is not None + assert len(parsed.tables) == 3 # 3 data rows + assert parsed.tables[0]["name"] == "Alice" + assert parsed.metadata["row_count"] == 3 + assert "name" in parsed.metadata["columns"] + + def test_parse_markdown(self, sample_markdown_file: Path): + """Test parsing Markdown file.""" + source = LocalFileSource() + parser = TextParser() + + loaded = source.load(str(sample_markdown_file)) + parsed = parser.parse(loaded) + + assert "# Title" in parsed.text + assert "**markdown**" in parsed.text + assert parsed.metadata["format"] == ".md" + + +# ============================================================================ +# LazyFile Tests +# ============================================================================ + + +class TestLazyFile: + """Tests for LazyFile progressive disclosure.""" + + def test_level_0_no_io(self, sample_text_file: Path): + """Test Level 0 properties don't trigger I/O.""" + source = LocalFileSource() + parser = TextParser() + lazy = LazyFile(path=str(sample_text_file), source=source, parser=parser) + + # Level 0 access - no loading + assert lazy.name == "sample.txt" + assert lazy.extension == ".txt" + assert lazy.is_loaded is False + assert lazy.level == 0 + + def test_level_1_metadata(self, sample_text_file: Path): + """Test Level 1 metadata access.""" + source = LocalFileSource() + parser = TextParser() + lazy = LazyFile(path=str(sample_text_file), source=source, parser=parser) + + # Level 1 access - metadata only + assert lazy.size > 0 + assert lazy.level == 1 + assert lazy.is_loaded is False # Full content not loaded + assert lazy.mime_type == "text/plain" + + def test_level_2_content(self, sample_text_file: Path): + """Test Level 2 content access.""" + source = LocalFileSource() + parser = TextParser() + lazy = LazyFile(path=str(sample_text_file), source=source, parser=parser) + + # Level 2 access - full content + content = lazy.content + assert "Hello, world!" in content + assert lazy.level == 2 + assert lazy.is_loaded is True + assert lazy.is_parsed is True + + def test_size_properties(self, sample_text_file: Path): + """Test size conversion properties.""" + source = LocalFileSource() + parser = TextParser() + lazy = LazyFile(path=str(sample_text_file), source=source, parser=parser) + + assert lazy.size_kb == lazy.size / 1024 + assert lazy.size_mb == lazy.size / (1024 * 1024) + + def test_preload_metadata(self, sample_text_file: Path): + """Test preload_metadata method.""" + source = LocalFileSource() + parser = TextParser() + lazy = LazyFile(path=str(sample_text_file), source=source, parser=parser) + + result = lazy.preload_metadata() + assert result is lazy # Returns self for chaining + assert lazy.level >= 1 + + def test_preload_full(self, sample_text_file: Path): + """Test preload method.""" + source = LocalFileSource() + parser = TextParser() + lazy = LazyFile(path=str(sample_text_file), source=source, parser=parser) + + result = lazy.preload() + assert result is lazy + assert lazy.level == 2 + + def test_read_method(self, sample_text_file: Path): + """Test read method for raw text access.""" + source = LocalFileSource() + parser = TextParser() + lazy = LazyFile(path=str(sample_text_file), source=source, parser=parser) + + text = lazy.read() + assert "Hello, world!" in text + # Note: read() may trigger full load depending on source + + def test_repr(self, sample_text_file: Path): + """Test string representation.""" + source = LocalFileSource() + lazy = LazyFile(path=str(sample_text_file), source=source) + + repr_str = repr(lazy) + assert "LazyFile" in repr_str + assert "sample.txt" in repr_str + assert "level=0" in repr_str + + +# ============================================================================ +# LazyFileCollection Tests +# ============================================================================ + + +class TestLazyFileCollection: + """Tests for LazyFileCollection.""" + + def test_empty_collection(self): + """Test empty collection.""" + collection = LazyFileCollection([]) + assert len(collection) == 0 + assert bool(collection) is False + + def test_names_property(self, temp_dir: Path, multiple_files): + """Test names property without loading.""" + source = LocalFileSource(base_path=temp_dir) + parser = TextParser() + + lazy_files = [ + LazyFile(path=str(f), source=source, parser=parser) + for f in multiple_files["txt"] + ] + collection = LazyFileCollection(lazy_files) + + names = collection.names + assert len(names) == 2 + assert "doc1.txt" in names + assert "doc2.txt" in names + + def test_by_extension(self, temp_dir: Path, multiple_files): + """Test filtering by extension.""" + source = LocalFileSource(base_path=temp_dir) + parser = TextParser() + + all_files = ( + multiple_files["txt"] + multiple_files["md"] + multiple_files["json"] + ) + lazy_files = [ + LazyFile(path=str(f), source=source, parser=parser) for f in all_files + ] + collection = LazyFileCollection(lazy_files) + + # Filter by extension + txt_files = collection.by_extension(".txt") + assert len(txt_files) == 2 + assert all(f.extension == ".txt" for f in txt_files) + + md_files = collection.by_extension("md") # Without leading dot + assert len(md_files) == 2 + + def test_by_name_pattern(self, temp_dir: Path, multiple_files): + """Test filtering by name pattern.""" + source = LocalFileSource(base_path=temp_dir) + parser = TextParser() + + all_files = multiple_files["txt"] + multiple_files["md"] + lazy_files = [ + LazyFile(path=str(f), source=source, parser=parser) for f in all_files + ] + collection = LazyFileCollection(lazy_files) + + # Filter by pattern + docs = collection.by_name("doc*.txt") + assert len(docs) == 2 + + readme = collection.by_name("readme*") + assert len(readme) == 1 + + def test_search(self, temp_dir: Path, multiple_files): + """Test keyword search.""" + source = LocalFileSource(base_path=temp_dir) + parser = TextParser() + + all_files = multiple_files["txt"] + multiple_files["md"] + lazy_files = [ + LazyFile(path=str(f), source=source, parser=parser) for f in all_files + ] + collection = LazyFileCollection(lazy_files) + + # Case-insensitive search + results = collection.search("DOC") + assert len(results) == 2 + + def test_loaded_count(self, temp_dir: Path, multiple_files): + """Test loaded count tracking.""" + source = LocalFileSource(base_path=temp_dir) + parser = TextParser() + + lazy_files = [ + LazyFile(path=str(f), source=source, parser=parser) + for f in multiple_files["txt"] + ] + collection = LazyFileCollection(lazy_files) + + assert collection.loaded_count == 0 + + # Load first file + _ = collection[0].content + assert collection.loaded_count == 1 + + def test_extensions_property(self, temp_dir: Path, multiple_files): + """Test extensions set property.""" + source = LocalFileSource(base_path=temp_dir) + parser = TextParser() + + all_files = ( + multiple_files["txt"] + multiple_files["md"] + multiple_files["json"] + ) + lazy_files = [ + LazyFile(path=str(f), source=source, parser=parser) for f in all_files + ] + collection = LazyFileCollection(lazy_files) + + extensions = collection.extensions + assert ".txt" in extensions + assert ".md" in extensions + assert ".json" in extensions + + def test_summary(self, temp_dir: Path, multiple_files): + """Test summary output.""" + source = LocalFileSource(base_path=temp_dir) + parser = TextParser() + + all_files = multiple_files["txt"] + multiple_files["md"] + lazy_files = [ + LazyFile(path=str(f), source=source, parser=parser) for f in all_files + ] + collection = LazyFileCollection(lazy_files) + + summary = collection.summary() + assert "LazyFileCollection" in summary + assert ".txt: 2" in summary + assert ".md: 2" in summary + + +# ============================================================================ +# FileLoader Tests +# ============================================================================ + + +class TestFileLoader: + """Tests for FileLoader orchestrator.""" + + def test_default_sources_and_parsers(self): + """Test default configuration.""" + loader = FileLoader() + assert "local" in loader.sources + assert len(loader.parsers) >= 2 # At least TextParser and PDFParser + + def test_load_single_file(self, sample_text_file: Path): + """Test loading a single file.""" + loader = FileLoader() + parsed = loader.load_single(str(sample_text_file)) + + assert isinstance(parsed, ParsedContent) + assert "Hello, world!" in parsed.text + + def test_load_multiple_files(self, temp_dir: Path, multiple_files): + """Test loading multiple files.""" + loader = FileLoader(base_path=temp_dir) + files = ["doc1.txt", "doc2.txt"] + results = loader.load_files(files) + + assert len(results) == 2 + assert all(isinstance(r, ParsedContent) for r in results) + + def test_load_with_glob(self, temp_dir: Path, multiple_files): + """Test loading with glob pattern.""" + loader = FileLoader(base_path=temp_dir) + results = loader.load_files(["*.txt"]) + + assert len(results) == 2 + + def test_create_lazy_files(self, temp_dir: Path, multiple_files): + """Test creating lazy file collection.""" + loader = FileLoader(base_path=temp_dir) + collection = loader.create_lazy_files(["*.txt"]) + + assert isinstance(collection, LazyFileCollection) + assert len(collection) == 2 + assert collection.loaded_count == 0 # Not loaded yet + + def test_create_lazy_file_single(self, sample_text_file: Path): + """Test creating single lazy file.""" + loader = FileLoader() + lazy = loader.create_lazy_file(str(sample_text_file)) + + assert isinstance(lazy, LazyFile) + assert lazy.name == "sample.txt" + + def test_build_context_lazy(self, temp_dir: Path, multiple_files): + """Test building context with lazy loading.""" + loader = FileLoader(base_path=temp_dir) + context = loader.build_context(["*.txt"], lazy=True) + + assert "files" in context + assert "file_count" in context + assert "file_names" in context + assert isinstance(context["files"], LazyFileCollection) + + def test_build_context_eager(self, temp_dir: Path, multiple_files): + """Test building context with eager loading.""" + loader = FileLoader(base_path=temp_dir) + context = loader.build_context(["*.txt"], lazy=False) + + assert "files" in context + assert "file_count" in context + # Files are already parsed dicts + assert isinstance(context["files"], list) + + def test_register_parser(self, sample_text_file: Path): + """Test registering custom parser.""" + loader = FileLoader() + initial_count = len(loader.parsers) + + # Register another TextParser (for testing) + loader.register_parser(TextParser()) + assert len(loader.parsers) == initial_count + 1 + + def test_nonexistent_file_raises(self, temp_dir: Path): + """Test loading non-existent file raises error.""" + loader = FileLoader() + with pytest.raises(FileNotFoundError): + loader.load_single(str(temp_dir / "nonexistent.txt")) + + +# ============================================================================ +# FileMetadata Tests +# ============================================================================ + + +class TestFileMetadata: + """Tests for FileMetadata dataclass.""" + + def test_size_properties(self): + """Test size conversion properties.""" + metadata = FileMetadata( + name="test.txt", + path="/path/to/test.txt", + source_type="local", + size_bytes=1024 * 1024, # 1 MB + ) + + assert metadata.size_kb == 1024 + assert metadata.size_mb == 1.0 + + def test_extension_property(self): + """Test extension extraction.""" + metadata = FileMetadata( + name="document.PDF", + path="/path/document.PDF", + source_type="local", + size_bytes=100, + ) + + assert metadata.extension == ".pdf" # Lowercase + + def test_to_dict(self): + """Test serialization to dict.""" + from datetime import datetime + + metadata = FileMetadata( + name="test.txt", + path="/path/test.txt", + source_type="local", + size_bytes=100, + mime_type="text/plain", + last_modified=datetime(2024, 1, 1, 12, 0, 0), + extra={"key": "value"}, + ) + + d = metadata.to_dict() + assert d["name"] == "test.txt" + assert d["size_bytes"] == 100 + assert "2024" in d["last_modified"] + + +# ============================================================================ +# ParsedContent Tests +# ============================================================================ + + +class TestParsedContent: + """Tests for ParsedContent dataclass.""" + + def test_has_tables(self): + """Test has_tables property.""" + content_with_tables = ParsedContent( + text="data", + tables=[{"col": "val"}], + ) + content_without = ParsedContent(text="data") + + assert content_with_tables.has_tables is True + assert content_without.has_tables is False + + def test_has_chunks(self): + """Test has_chunks property.""" + content_with_chunks = ParsedContent( + text="data", + chunks=["chunk1", "chunk2"], + ) + content_without = ParsedContent(text="data") + + assert content_with_chunks.has_chunks is True + assert content_without.has_chunks is False + + def test_counts(self): + """Test count properties.""" + content = ParsedContent( + text="data", + chunks=["a", "b", "c"], + tables=[{"x": 1}, {"x": 2}], + ) + + assert content.chunk_count == 3 + assert content.table_count == 2 + + +# ============================================================================ +# Integration Tests +# ============================================================================ + + +class TestIntegration: + """Integration tests for the file handling system.""" + + def test_full_workflow_lazy(self, temp_dir: Path, multiple_files): + """Test complete workflow with lazy loading.""" + loader = FileLoader(base_path=temp_dir) + + # Create lazy collection + files = loader.create_lazy_files(["**/*.txt", "**/*.md"]) + + # Level 0 - no I/O + assert len(files) >= 4 + names = files.names + assert all(isinstance(n, str) for n in names) + + # Filter without loading + txt_files = files.by_extension(".txt") + assert len(txt_files) == 3 # Including nested + + # Level 2 - load specific files + for f in txt_files[:2]: + content = f.content + assert len(content) > 0 + assert f.level == 2 + + # Check stats + assert txt_files.loaded_count >= 2 + + def test_full_workflow_eager(self, temp_dir: Path, multiple_files): + """Test complete workflow with eager loading.""" + loader = FileLoader(base_path=temp_dir) + + # Eager load all files + results = loader.load_files(["*.txt", "*.md"]) + + assert len(results) == 4 + assert all(isinstance(r, ParsedContent) for r in results) + assert all(len(r.text) > 0 for r in results) + + def test_build_context_for_rlm(self, temp_dir: Path, multiple_files): + """Test building context suitable for RLM consumption.""" + loader = FileLoader(base_path=temp_dir) + + context = loader.build_context(["doc1.txt", "doc2.txt"], lazy=True) + + # Context should have expected structure + assert "files" in context + assert "file_count" in context + assert context["file_count"] == 2 + + # Files are lazy - can filter without loading + files = context["files"] + assert files.by_extension(".txt") + + +# ============================================================================ +# PDF Parser Tests (conditional on pdfplumber availability) +# ============================================================================ + + +class TestPDFParser: + """Tests for PDFParser (requires pdfplumber).""" + + @pytest.fixture + def sample_pdf(self, temp_dir: Path): + """Create a minimal PDF file for testing.""" + try: + from reportlab.pdfgen import canvas + + pdf_path = temp_dir / "test.pdf" + c = canvas.Canvas(str(pdf_path)) + c.drawString(100, 750, "Hello PDF World!") + c.drawString(100, 700, "This is a test PDF.") + c.save() + return pdf_path + except ImportError: + pytest.skip("reportlab not installed for PDF generation") + + @pytest.mark.skipif( + not os.path.exists("/usr/bin/python3"), reason="Test requires PDF library" + ) + def test_pdf_parser_import(self): + """Test PDFParser can be imported.""" + from adk_rlm.files import PDFParser + + parser = PDFParser() + assert parser.supported_extensions == [".pdf"] + + def test_pdf_can_parse(self, temp_dir: Path): + """Test can_parse identifies PDF files.""" + from adk_rlm.files import PDFParser + + parser = PDFParser() + + # Create fake loaded file with PDF extension + fake_metadata = FileMetadata( + name="document.pdf", + path=str(temp_dir / "document.pdf"), + source_type="local", + size_bytes=1000, + mime_type="application/pdf", + ) + fake_file = LoadedFile(metadata=fake_metadata, content=b"fake pdf") + + assert parser.can_parse(fake_file) is True diff --git a/contributing/samples/rlm/tests/test_gcs_integration.py b/contributing/samples/rlm/tests/test_gcs_integration.py new file mode 100644 index 0000000000..15cedd19f6 --- /dev/null +++ b/contributing/samples/rlm/tests/test_gcs_integration.py @@ -0,0 +1,245 @@ +""" +Integration tests for GCSFileSource with real GCS access. + +These tests require actual GCS access and are skipped unless +the RLM_GCS_TEST_BUCKET environment variable is set. + +Required environment variables: +- RLM_GCS_TEST_BUCKET: GCS bucket name for testing +- RLM_GCS_TEST_FILE: (optional) Path to test file in bucket (default: test/sample.txt) +- RLM_GCS_TEST_PREFIX: (optional) Prefix for glob tests (default: test) +- RLM_GCS_TEST_PROJECT: (optional) GCP project ID + +To run these tests: + RLM_GCS_TEST_BUCKET=my-test-bucket pytest tests/test_gcs_integration.py -v +""" + +import os + +import pytest + +# Skip all tests if GCS bucket not configured +pytestmark = pytest.mark.skipif( + not os.environ.get("RLM_GCS_TEST_BUCKET"), + reason="GCS integration tests require RLM_GCS_TEST_BUCKET env var", +) + +# Also skip if google-cloud-storage is not installed +pytest.importorskip("google.cloud.storage") + + +@pytest.fixture +def gcs_source(): + """Create GCSFileSource for integration tests.""" + from adk_rlm.files.sources.gcs import GCSFileSource + + return GCSFileSource( + bucket=os.environ["RLM_GCS_TEST_BUCKET"], + project=os.environ.get("RLM_GCS_TEST_PROJECT"), + ) + + +@pytest.fixture +def test_bucket(): + """Get test bucket name.""" + return os.environ["RLM_GCS_TEST_BUCKET"] + + +@pytest.fixture +def test_file_path(): + """Path to test file in GCS bucket.""" + return os.environ.get("RLM_GCS_TEST_FILE", "test/sample.txt") + + +@pytest.fixture +def test_prefix(): + """Prefix for glob pattern tests.""" + return os.environ.get("RLM_GCS_TEST_PREFIX", "test") + + +class TestGCSIntegration: + """Integration tests requiring real GCS access.""" + + def test_source_type(self, gcs_source): + """Test source type is 'gcs'.""" + assert gcs_source.source_type == "gcs" + + def test_exists_real_file(self, gcs_source, test_bucket, test_file_path): + """Test checking existence of a real file.""" + result = gcs_source.exists(f"gs://{test_bucket}/{test_file_path}") + assert result is True + + def test_exists_nonexistent_file(self, gcs_source, test_bucket): + """Test exists returns False for missing file.""" + result = gcs_source.exists(f"gs://{test_bucket}/nonexistent-file-12345.txt") + assert result is False + + def test_get_metadata(self, gcs_source, test_bucket, test_file_path): + """Test fetching real file metadata.""" + path = f"gs://{test_bucket}/{test_file_path}" + + metadata = gcs_source.get_metadata(path) + + assert metadata.name == test_file_path.split("/")[-1] + assert metadata.size_bytes > 0 + assert metadata.source_type == "gcs" + assert metadata.path == path + assert "bucket" in metadata.extra + assert metadata.extra["bucket"] == test_bucket + + def test_load(self, gcs_source, test_bucket, test_file_path): + """Test loading real file content.""" + path = f"gs://{test_bucket}/{test_file_path}" + + loaded = gcs_source.load(path) + + assert len(loaded.content) > 0 + assert loaded.metadata.path == path + assert loaded.metadata.source_type == "gcs" + + def test_load_not_found(self, gcs_source, test_bucket): + """Test proper error for missing file.""" + with pytest.raises(FileNotFoundError): + gcs_source.load(f"gs://{test_bucket}/nonexistent-file-12345.txt") + + def test_resolve_single_file(self, gcs_source, test_bucket, test_file_path): + """Test resolving a single file path.""" + path = f"gs://{test_bucket}/{test_file_path}" + + result = gcs_source.resolve(path) + + assert result == [path] + + def test_resolve_nonexistent_file(self, gcs_source, test_bucket): + """Test resolving a nonexistent file returns empty list.""" + path = f"gs://{test_bucket}/nonexistent-file-12345.txt" + + result = gcs_source.resolve(path) + + assert result == [] + + def test_resolve_glob_pattern(self, gcs_source, test_bucket, test_prefix): + """Test glob pattern resolution.""" + pattern = f"gs://{test_bucket}/{test_prefix}/*" + + paths = gcs_source.resolve(pattern) + + # May return empty if no files, but should not error + assert isinstance(paths, list) + for path in paths: + assert path.startswith(f"gs://{test_bucket}/") + + def test_load_many_single(self, gcs_source, test_bucket, test_file_path): + """Test load_many with single file.""" + path = f"gs://{test_bucket}/{test_file_path}" + + results = list(gcs_source.load_many([path])) + + assert len(results) == 1 + assert len(results[0].content) > 0 + + def test_load_many_multiple(self, gcs_source, test_bucket, test_file_path): + """Test load_many with same file twice (tests parallelism).""" + path = f"gs://{test_bucket}/{test_file_path}" + paths = [path, path] + + results = list(gcs_source.load_many(paths)) + + assert len(results) == 2 + + +class TestGCSWithFileLoader: + """Test GCS source integration with FileLoader.""" + + def test_file_loader_with_gcs(self, gcs_source, test_bucket, test_file_path): + """Test FileLoader works with GCS source.""" + from adk_rlm.files.loader import FileLoader + + path = f"gs://{test_bucket}/{test_file_path}" + + loader = FileLoader(sources={"gcs": gcs_source}) + collection = loader.create_lazy_files([path]) + + assert len(collection) == 1 + assert collection[0].name == test_file_path.split("/")[-1] + + def test_lazy_loading_with_gcs(self, gcs_source, test_bucket, test_file_path): + """Test lazy file loading from GCS.""" + from adk_rlm.files.loader import FileLoader + + path = f"gs://{test_bucket}/{test_file_path}" + + loader = FileLoader(sources={"gcs": gcs_source}) + collection = loader.create_lazy_files([path]) + + lazy_file = collection[0] + + # Level 0 - no I/O yet + assert lazy_file.level == 0 + _ = lazy_file.name # Still no I/O + + # Level 1 - metadata fetch + size = lazy_file.size + assert lazy_file.level == 1 + assert size > 0 + + def test_lazy_file_content_access( + self, gcs_source, test_bucket, test_file_path + ): + """Test accessing lazy file content triggers download.""" + from adk_rlm.files.loader import FileLoader + from adk_rlm.files.parsers.text import TextParser + + path = f"gs://{test_bucket}/{test_file_path}" + + loader = FileLoader(sources={"gcs": gcs_source}, parsers=[TextParser()]) + collection = loader.create_lazy_files([path]) + + lazy_file = collection[0] + + # Access raw content (no parsing needed) + raw_content = lazy_file.read() + assert len(raw_content) > 0 + assert lazy_file.level >= 1 # At least metadata loaded + + +class TestGCSRetryBehavior: + """Test retry behavior with real GCS (best-effort tests).""" + + def test_timeout_configuration(self, gcs_source): + """Test that timeout is configurable.""" + from adk_rlm.files.sources.gcs import GCSFileSource + from adk_rlm.files.sources.gcs import RetryConfig + + # Create source with custom timeout + source = GCSFileSource( + bucket=os.environ["RLM_GCS_TEST_BUCKET"], + timeout=5.0, + retry_config=RetryConfig(max_attempts=2, initial_delay=0.1), + ) + + assert source.timeout == 5.0 + assert source.retry_config.max_attempts == 2 + + +class TestGCSEdgeCases: + """Test edge cases with real GCS.""" + + def test_path_without_gs_prefix(self, gcs_source, test_file_path): + """Test loading with path without gs:// prefix uses default bucket.""" + # This should work because GCSFileSource has default bucket set + loaded = gcs_source.load(test_file_path) + + assert len(loaded.content) > 0 + + def test_metadata_extra_fields(self, gcs_source, test_bucket, test_file_path): + """Test that extra metadata fields are populated.""" + path = f"gs://{test_bucket}/{test_file_path}" + + metadata = gcs_source.get_metadata(path) + + # Check that GCS-specific fields are present + assert "blob_name" in metadata.extra + assert "storage_class" in metadata.extra + # generation and metageneration should be present + assert "generation" in metadata.extra diff --git a/contributing/samples/rlm/tests/test_gcs_pickle.py b/contributing/samples/rlm/tests/test_gcs_pickle.py new file mode 100644 index 0000000000..75c86a7e5b --- /dev/null +++ b/contributing/samples/rlm/tests/test_gcs_pickle.py @@ -0,0 +1,102 @@ +""" +Test that GCSFileSource can be pickled. + +This is important because LazyFile objects store a reference to their source, +and if the source can't be pickled, serialization of file collections will fail. +""" + +import pickle + +import pytest + + +class TestGCSFileSourcePickle: + """Test pickling of GCSFileSource.""" + + @pytest.fixture + def gcs_source(self): + """Create a GCSFileSource instance.""" + pytest.importorskip("google.cloud.storage") + from adk_rlm.files.sources.gcs import GCSFileSource + + return GCSFileSource(bucket="test-bucket") + + def test_gcs_source_can_be_pickled(self, gcs_source): + """GCSFileSource should be pickleable.""" + # Pickle and unpickle + pickled = pickle.dumps(gcs_source) + unpickled = pickle.loads(pickled) + + # Check that the unpickled source has the same config + assert unpickled.default_bucket == gcs_source.default_bucket + assert unpickled.timeout == gcs_source.timeout + assert unpickled.max_concurrent == gcs_source.max_concurrent + + def test_gcs_source_client_is_lazy(self, gcs_source): + """GCSFileSource client should be lazily initialized.""" + # Before accessing client, _client should be None + assert gcs_source._client is None + + # After pickling, _client should still be None + pickled = pickle.dumps(gcs_source) + unpickled = pickle.loads(pickled) + assert unpickled._client is None + + def test_gcs_source_pickle_after_client_access(self, gcs_source): + """GCSFileSource should be pickleable even after client is accessed.""" + # Access the client to initialize it + # Note: This may fail if no credentials are available, which is fine for this test + try: + _ = gcs_source.client + except Exception: + pytest.skip("GCS credentials not available") + + # Should still be pickleable - client should be excluded + pickled = pickle.dumps(gcs_source) + unpickled = pickle.loads(pickled) + + # Client should be None after unpickling (will be re-created on demand) + assert unpickled._client is None + assert unpickled.default_bucket == gcs_source.default_bucket + + def test_lazy_file_with_gcs_source_can_be_pickled(self, gcs_source): + """LazyFile with GCSFileSource should be pickleable.""" + from adk_rlm.files.lazy import LazyFile + + lazy_file = LazyFile( + path="gs://test-bucket/test.txt", + source=gcs_source, + parser=None, + ) + + # Should be pickleable + pickled = pickle.dumps(lazy_file) + unpickled = pickle.loads(pickled) + + assert unpickled.path == lazy_file.path + assert unpickled.source.default_bucket == gcs_source.default_bucket + + def test_lazy_file_collection_with_gcs_source_can_be_pickled( + self, gcs_source + ): + """LazyFileCollection with GCS files should be pickleable.""" + from adk_rlm.files.lazy import LazyFile + from adk_rlm.files.lazy import LazyFileCollection + + files = [ + LazyFile( + path="gs://test-bucket/file1.txt", source=gcs_source, parser=None + ), + LazyFile( + path="gs://test-bucket/file2.txt", source=gcs_source, parser=None + ), + ] + collection = LazyFileCollection(files) + + # Should be pickleable + pickled = pickle.dumps(collection) + unpickled = pickle.loads(pickled) + + assert len(unpickled) == 2 + assert unpickled[0].path == "gs://test-bucket/file1.txt" + assert unpickled[1].path == "gs://test-bucket/file2.txt" diff --git a/contributing/samples/rlm/tests/test_gcs_source.py b/contributing/samples/rlm/tests/test_gcs_source.py new file mode 100644 index 0000000000..098241dd1f --- /dev/null +++ b/contributing/samples/rlm/tests/test_gcs_source.py @@ -0,0 +1,643 @@ +""" +Unit tests for GCSFileSource with mocked GCS client. + +These tests do not require actual GCS access - all GCS operations are mocked. +""" + +from datetime import datetime +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +# Skip all tests if google-cloud-storage is not installed +pytest.importorskip("google.cloud.storage") + + +@pytest.fixture +def mock_storage(): + """Mock google.cloud.storage module.""" + with patch("adk_rlm.files.sources.gcs.storage") as mock_storage: + mock_client = MagicMock() + mock_storage.Client.return_value = mock_client + mock_storage.Client.from_service_account_json.return_value = mock_client + yield mock_storage, mock_client + + +@pytest.fixture +def gcs_source(mock_storage): + """Create GCSFileSource with mocked client.""" + from adk_rlm.files.sources.gcs import GCSFileSource + + _, mock_client = mock_storage + source = GCSFileSource(bucket="test-bucket") + source.client = mock_client + return source + + +class TestGCSFileSourceInit: + """Test GCSFileSource initialization.""" + + def test_init_with_default_credentials(self, mock_storage): + """Test initialization with Application Default Credentials.""" + from adk_rlm.files.sources.gcs import GCSFileSource + + mock_storage_mod, _ = mock_storage + source = GCSFileSource(bucket="test-bucket") + + assert source.default_bucket == "test-bucket" + assert source.source_type == "gcs" + mock_storage_mod.Client.assert_called_once() + + def test_init_with_service_account(self, mock_storage): + """Test initialization with service account JSON.""" + from adk_rlm.files.sources.gcs import GCSFileSource + + mock_storage_mod, _ = mock_storage + GCSFileSource( + bucket="test-bucket", + credentials_path="/path/to/sa.json", + project="my-project", + ) + + mock_storage_mod.Client.from_service_account_json.assert_called_once_with( + "/path/to/sa.json", project="my-project" + ) + + def test_init_with_custom_settings(self, mock_storage): + """Test initialization with custom timeout and retry config.""" + from adk_rlm.files.sources.gcs import GCSFileSource + from adk_rlm.files.sources.gcs import RetryConfig + + retry_config = RetryConfig(max_attempts=5, initial_delay=1.0) + source = GCSFileSource( + bucket="test-bucket", + timeout=120.0, + retry_config=retry_config, + max_concurrent=5, + large_file_threshold=50_000_000, + ) + + assert source.timeout == 120.0 + assert source.retry_config.max_attempts == 5 + assert source.max_concurrent == 5 + assert source.large_file_threshold == 50_000_000 + + +class TestPathParsing: + """Test GCS path parsing.""" + + def test_parse_path_with_gs_prefix(self, gcs_source): + """Test parsing gs:// URIs.""" + bucket, blob = gcs_source._parse_path("gs://my-bucket/path/to/file.pdf") + assert bucket == "my-bucket" + assert blob == "path/to/file.pdf" + + def test_parse_path_with_gs_prefix_root(self, gcs_source): + """Test parsing gs:// URI with file at root.""" + bucket, blob = gcs_source._parse_path("gs://my-bucket/file.pdf") + assert bucket == "my-bucket" + assert blob == "file.pdf" + + def test_parse_path_without_prefix(self, gcs_source): + """Test parsing paths without gs:// uses default bucket.""" + bucket, blob = gcs_source._parse_path("path/to/file.pdf") + assert bucket == "test-bucket" + assert blob == "path/to/file.pdf" + + def test_parse_path_no_bucket_raises(self, mock_storage): + """Test that missing bucket raises ValueError.""" + from adk_rlm.files.sources.gcs import GCSFileSource + + source = GCSFileSource() # No default bucket + with pytest.raises(ValueError, match="No bucket specified"): + source._parse_path("path/to/file.pdf") + + def test_parse_path_empty_blob_name(self, gcs_source): + """Test parsing gs:// with only bucket.""" + bucket, blob = gcs_source._parse_path("gs://my-bucket") + assert bucket == "my-bucket" + assert blob == "" + + +class TestResolve: + """Test path resolution including glob patterns.""" + + def test_resolve_single_file_exists(self, gcs_source): + """Test resolving a single file that exists.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.exists.return_value = True + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + result = gcs_source.resolve("gs://test-bucket/file.pdf") + + assert result == ["gs://test-bucket/file.pdf"] + mock_blob.exists.assert_called_once() + + def test_resolve_single_file_not_exists(self, gcs_source): + """Test resolving a file that doesn't exist.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.exists.return_value = False + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + result = gcs_source.resolve("gs://test-bucket/missing.pdf") + + assert result == [] + + def test_resolve_glob_pattern(self, gcs_source): + """Test resolving glob patterns.""" + mock_bucket = MagicMock() + mock_blob1 = MagicMock() + mock_blob1.name = "data/file1.pdf" + mock_blob2 = MagicMock() + mock_blob2.name = "data/file2.pdf" + mock_blob3 = MagicMock() + mock_blob3.name = "data/file.txt" + + mock_bucket.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3] + gcs_source.client.bucket.return_value = mock_bucket + + result = gcs_source.resolve("gs://test-bucket/data/*.pdf") + + assert result == [ + "gs://test-bucket/data/file1.pdf", + "gs://test-bucket/data/file2.pdf", + ] + mock_bucket.list_blobs.assert_called_once() + + def test_resolve_recursive_glob(self, gcs_source): + """Test resolving ** recursive patterns.""" + mock_bucket = MagicMock() + mock_blob1 = MagicMock() + mock_blob1.name = "data/2024/report.pdf" + mock_blob2 = MagicMock() + mock_blob2.name = "data/2023/report.pdf" + mock_blob3 = MagicMock() + mock_blob3.name = "data/readme.txt" + + mock_bucket.list_blobs.return_value = [mock_blob1, mock_blob2, mock_blob3] + gcs_source.client.bucket.return_value = mock_bucket + + result = gcs_source.resolve("gs://test-bucket/data/**/*.pdf") + + assert len(result) == 2 + assert "gs://test-bucket/data/2024/report.pdf" in result + assert "gs://test-bucket/data/2023/report.pdf" in result + + def test_resolve_glob_no_matches(self, gcs_source): + """Test glob pattern with no matches.""" + mock_bucket = MagicMock() + mock_bucket.list_blobs.return_value = [] + gcs_source.client.bucket.return_value = mock_bucket + + result = gcs_source.resolve("gs://test-bucket/data/*.pdf") + + assert result == [] + + +class TestGetMetadata: + """Test metadata fetching.""" + + def test_get_metadata(self, gcs_source): + """Test fetching metadata without downloading content.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.size = 1024 + mock_blob.content_type = "application/pdf" + mock_blob.updated = datetime(2024, 1, 15, 10, 30, 0) + mock_blob.content_encoding = None + mock_blob.storage_class = "STANDARD" + mock_blob.generation = 12345 + mock_blob.metageneration = 1 + mock_blob.etag = "abc123" + mock_blob.md5_hash = "xyz789" + mock_blob.crc32c = "crc123" + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + metadata = gcs_source.get_metadata("gs://test-bucket/report.pdf") + + assert metadata.name == "report.pdf" + assert metadata.size_bytes == 1024 + assert metadata.mime_type == "application/pdf" + assert metadata.source_type == "gcs" + assert metadata.extra["bucket"] == "test-bucket" + assert metadata.extra["storage_class"] == "STANDARD" + mock_blob.reload.assert_called_once() + + def test_get_metadata_guesses_mime_type(self, gcs_source): + """Test that MIME type is guessed when not provided by GCS.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.size = 500 + mock_blob.content_type = None # GCS didn't provide MIME type + mock_blob.updated = None + mock_blob.content_encoding = None + mock_blob.storage_class = "STANDARD" + mock_blob.generation = 1 + mock_blob.metageneration = 1 + mock_blob.etag = None + mock_blob.md5_hash = None + mock_blob.crc32c = None + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + metadata = gcs_source.get_metadata("gs://test-bucket/data.json") + + assert metadata.mime_type == "application/json" + + +class TestLoad: + """Test file loading.""" + + def test_load_small_file(self, gcs_source): + """Test loading a small file directly into memory.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.size = 1024 + mock_blob.content_type = "text/plain" + mock_blob.updated = datetime.now() + mock_blob.content_encoding = None + mock_blob.storage_class = "STANDARD" + mock_blob.generation = 1 + mock_blob.metageneration = 1 + mock_blob.etag = "abc" + mock_blob.md5_hash = None + mock_blob.crc32c = None + mock_blob.download_as_bytes.return_value = b"Hello, world!" + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + loaded = gcs_source.load("gs://test-bucket/file.txt") + + assert loaded.content == b"Hello, world!" + assert loaded.metadata.name == "file.txt" + assert loaded.metadata.source_type == "gcs" + + def test_load_uses_chunked_for_large_files(self, gcs_source): + """Test that large files use chunked loading strategy.""" + gcs_source.large_file_threshold = 1000 # Low threshold for testing + + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.size = 2000 # Above threshold + mock_blob.content_type = "application/octet-stream" + mock_blob.updated = None + mock_blob.content_encoding = None + mock_blob.storage_class = "STANDARD" + mock_blob.generation = 1 + mock_blob.metageneration = 1 + mock_blob.etag = None + mock_blob.md5_hash = None + mock_blob.crc32c = None + + # Mock the download_to_file to write to temp file + def download_side_effect(file_obj, timeout=None): + file_obj.write(b"large file content") + + mock_blob.download_to_file.side_effect = download_side_effect + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + loaded = gcs_source.load("gs://test-bucket/large.bin") + + assert loaded.content == b"large file content" + mock_blob.download_to_file.assert_called_once() + + +class TestLoadErrors: + """Test error handling during load.""" + + def test_load_not_found_raises(self, gcs_source): + """Test loading a nonexistent file raises FileNotFoundError.""" + from google.cloud.exceptions import NotFound + + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.reload.side_effect = NotFound("Not found") + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + with pytest.raises(FileNotFoundError): + gcs_source.load("gs://test-bucket/missing.txt") + + def test_load_permission_denied_raises(self, gcs_source): + """Test loading without permission raises PermissionError.""" + from google.cloud.exceptions import Forbidden + + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.reload.side_effect = Forbidden("Access denied") + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + with pytest.raises(PermissionError, match="Access denied"): + gcs_source.load("gs://test-bucket/secret.txt") + + +class TestRetryLogic: + """Test retry behavior for transient errors.""" + + def test_retry_on_transient_error(self, mock_storage): + """Test that transient errors trigger retries.""" + from adk_rlm.files.sources.gcs import GCSFileSource + from adk_rlm.files.sources.gcs import RetryConfig + + _, mock_client = mock_storage + source = GCSFileSource( + bucket="test-bucket", + retry_config=RetryConfig(max_attempts=3, initial_delay=0.01), + ) + source.client = mock_client + + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.size = 100 + mock_blob.content_type = "text/plain" + mock_blob.updated = None + mock_blob.content_encoding = None + mock_blob.storage_class = "STANDARD" + mock_blob.generation = 1 + mock_blob.metageneration = 1 + mock_blob.etag = None + mock_blob.md5_hash = None + mock_blob.crc32c = None + + # Fail twice, then succeed for all subsequent calls + call_count = 0 + + def reload_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise Exception("ServiceUnavailable") + return mock_blob + + mock_blob.reload.side_effect = reload_side_effect + mock_blob.download_as_bytes.return_value = b"data" + mock_bucket.blob.return_value = mock_blob + mock_client.bucket.return_value = mock_bucket + + loaded = source.load("gs://test-bucket/file.txt") + + assert loaded.content == b"data" + # load() calls get_metadata() first (which retries reload 3 times), + # then _load_direct() which does a best-effort reload after download + assert call_count >= 3 + + def test_max_retries_exceeded(self, mock_storage): + """Test that exceeding max retries raises error.""" + from adk_rlm.files.sources.gcs import GCSFileSource + from adk_rlm.files.sources.gcs import RetryConfig + + _, mock_client = mock_storage + source = GCSFileSource( + bucket="test-bucket", + retry_config=RetryConfig(max_attempts=2, initial_delay=0.01), + ) + source.client = mock_client + + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.reload.side_effect = Exception("ServiceUnavailable") + mock_bucket.blob.return_value = mock_blob + mock_client.bucket.return_value = mock_bucket + + with pytest.raises(RuntimeError, match="failed after 2 attempts"): + source.load("gs://test-bucket/file.txt") + + def test_non_retryable_error_not_retried(self, mock_storage): + """Test that non-retryable errors are not retried.""" + from adk_rlm.files.sources.gcs import GCSFileSource + from adk_rlm.files.sources.gcs import RetryConfig + + _, mock_client = mock_storage + source = GCSFileSource( + bucket="test-bucket", + retry_config=RetryConfig(max_attempts=3, initial_delay=0.01), + ) + source.client = mock_client + + mock_bucket = MagicMock() + mock_blob = MagicMock() + + call_count = 0 + + def reload_side_effect(*args, **kwargs): + nonlocal call_count + call_count += 1 + raise ValueError("Bad request - not retryable") + + mock_blob.reload.side_effect = reload_side_effect + mock_bucket.blob.return_value = mock_blob + mock_client.bucket.return_value = mock_bucket + + with pytest.raises(ValueError, match="Bad request"): + source.load("gs://test-bucket/file.txt") + + # Should only try once (no retries) + assert call_count == 1 + + +class TestLoadMany: + """Test parallel loading functionality.""" + + def test_load_many_empty_list(self, gcs_source): + """Test loading empty list.""" + results = list(gcs_source.load_many([])) + assert results == [] + + def test_load_many_single_file(self, gcs_source): + """Test loading single file doesn't use parallelism.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.size = 100 + mock_blob.content_type = "text/plain" + mock_blob.updated = None + mock_blob.content_encoding = None + mock_blob.storage_class = "STANDARD" + mock_blob.generation = 1 + mock_blob.metageneration = 1 + mock_blob.etag = None + mock_blob.md5_hash = None + mock_blob.crc32c = None + mock_blob.download_as_bytes.return_value = b"content" + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + results = list(gcs_source.load_many(["gs://test-bucket/file.txt"])) + + assert len(results) == 1 + assert results[0].content == b"content" + + def test_load_many_parallel(self, gcs_source): + """Test loading multiple files in parallel.""" + + def make_blob(name, content): + blob = MagicMock() + blob.size = len(content) + blob.content_type = "text/plain" + blob.updated = None + blob.content_encoding = None + blob.storage_class = "STANDARD" + blob.generation = 1 + blob.metageneration = 1 + blob.etag = None + blob.md5_hash = None + blob.crc32c = None + blob.download_as_bytes.return_value = content + return blob + + mock_bucket = MagicMock() + blobs = { + "file1.txt": make_blob("file1.txt", b"content1"), + "file2.txt": make_blob("file2.txt", b"content2"), + "file3.txt": make_blob("file3.txt", b"content3"), + } + + def get_blob(name): + blob_name = name.split("/")[-1] if "/" in name else name + return blobs.get(blob_name, MagicMock()) + + mock_bucket.blob.side_effect = get_blob + gcs_source.client.bucket.return_value = mock_bucket + + paths = [ + "gs://test-bucket/file1.txt", + "gs://test-bucket/file2.txt", + "gs://test-bucket/file3.txt", + ] + results = list(gcs_source.load_many(paths)) + + assert len(results) == 3 + contents = {r.content for r in results} + assert contents == {b"content1", b"content2", b"content3"} + + +class TestExists: + """Test existence checking.""" + + def test_exists_true(self, gcs_source): + """Test exists returns True for existing blob.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.exists.return_value = True + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + assert gcs_source.exists("gs://test-bucket/file.txt") is True + + def test_exists_false(self, gcs_source): + """Test exists returns False for missing blob.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.exists.return_value = False + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + assert gcs_source.exists("gs://test-bucket/missing.txt") is False + + def test_exists_handles_errors(self, gcs_source): + """Test exists returns False on errors.""" + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.exists.side_effect = Exception("Network error") + mock_bucket.blob.return_value = mock_blob + gcs_source.client.bucket.return_value = mock_bucket + + assert gcs_source.exists("gs://test-bucket/file.txt") is False + + +class TestLazyFileIntegration: + """Test GCSFileSource with LazyFile.""" + + def test_lazy_file_level_0(self, mock_storage): + """Test Level 0 access (name/extension) requires no I/O.""" + from adk_rlm.files.lazy import LazyFile + from adk_rlm.files.sources.gcs import GCSFileSource + + _, mock_client = mock_storage + source = GCSFileSource(bucket="test-bucket") + source.client = mock_client + + lazy = LazyFile(path="gs://test-bucket/data/report.pdf", source=source) + + # Level 0 - no I/O + assert lazy.name == "report.pdf" + assert lazy.extension == ".pdf" + assert lazy.level == 0 + + # Verify no GCS calls were made + mock_client.bucket.assert_not_called() + + def test_lazy_file_level_1(self, mock_storage): + """Test Level 1 access (metadata) triggers reload.""" + from adk_rlm.files.lazy import LazyFile + from adk_rlm.files.sources.gcs import GCSFileSource + + _, mock_client = mock_storage + source = GCSFileSource(bucket="test-bucket") + source.client = mock_client + + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.size = 2048 + mock_blob.content_type = "application/pdf" + mock_blob.updated = datetime.now() + mock_blob.content_encoding = None + mock_blob.storage_class = "STANDARD" + mock_blob.generation = 1 + mock_blob.metageneration = 1 + mock_blob.etag = "abc" + mock_blob.md5_hash = "xyz" + mock_blob.crc32c = "crc" + mock_bucket.blob.return_value = mock_blob + mock_client.bucket.return_value = mock_bucket + + lazy = LazyFile(path="gs://test-bucket/report.pdf", source=source) + + # Level 1 - triggers metadata fetch + size = lazy.size + + assert size == 2048 + assert lazy.level == 1 + mock_blob.reload.assert_called_once() + + +class TestFileLoaderIntegration: + """Test GCS source integration with FileLoader.""" + + def test_file_loader_gcs_detection(self, mock_storage): + """Test FileLoader auto-detects gs:// paths.""" + from adk_rlm.files.loader import FileLoader + from adk_rlm.files.sources.gcs import GCSFileSource + + _, mock_client = mock_storage + gcs_source = GCSFileSource(bucket="test-bucket") + gcs_source.client = mock_client + + # Mock resolve to return the path + mock_bucket = MagicMock() + mock_blob = MagicMock() + mock_blob.exists.return_value = True + mock_bucket.blob.return_value = mock_blob + mock_client.bucket.return_value = mock_bucket + + loader = FileLoader(sources={"gcs": gcs_source}) + collection = loader.create_lazy_files(["gs://test-bucket/file.txt"]) + + assert len(collection) == 1 + assert collection[0].name == "file.txt" + + def test_file_loader_gcs_not_configured_raises(self): + """Test FileLoader raises clear error when GCS not configured.""" + from adk_rlm.files.loader import FileLoader + + loader = FileLoader() + + with pytest.raises(ValueError, match="GCS source not configured"): + loader.create_lazy_files(["gs://some-bucket/file.txt"]) diff --git a/contributing/samples/rlm/tests/test_logger.py b/contributing/samples/rlm/tests/test_logger.py new file mode 100644 index 0000000000..0d632c4b57 --- /dev/null +++ b/contributing/samples/rlm/tests/test_logger.py @@ -0,0 +1,132 @@ +""" +Tests for JSONL logging. +""" + +import json +from pathlib import Path + +from adk_rlm.logging.rlm_logger import RLMLogger +from adk_rlm.types import CodeBlock +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMIteration +from adk_rlm.types import RLMMetadata +import pytest + + +class TestRLMLogger: + """Tests for RLMLogger.""" + + def test_creates_log_file(self, temp_log_dir): + """Logger creates log file.""" + logger = RLMLogger(temp_log_dir) + + assert Path(logger.log_file_path).parent.exists() + + def test_log_metadata(self, temp_log_dir): + """Log metadata as first entry.""" + logger = RLMLogger(temp_log_dir) + metadata = RLMMetadata( + root_model="gemini-pro", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={"model_name": "gemini-pro"}, + environment_type="local", + environment_kwargs={}, + ) + logger.log_metadata(metadata) + + # Read log file + with open(logger.log_file_path) as f: + lines = f.readlines() + + assert len(lines) == 1 + entry = json.loads(lines[0]) + assert entry["type"] == "metadata" + assert entry["root_model"] == "gemini-pro" + + def test_log_iteration(self, temp_log_dir): + """Log iteration.""" + logger = RLMLogger(temp_log_dir) + iteration = RLMIteration( + prompt="test prompt", + response="test response", + code_blocks=[], + final_answer=None, + iteration_time=1.0, + ) + logger.log(iteration) + + with open(logger.log_file_path) as f: + lines = f.readlines() + + assert len(lines) == 1 + entry = json.loads(lines[0]) + assert entry["type"] == "iteration" + assert entry["iteration"] == 1 + assert entry["response"] == "test response" + + def test_iteration_count(self, temp_log_dir): + """Iteration counter increments.""" + logger = RLMLogger(temp_log_dir) + + for i in range(3): + iteration = RLMIteration( + prompt="", response=f"response {i}", code_blocks=[] + ) + logger.log(iteration) + + assert logger.iteration_count == 3 + + def test_log_with_code_blocks(self, temp_log_dir): + """Log iteration with code blocks.""" + logger = RLMLogger(temp_log_dir) + + result = REPLResult( + stdout="42", stderr="", locals={"x": 42}, execution_time=0.1 + ) + code_block = CodeBlock(code="print(42)", result=result) + iteration = RLMIteration( + prompt="test", + response="Let me calculate", + code_blocks=[code_block], + iteration_time=0.5, + ) + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert len(entry["code_blocks"]) == 1 + assert entry["code_blocks"][0]["code"] == "print(42)" + assert entry["code_blocks"][0]["result"]["stdout"] == "42" + + def test_metadata_logged_once(self, temp_log_dir): + """Metadata only logged once.""" + logger = RLMLogger(temp_log_dir) + metadata = RLMMetadata( + root_model="gemini-pro", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={}, + environment_type="local", + environment_kwargs={}, + ) + + logger.log_metadata(metadata) + logger.log_metadata(metadata) # Second call should be ignored + + with open(logger.log_file_path) as f: + lines = f.readlines() + + metadata_entries = [l for l in lines if '"type": "metadata"' in l] + assert len(metadata_entries) == 1 + + def test_get_log_path(self, temp_log_dir): + """Get log path.""" + logger = RLMLogger(temp_log_dir) + path = logger.get_log_path() + + assert path == logger.log_file_path + assert temp_log_dir in path diff --git a/contributing/samples/rlm/tests/test_multi_turn.py b/contributing/samples/rlm/tests/test_multi_turn.py new file mode 100644 index 0000000000..0b5e19d641 --- /dev/null +++ b/contributing/samples/rlm/tests/test_multi_turn.py @@ -0,0 +1,244 @@ +""" +Tests for multi-turn persistence in ADK-RLM. + +These tests verify that: +1. REPL environments persist across calls +2. Contexts accumulate (context_0, context_1, ...) +3. Histories accumulate (history_0, history_1, ...) +4. Variables persist across calls +""" + +from adk_rlm.repl.local_repl import LocalREPL +import pytest + + +class TestLocalREPLMultiContext: + """Tests for multi-context support.""" + + def test_add_context_versioning(self, mock_llm_query): + """Add_context creates versioned variables.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.add_context("First", 0) + repl.add_context("Second", 1) + + assert repl.locals["context_0"] == "First" + assert repl.locals["context_1"] == "Second" + assert repl.locals["context"] == "First" # Alias to first + assert repl.get_context_count() == 2 + + def test_add_context_auto_increment(self, mock_llm_query): + """Add_context auto-increments when no index provided.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + idx1 = repl.add_context("First") + idx2 = repl.add_context("Second") + + assert idx1 == 0 + assert idx2 == 1 + assert repl.locals["context_0"] == "First" + assert repl.locals["context_1"] == "Second" + assert repl.get_context_count() == 2 + + def test_contexts_accessible_in_code(self, mock_llm_query): + """Multiple contexts can be accessed in code execution.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.add_context("Document A content") + repl.add_context("Document B content") + + result = repl.execute_code("combined = f'{context_0} + {context_1}'") + assert result.stderr == "" + assert repl.locals["combined"] == "Document A content + Document B content" + + def test_context_alias_points_to_first(self, mock_llm_query): + """'context' always aliases context_0.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.add_context("First") + repl.add_context("Second") + repl.add_context("Third") + + result = repl.execute_code("is_first = context == context_0") + assert result.stderr == "" + assert repl.locals["is_first"] is True + + +class TestLocalREPLHistory: + """Tests for message history storage.""" + + def test_add_history_basic(self, mock_llm_query): + """Add_history stores message history correctly.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + history = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + + index = repl.add_history(history) + + assert index == 0 + assert "history_0" in repl.locals + assert "history" in repl.locals + assert repl.locals["history_0"] == history + assert repl.locals["history"] == history + assert repl.get_history_count() == 1 + + def test_add_multiple_histories(self, mock_llm_query): + """Adding multiple conversation histories.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + history1 = [{"role": "user", "content": "First conversation"}] + history2 = [{"role": "user", "content": "Second conversation"}] + + repl.add_history(history1) + repl.add_history(history2) + + assert repl.get_history_count() == 2 + assert repl.locals["history_0"] == history1 + assert repl.locals["history_1"] == history2 + assert repl.locals["history"] == history1 # Alias stays on first + + def test_history_accessible_via_code(self, mock_llm_query): + """Stored history is accessible via code execution.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + history = [{"role": "user", "content": "Test message"}] + repl.add_history(history) + + result = repl.execute_code("msg = history[0]['content']") + assert result.stderr == "" + assert repl.locals["msg"] == "Test message" + + def test_history_is_copy(self, mock_llm_query): + """Stored history is a copy, not a reference.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + history = [{"role": "user", "content": "Original"}] + repl.add_history(history) + + # Modify original + history[0]["content"] = "Modified" + + # Stored copy should be unchanged + assert repl.locals["history_0"][0]["content"] == "Original" + + def test_can_iterate_histories_in_code(self, mock_llm_query): + """Iterating through multiple histories in code.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + repl.add_history([{"role": "user", "content": "Query 1"}]) + repl.add_history([{"role": "user", "content": "Query 2"}]) + repl.add_history([{"role": "user", "content": "Query 3"}]) + + code = """ +all_contents = [ + history_0[0]['content'], + history_1[0]['content'], + history_2[0]['content'], +] +""" + result = repl.execute_code(code) + assert result.stderr == "" + assert repl.locals["all_contents"] == ["Query 1", "Query 2", "Query 3"] + + +class TestLocalREPLPersistentState: + """Tests for state persistence across operations.""" + + def test_variables_persist_with_contexts(self, mock_llm_query): + """Variables and contexts coexist.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + repl.add_context("My context data") + repl.execute_code("summary = context.upper()") + assert repl.locals["summary"] == "MY CONTEXT DATA" + + repl.add_context("New context") + + # Previous variable should still exist + assert repl.locals["summary"] == "MY CONTEXT DATA" + assert repl.locals["context_1"] == "New context" + + def test_variables_persist_with_histories(self, mock_llm_query): + """Variables and histories coexist.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + repl.add_history([{"role": "user", "content": "Hello"}]) + repl.execute_code("extracted = history[0]['content']") + assert repl.locals["extracted"] == "Hello" + + repl.add_history([{"role": "user", "content": "World"}]) + + # Previous variable should still exist + assert repl.locals["extracted"] == "Hello" + assert repl.locals["history_1"][0]["content"] == "World" + + def test_full_persistent_session_simulation(self, mock_llm_query): + """Simulate a multi-turn persistent session.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + # Turn 1: Load first document + repl.add_context("Document: Sales were $1000") + repl.execute_code("sales = 1000") + + # Turn 2: Load second document, use previous variable + repl.add_context("Document: Costs were $400") + result = repl.execute_code("profit = sales - 400") + assert result.stderr == "" + assert repl.locals["profit"] == 600 + + # Turn 3: Store history and reference everything + repl.add_history([ + {"role": "user", "content": "What were the sales?"}, + {"role": "assistant", "content": "Sales were $1000"}, + ]) + + code = """ +summary = f"Sales: {context_0}, Costs: {context_1}, Profit: {profit}" +prev_question = history_0[0]['content'] +""" + result = repl.execute_code(code) + assert result.stderr == "" + assert "Profit: 600" in repl.locals["summary"] + assert repl.locals["prev_question"] == "What were the sales?" + + assert repl.get_context_count() == 2 + assert repl.get_history_count() == 1 + + +class TestNonPersistentBehavior: + """Tests simulating non-persistent RLM behavior.""" + + def test_simulated_non_persistent_completions(self, mock_llm_query): + """Simulate 2 RLM completions to show env resets between calls.""" + # Completion 1 + completion_1_env = LocalREPL(llm_query_fn=mock_llm_query) + completion_1_env.execute_code("important_result = 42") + assert completion_1_env.locals["important_result"] == 42 + completion_1_env.cleanup() + + # Completion 2 - fresh environment + completion_2_env = LocalREPL(llm_query_fn=mock_llm_query) + result = completion_2_env.execute_code("print(important_result)") + + assert "NameError" in result.stderr + assert "important_result" in result.stderr + completion_2_env.cleanup() + + def test_simulated_non_persistent_functions(self, mock_llm_query): + """Simulate 2 RLM completions to show functions don't persist.""" + # Completion 1 + completion_1_env = LocalREPL(llm_query_fn=mock_llm_query) + completion_1_env.execute_code("def my_helper(): return 'useful'") + assert ( + completion_1_env.execute_code("print(my_helper())").stdout.strip() + == "useful" + ) + completion_1_env.cleanup() + + # Completion 2 - fresh environment + completion_2_env = LocalREPL(llm_query_fn=mock_llm_query) + result = completion_2_env.execute_code("my_helper()") + + assert "NameError" in result.stderr + assert "my_helper" in result.stderr + completion_2_env.cleanup() diff --git a/contributing/samples/rlm/tests/test_parallel_batched.py b/contributing/samples/rlm/tests/test_parallel_batched.py new file mode 100644 index 0000000000..8b12d16c2b --- /dev/null +++ b/contributing/samples/rlm/tests/test_parallel_batched.py @@ -0,0 +1,326 @@ +""" +Tests for parallel batched queries (llm_query_batched with recursive=True). + +This module tests the parallel execution of recursive RLM child agents, +including batch metadata propagation and iteration linking. +""" + +import time +from typing import Any +from unittest.mock import MagicMock +from unittest.mock import patch + +from adk_rlm.code_executor import RLMCodeExecutor +from adk_rlm.events import RLMEventData +from adk_rlm.events import RLMEventType +import pytest + + +class TestParallelRecursiveBatchedExecution: + """Tests for parallel execution of recursive batched queries.""" + + def test_parallel_recursive_method_exists(self): + """Verify _run_parallel_recursive method exists.""" + executor = RLMCodeExecutor() + assert hasattr(executor, "_run_parallel_recursive") + + def test_run_recursive_rlm_accepts_batch_params(self): + """Verify _run_recursive_rlm accepts batch metadata parameters.""" + import inspect + + executor = RLMCodeExecutor() + sig = inspect.signature(executor._run_recursive_rlm) + params = list(sig.parameters.keys()) + + assert "parallel_batch_id" in params + assert "batch_index" in params + assert "batch_size" in params + + +class TestBatchMetadataEventData: + """Tests for batch metadata in RLMEventData.""" + + def test_event_data_has_batch_fields(self): + """Verify RLMEventData has parallel batch fields.""" + event_data = RLMEventData( + event_type=RLMEventType.ITERATION_START, + parallel_batch_id="test-batch-123", + batch_index=0, + batch_size=3, + ) + + assert event_data.parallel_batch_id == "test-batch-123" + assert event_data.batch_index == 0 + assert event_data.batch_size == 3 + + def test_event_data_to_dict_includes_batch_fields(self): + """Verify to_dict includes batch fields when set.""" + event_data = RLMEventData( + event_type=RLMEventType.ITERATION_START, + iteration=1, + parallel_batch_id="batch-abc", + batch_index=2, + batch_size=5, + ) + + data = event_data.to_dict() + assert data["parallel_batch_id"] == "batch-abc" + assert data["batch_index"] == 2 + assert data["batch_size"] == 5 + + def test_event_data_to_dict_excludes_none_batch_fields(self): + """Verify to_dict excludes batch fields when None.""" + event_data = RLMEventData( + event_type=RLMEventType.ITERATION_START, + iteration=1, + ) + + data = event_data.to_dict() + assert "parallel_batch_id" not in data + assert "batch_index" not in data + assert "batch_size" not in data + + +class TestLoggerBatchTracking: + """Tests for batch tracking in RLMLogger.""" + + def test_logger_accepts_batch_params(self): + """Verify logger.log accepts batch metadata parameters.""" + import inspect + + from adk_rlm.logging.rlm_logger import RLMLogger + + sig = inspect.signature(RLMLogger.log) + params = list(sig.parameters.keys()) + + assert "parent_iteration" in params + assert "parent_block_index" in params + assert "parallel_batch_id" in params + assert "batch_index" in params + assert "batch_size" in params + + def test_logger_writes_batch_metadata(self, temp_log_dir): + """Verify logger writes batch metadata to log file.""" + import json + + from adk_rlm.logging.rlm_logger import RLMLogger + from adk_rlm.types import RLMIteration + + logger = RLMLogger(log_dir=temp_log_dir) + iteration = RLMIteration( + prompt=[{"role": "user", "content": "test"}], + response="test response", + code_blocks=[], + ) + + logger.log( + iteration, + depth=1, + agent_name="rlm_agent_depth_1_0", + parent_agent="rlm_agent", + parent_iteration=2, + parent_block_index=0, + parallel_batch_id="batch-123", + batch_index=1, + batch_size=3, + ) + + with open(logger.get_log_path()) as f: + entry = json.loads(f.readline()) + + assert entry["parent_iteration"] == 2 + assert entry["parent_block_index"] == 0 + assert entry["parallel_batch_id"] == "batch-123" + assert entry["batch_index"] == 1 + assert entry["batch_size"] == 3 + + +class TestParallelExecutionBehavior: + """Tests for actual parallel execution behavior.""" + + def test_parallel_batched_preserves_order(self): + """Verify results are returned in original prompt order.""" + executor = RLMCodeExecutor( + current_depth=0, + max_depth=1, + ) + + # Mock _run_recursive_rlm to return predictable results with delays + call_order = [] + original_run = executor._run_recursive_rlm + + def mock_run(prompt, model, context_obj=None, **kwargs): + call_order.append(prompt) + # Simulate varying execution times + if "Q1" in prompt: + time.sleep(0.05) + return f"Result for {prompt}" + + executor._run_recursive_rlm = mock_run + + prompts = ["Q1", "Q2", "Q3"] + results = executor._run_parallel_recursive(prompts, None, "test-model") + + # Results should be in original order + assert results[0] == "Result for Q1" + assert results[1] == "Result for Q2" + assert results[2] == "Result for Q3" + + def test_parallel_batched_generates_batch_id(self): + """Verify parallel execution generates a batch ID.""" + executor = RLMCodeExecutor( + current_depth=0, + max_depth=1, + ) + + batch_ids_seen = [] + + def mock_run( + prompt, model, context_obj=None, parallel_batch_id=None, **kwargs + ): + batch_ids_seen.append(parallel_batch_id) + return f"Result for {prompt}" + + executor._run_recursive_rlm = mock_run + + prompts = ["Q1", "Q2", "Q3"] + executor._run_parallel_recursive(prompts, None, "test-model") + + # All calls should have the same batch ID + assert len(batch_ids_seen) == 3 + assert batch_ids_seen[0] is not None + assert batch_ids_seen[0] == batch_ids_seen[1] == batch_ids_seen[2] + # Batch ID should be a valid UUID format + import uuid + + uuid.UUID(batch_ids_seen[0]) # Will raise if invalid + + def test_parallel_batched_passes_batch_index(self): + """Verify parallel execution passes correct batch indices.""" + executor = RLMCodeExecutor( + current_depth=0, + max_depth=1, + ) + + indices_seen = {} + + def mock_run(prompt, model, context_obj=None, batch_index=None, **kwargs): + indices_seen[prompt] = batch_index + return f"Result for {prompt}" + + executor._run_recursive_rlm = mock_run + + prompts = ["Q1", "Q2", "Q3"] + executor._run_parallel_recursive(prompts, None, "test-model") + + assert indices_seen["Q1"] == 0 + assert indices_seen["Q2"] == 1 + assert indices_seen["Q3"] == 2 + + def test_parallel_batched_passes_batch_size(self): + """Verify parallel execution passes correct batch size.""" + executor = RLMCodeExecutor( + current_depth=0, + max_depth=1, + ) + + sizes_seen = [] + + def mock_run(prompt, model, context_obj=None, batch_size=None, **kwargs): + sizes_seen.append(batch_size) + return f"Result for {prompt}" + + executor._run_recursive_rlm = mock_run + + prompts = ["Q1", "Q2", "Q3", "Q4"] + executor._run_parallel_recursive(prompts, None, "test-model") + + assert all(s == 4 for s in sizes_seen) + + def test_parallel_batched_passes_contexts(self): + """Verify contexts are passed correctly to each child.""" + executor = RLMCodeExecutor( + current_depth=0, + max_depth=1, + ) + + contexts_seen = {} + + def mock_run(prompt, model, context_obj=None, **kwargs): + contexts_seen[prompt] = context_obj + return f"Result for {prompt}" + + executor._run_recursive_rlm = mock_run + + prompts = ["Q1", "Q2", "Q3"] + contexts = ["Context A", "Context B", "Context C"] + executor._run_parallel_recursive(prompts, contexts, "test-model") + + assert contexts_seen["Q1"] == "Context A" + assert contexts_seen["Q2"] == "Context B" + assert contexts_seen["Q3"] == "Context C" + + def test_parallel_batched_handles_exceptions(self): + """Verify parallel execution handles individual failures gracefully.""" + executor = RLMCodeExecutor( + current_depth=0, + max_depth=1, + ) + + def mock_run(prompt, model, context_obj=None, **kwargs): + if "Q2" in prompt: + raise ValueError("Simulated failure") + return f"Result for {prompt}" + + executor._run_recursive_rlm = mock_run + + prompts = ["Q1", "Q2", "Q3"] + results = executor._run_parallel_recursive(prompts, None, "test-model") + + assert results[0] == "Result for Q1" + assert "Error" in results[1] + assert "Simulated failure" in results[1] + assert results[2] == "Result for Q3" + + +class TestIterationLinking: + """Tests for linking child agents to spawning iteration.""" + + def test_ancestry_includes_iteration_and_block(self): + """Verify ancestry entry includes iteration and block_index.""" + executor = RLMCodeExecutor( + parent_agent="rlm_agent", + ) + + executor.set_iteration_context(iteration=3, block_index=1) + entry = executor._get_current_ancestry_entry() + + assert entry["agent"] == "rlm_agent" + assert entry["iteration"] == 3 + assert entry["block_index"] == 1 + + def test_child_ancestry_chain_preserved(self): + """Verify child agents receive full ancestry chain.""" + parent_ancestry = [ + {"agent": "rlm_agent", "depth": 0, "iteration": 1, "block_index": 0} + ] + + executor = RLMCodeExecutor( + parent_agent="rlm_agent_depth_1_0", + ancestry=parent_ancestry, + current_depth=1, + ) + + executor.set_iteration_context(iteration=2, block_index=0) + entry = executor._get_current_ancestry_entry() + + # Current entry should reflect current agent's context + assert entry["agent"] == "rlm_agent_depth_1_0" + assert entry["iteration"] == 2 + assert entry["block_index"] == 0 + + # Full ancestry should include parent + current + full_ancestry = executor._ancestry + [entry] + assert len(full_ancestry) == 2 + assert full_ancestry[0]["agent"] == "rlm_agent" + assert full_ancestry[1]["agent"] == "rlm_agent_depth_1_0" diff --git a/contributing/samples/rlm/tests/test_parsing.py b/contributing/samples/rlm/tests/test_parsing.py new file mode 100644 index 0000000000..d327c3526e --- /dev/null +++ b/contributing/samples/rlm/tests/test_parsing.py @@ -0,0 +1,265 @@ +""" +Tests for code block parsing and final answer detection. +""" + +from adk_rlm.callbacks.code_execution import find_code_blocks +from adk_rlm.callbacks.code_execution import find_final_answer +from adk_rlm.callbacks.code_execution import format_execution_result +from adk_rlm.callbacks.code_execution import format_iteration +from adk_rlm.repl.local_repl import LocalREPL +from adk_rlm.types import CodeBlock +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMIteration +import pytest + + +class TestFindCodeBlocks: + """Tests for find_code_blocks function.""" + + def test_find_single_code_block(self): + """One repl block.""" + text = """ +Let me analyze this: +```repl +x = 42 +print(x) +``` +That should work. +""" + blocks = find_code_blocks(text) + assert len(blocks) == 1 + assert "x = 42" in blocks[0] + assert "print(x)" in blocks[0] + + def test_find_multiple_code_blocks(self): + """Multiple repl blocks.""" + text = """ +First block: +```repl +x = 1 +``` + +Second block: +```repl +y = 2 +``` +""" + blocks = find_code_blocks(text) + assert len(blocks) == 2 + assert "x = 1" in blocks[0] + assert "y = 2" in blocks[1] + + def test_find_no_code_blocks(self): + """No repl blocks.""" + text = "This is just plain text without any code blocks." + blocks = find_code_blocks(text) + assert blocks == [] + + def test_ignore_other_languages(self): + """Only extracts ```repl blocks.""" + text = """ +```python +x = 1 +``` + +```repl +y = 2 +``` + +```javascript +z = 3 +``` +""" + blocks = find_code_blocks(text) + assert len(blocks) == 1 + assert "y = 2" in blocks[0] + + def test_preserve_indentation(self): + """Code with indentation is preserved.""" + text = """ +```repl +def foo(): + return 42 + +result = foo() +``` +""" + blocks = find_code_blocks(text) + assert len(blocks) == 1 + assert " return 42" in blocks[0] + + def test_preserve_blank_lines(self): + """Code with blank lines.""" + text = """ +```repl +x = 1 + +y = 2 +``` +""" + blocks = find_code_blocks(text) + assert len(blocks) == 1 + assert "\n\n" in blocks[0] or blocks[0].count("\n") >= 2 + + +class TestFindFinalAnswer: + """Tests for find_final_answer function.""" + + def test_find_final_simple(self): + """FINAL(answer).""" + text = "Based on my analysis:\nFINAL(The answer is 42)" + answer = find_final_answer(text) + assert answer == "The answer is 42" + + def test_find_final_multiword(self): + """FINAL with multiple words.""" + text = "FINAL(This is a longer answer with multiple words)" + answer = find_final_answer(text) + assert answer == "This is a longer answer with multiple words" + + def test_find_final_with_quotes(self): + """FINAL with quoted content.""" + text = 'FINAL("quoted answer")' + answer = find_final_answer(text) + assert '"quoted answer"' in answer or "quoted answer" in answer + + def test_no_final_answer(self): + """No FINAL pattern.""" + text = "This is just regular text without any final answer." + answer = find_final_answer(text) + assert answer is None + + def test_final_at_line_start(self): + """FINAL must be at line start.""" + text = "Some text FINAL(not a final)" + answer = find_final_answer(text) + # Should not match if not at line start + assert answer is None + + def test_final_with_leading_whitespace(self): + """FINAL with leading whitespace is allowed.""" + text = "Some intro\n FINAL(The answer)" + answer = find_final_answer(text) + assert answer == "The answer" + + def test_find_final_var(self, mock_llm_query): + """FINAL_VAR resolves variable.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.execute_code("my_result = 'Computed answer'") + + text = "After computation:\nFINAL_VAR(my_result)" + answer = find_final_answer(text, repl) + assert answer == "Computed answer" + + def test_find_final_var_quoted(self, mock_llm_query): + """FINAL_VAR with quoted variable name.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.execute_code("my_result = 'Answer here'") + + text = 'FINAL_VAR("my_result")' + answer = find_final_answer(text, repl) + assert answer == "Answer here" + + def test_find_final_var_missing(self, mock_llm_query): + """FINAL_VAR with missing variable returns None (no final answer).""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + text = "FINAL_VAR(nonexistent)" + answer = find_final_answer(text, repl) + assert answer is None + + def test_multiple_finals_returns_first(self): + """Multiple FINAL patterns returns first.""" + text = "FINAL(first)\nFINAL(second)" + answer = find_final_answer(text) + assert answer == "first" + + def test_final_var_takes_precedence(self, mock_llm_query): + """FINAL_VAR is checked before FINAL.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.execute_code("var = 'from variable'") + + text = "FINAL_VAR(var)\nFINAL(direct)" + answer = find_final_answer(text, repl) + assert answer == "from variable" + + +class TestFormatExecutionResult: + """Tests for format_execution_result function.""" + + def test_format_with_stdout(self): + """Format result with stdout.""" + result = REPLResult( + stdout="Hello, World!", stderr="", locals={"x": 42}, execution_time=0.1 + ) + formatted = format_execution_result(result) + assert "Hello, World!" in formatted + + def test_format_with_stderr(self): + """Format result with stderr.""" + result = REPLResult( + stdout="", stderr="Error occurred", locals={}, execution_time=0.1 + ) + formatted = format_execution_result(result) + assert "Error occurred" in formatted + + def test_format_with_variables(self): + """Format result showing variables.""" + result = REPLResult( + stdout="", stderr="", locals={"x": 42, "y": "hello"}, execution_time=0.1 + ) + formatted = format_execution_result(result) + assert "REPL variables" in formatted or "x" in formatted or "y" in formatted + + def test_format_no_output(self): + """Format result with no output.""" + result = REPLResult(stdout="", stderr="", locals={}, execution_time=0.1) + formatted = format_execution_result(result) + assert formatted # Should return something + + +class TestFormatIteration: + """Tests for format_iteration function.""" + + def test_format_iteration_response(self): + """Format assistant response.""" + iteration = RLMIteration( + prompt="test", response="This is my response.", code_blocks=[] + ) + messages = format_iteration(iteration) + + assert len(messages) == 1 + assert messages[0]["role"] == "assistant" + assert "This is my response." in messages[0]["content"] + + def test_format_iteration_with_code(self): + """Format iteration with code execution.""" + result = REPLResult( + stdout="42", stderr="", locals={"x": 42}, execution_time=0.1 + ) + code_block = CodeBlock(code="print(42)", result=result) + iteration = RLMIteration( + prompt="test", response="Let me calculate:", code_blocks=[code_block] + ) + messages = format_iteration(iteration) + + assert len(messages) == 2 + assert messages[0]["role"] == "assistant" + assert messages[1]["role"] == "user" + assert "Code executed" in messages[1]["content"] + + def test_truncate_long_output(self): + """Long output is truncated.""" + long_output = "x" * 30000 + result = REPLResult( + stdout=long_output, stderr="", locals={}, execution_time=0.1 + ) + code_block = CodeBlock(code="print('x' * 30000)", result=result) + iteration = RLMIteration( + prompt="test", response="Computing...", code_blocks=[code_block] + ) + + messages = format_iteration(iteration, max_character_length=1000) + + # Should be truncated + assert len(messages[1]["content"]) < 25000 diff --git a/contributing/samples/rlm/tests/test_prompts.py b/contributing/samples/rlm/tests/test_prompts.py new file mode 100644 index 0000000000..767cb8da9a --- /dev/null +++ b/contributing/samples/rlm/tests/test_prompts.py @@ -0,0 +1,93 @@ +""" +Tests for prompt building. +""" + +from adk_rlm.prompts import build_rlm_system_prompt +from adk_rlm.prompts import build_user_prompt +from adk_rlm.prompts import RLM_SYSTEM_PROMPT +from adk_rlm.types import QueryMetadata +import pytest + + +class TestBuildRLMSystemPrompt: + """Tests for build_rlm_system_prompt.""" + + def test_build_system_prompt(self): + """Build with metadata.""" + metadata = QueryMetadata("test context") + messages = build_rlm_system_prompt(RLM_SYSTEM_PROMPT, metadata) + + assert len(messages) == 2 + assert messages[0]["role"] == "system" + assert messages[1]["role"] == "assistant" + assert "context" in messages[1]["content"].lower() + + def test_includes_context_length(self): + """Includes context length info.""" + metadata = QueryMetadata("x" * 1000) + messages = build_rlm_system_prompt(RLM_SYSTEM_PROMPT, metadata) + + assert "1000" in messages[1]["content"] + + def test_custom_system_prompt(self): + """Uses custom system prompt.""" + custom_prompt = "You are a custom assistant." + metadata = QueryMetadata("test") + messages = build_rlm_system_prompt(custom_prompt, metadata) + + assert messages[0]["content"] == custom_prompt + + +class TestBuildUserPrompt: + """Tests for build_user_prompt.""" + + def test_first_iteration(self): + """First iteration prompt.""" + prompt = build_user_prompt(root_prompt=None, iteration=0) + + assert prompt["role"] == "user" + assert ( + "haven't seen" in prompt["content"].lower() + or "not interacted" in prompt["content"].lower() + ) + + def test_subsequent_iteration(self): + """Later iteration prompt.""" + prompt = build_user_prompt(root_prompt=None, iteration=1) + + assert prompt["role"] == "user" + assert "history" in prompt["content"].lower() + + def test_with_root_prompt(self): + """Include root prompt.""" + prompt = build_user_prompt(root_prompt="What is the answer?", iteration=0) + + assert "What is the answer?" in prompt["content"] + + def test_multiple_contexts(self): + """Notes multiple contexts.""" + prompt = build_user_prompt(root_prompt=None, iteration=1, context_count=3) + + assert "3 contexts" in prompt["content"] + assert "context_0" in prompt["content"] + assert "context_2" in prompt["content"] + + def test_with_histories(self): + """Notes prior histories.""" + prompt = build_user_prompt( + root_prompt=None, iteration=1, context_count=1, history_count=2 + ) + + assert "2 prior conversation histories" in prompt["content"] + assert "history_0" in prompt["content"] + + def test_single_history(self): + """Notes single history differently.""" + prompt = build_user_prompt( + root_prompt=None, iteration=1, context_count=1, history_count=1 + ) + + assert "1 prior conversation history" in prompt["content"] + assert ( + "history_0" not in prompt["content"] + ) # Just mentions `history` variable diff --git a/contributing/samples/rlm/tests/test_repl.py b/contributing/samples/rlm/tests/test_repl.py new file mode 100644 index 0000000000..1433f3b667 --- /dev/null +++ b/contributing/samples/rlm/tests/test_repl.py @@ -0,0 +1,368 @@ +""" +Tests for the LocalREPL environment. +""" + +from adk_rlm.repl.local_repl import LocalREPL +import pytest + + +class TestBasicExecution: + """Tests for basic code execution.""" + + def test_basic_execution(self, mock_llm_query): + """Execute simple Python code.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code("x = 42") + + assert result.stderr == "" + assert "x" in result.locals + assert result.locals["x"] == 42 + + def test_print_capture(self, mock_llm_query): + """Capture print() output.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code("print('Hello, World!')") + + assert "Hello, World!" in result.stdout + assert result.stderr == "" + + def test_stderr_capture(self, mock_llm_query): + """Capture exceptions in stderr.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code("raise ValueError('test error')") + + assert "ValueError" in result.stderr + assert "test error" in result.stderr + + def test_variable_persistence(self, mock_llm_query): + """Variables persist across executions.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.execute_code("x = 42") + result = repl.execute_code("y = x * 2\nprint(y)") + + assert result.stdout.strip() == "84" + assert repl.locals["x"] == 42 + assert repl.locals["y"] == 84 + + def test_multiline_code(self, mock_llm_query): + """Execute multiline code blocks.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + code = """ +def factorial(n): + if n <= 1: + return 1 + return n * factorial(n - 1) + +result = factorial(5) +print(result) +""" + result = repl.execute_code(code) + + assert result.stdout.strip() == "120" + assert result.stderr == "" + + def test_execution_timing(self, mock_llm_query): + """Execution time is tracked.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code("x = 1 + 1") + + assert result.execution_time > 0 + assert result.execution_time < 1.0 # Should be fast + + def test_syntax_error_handling(self, mock_llm_query): + """Handle syntax errors gracefully.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code("if True print('bad')") + + assert "SyntaxError" in result.stderr + + def test_runtime_error_handling(self, mock_llm_query): + """Handle runtime errors gracefully.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code("x = 1 / 0") + + assert "ZeroDivisionError" in result.stderr + + +class TestContextLoading: + """Tests for context loading functionality.""" + + def test_context_loading_string(self, mock_llm_query, sample_context): + """Load string context.""" + repl = LocalREPL( + llm_query_fn=mock_llm_query, context_payload=sample_context + ) + result = repl.execute_code("print(type(context).__name__)") + + assert "str" in result.stdout + assert "context" in repl.locals + assert repl.locals["context"] == sample_context + + def test_context_loading_dict(self, mock_llm_query, sample_context_dict): + """Load dict context.""" + repl = LocalREPL( + llm_query_fn=mock_llm_query, context_payload=sample_context_dict + ) + result = repl.execute_code("print(context['title'])") + + assert "Test Document" in result.stdout + + def test_context_loading_list(self, mock_llm_query, sample_context_list): + """Load list context.""" + repl = LocalREPL( + llm_query_fn=mock_llm_query, context_payload=sample_context_list + ) + result = repl.execute_code("print(len(context))") + + assert "3" in result.stdout + + def test_multiple_contexts(self, mock_llm_query): + """Add multiple contexts.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.add_context("First context", 0) + repl.add_context("Second context", 1) + + assert repl.get_context_count() == 2 + assert "context_0" in repl.locals + assert "context_1" in repl.locals + assert "context" in repl.locals + assert repl.locals["context"] == repl.locals["context_0"] + + def test_context_loading_non_serializable_object(self, mock_llm_query): + """Load non-JSON-serializable object (e.g., custom class).""" + + class CustomObject: + + def __init__(self, value): + self.value = value + + def get_value(self): + return self.value + + obj = CustomObject(42) + repl = LocalREPL(llm_query_fn=mock_llm_query, context_payload=obj) + + # Object should be accessible in the REPL + assert "context" in repl.locals + assert repl.locals["context"] is obj + + # Should be able to call methods on it + result = repl.execute_code("result = context.get_value()\nprint(result)") + assert "42" in result.stdout + + def test_context_loading_lazy_file_collection(self, mock_llm_query, tmp_path): + """Load LazyFileCollection (the actual use case that was failing).""" + from adk_rlm.files import FileLoader + + # Create test files + (tmp_path / "test1.txt").write_text("Content 1") + (tmp_path / "test2.txt").write_text("Content 2") + + # Create lazy file collection + loader = FileLoader(base_path=tmp_path) + files = loader.create_lazy_files(["*.txt"]) + + # Load into REPL - this was failing before the fix + repl = LocalREPL(llm_query_fn=mock_llm_query, context_payload=files) + + # Collection should be accessible + assert "context" in repl.locals + assert repl.locals["context"] is files + + # Should be able to use collection methods + result = repl.execute_code("print(len(context))") + assert "2" in result.stdout + + result = repl.execute_code("print(context.names)") + assert "test1.txt" in result.stdout or "test2.txt" in result.stdout + + def test_context_loading_dict_with_lazy_files(self, mock_llm_query, tmp_path): + """Load dict containing LazyFileCollection (RLM build_context format).""" + from adk_rlm.files import FileLoader + + # Create test files + (tmp_path / "doc.txt").write_text("Document content") + + # Create context dict like build_context does + loader = FileLoader(base_path=tmp_path) + files = loader.create_lazy_files(["*.txt"]) + context_dict = { + "files": files, + "file_count": len(files), + "file_names": files.names, + } + + # Load into REPL + repl = LocalREPL(llm_query_fn=mock_llm_query, context_payload=context_dict) + + # Context should be accessible + assert "context" in repl.locals + + # Should be able to access the files + result = repl.execute_code("print(context['file_count'])") + assert "1" in result.stdout + + result = repl.execute_code("print(context['files'].names)") + assert "doc.txt" in result.stdout + + +class TestSafeBuiltins: + """Tests for safe builtins.""" + + def test_safe_builtins_available(self, mock_llm_query): + """Allowed builtins work.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + # Test various safe builtins + result = repl.execute_code("print(len([1, 2, 3]))") + assert "3" in result.stdout + + result = repl.execute_code("print(str(42))") + assert "42" in result.stdout + + result = repl.execute_code("print(list(range(3)))") + assert "[0, 1, 2]" in result.stdout + + def test_dangerous_builtins_blocked(self, mock_llm_query): + """Blocked builtins raise errors.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + result = repl.execute_code("eval('1 + 1')") + assert ( + "Error" in result.stderr + or "None" in result.stderr + or "TypeError" in result.stderr + ) + + result = repl.execute_code("exec('x = 1')") + assert ( + "Error" in result.stderr + or "None" in result.stderr + or "TypeError" in result.stderr + ) + + def test_import_allowed_modules(self, mock_llm_query): + """Can import safe modules.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + + result = repl.execute_code("import re\nprint(re.match('a', 'abc'))") + assert result.stderr == "" or "Error" not in result.stderr + + result = repl.execute_code("import json\nprint(json.dumps({'a': 1}))") + assert result.stderr == "" or "Error" not in result.stderr + + +class TestLLMQuery: + """Tests for LLM query functionality.""" + + def test_llm_query_basic(self, mock_llm_query): + """Basic llm_query call works.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code( + "response = llm_query('What is 2+2?')\nprint(response)" + ) + + assert "Mock response" in result.stdout + assert result.stderr == "" + + def test_llm_query_batched(self, mock_llm_query, mock_llm_query_batched): + """Batched llm_query works.""" + repl = LocalREPL( + llm_query_fn=mock_llm_query, llm_query_batched_fn=mock_llm_query_batched + ) + result = repl.execute_code( + "responses = llm_query_batched(['Q1?', 'Q2?'])\nprint(len(responses))" + ) + + assert "2" in result.stdout + assert result.stderr == "" + + def test_llm_query_no_function(self): + """Error when no llm_query function configured.""" + repl = LocalREPL() + result = repl.execute_code("response = llm_query('test')\nprint(response)") + + assert "Error" in result.stdout + + +class TestFinalVar: + """Tests for FINAL_VAR functionality.""" + + def test_final_var_string(self, mock_llm_query): + """FINAL_VAR resolves string variable.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.execute_code("my_answer = 'The answer is 42'") + result = repl.execute_code("print(FINAL_VAR('my_answer'))") + + assert "The answer is 42" in result.stdout + + def test_final_var_missing(self, mock_llm_query): + """FINAL_VAR with missing variable returns error.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + result = repl.execute_code("print(FINAL_VAR('nonexistent'))") + + assert "Error" in result.stdout + + +class TestHistoryManagement: + """Tests for history management.""" + + def test_add_history(self, mock_llm_query): + """Add conversation history.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + history = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + repl.add_history(history) + + assert repl.get_history_count() == 1 + assert "history_0" in repl.locals + assert "history" in repl.locals + + def test_multiple_histories(self, mock_llm_query): + """Add multiple histories.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.add_history([{"role": "user", "content": "First"}]) + repl.add_history([{"role": "user", "content": "Second"}]) + + assert repl.get_history_count() == 2 + assert "history_0" in repl.locals + assert "history_1" in repl.locals + + +class TestCleanup: + """Tests for cleanup functionality.""" + + def test_cleanup(self, mock_llm_query): + """Cleanup removes temp files and resets state.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.execute_code("x = 42") + repl.cleanup() + + assert len(repl.globals) == 0 + assert len(repl.locals) == 0 + + def test_context_manager(self, mock_llm_query): + """Context manager cleans up on exit.""" + with LocalREPL(llm_query_fn=mock_llm_query) as repl: + repl.execute_code("x = 42") + assert "x" in repl.locals + + # After exit, cleanup should have been called + assert len(repl.locals) == 0 + + def test_reset(self, mock_llm_query): + """Reset clears state but keeps the REPL usable.""" + repl = LocalREPL(llm_query_fn=mock_llm_query) + repl.execute_code("x = 42") + repl.add_context("test context") + repl.reset() + + assert "x" not in repl.locals + assert repl.get_context_count() == 0 + assert repl.get_history_count() == 0 + + # Should still be usable + result = repl.execute_code("y = 100\nprint(y)") + assert "100" in result.stdout diff --git a/contributing/samples/rlm/tests/test_simple_llm_events.py b/contributing/samples/rlm/tests/test_simple_llm_events.py new file mode 100644 index 0000000000..b138c9ee45 --- /dev/null +++ b/contributing/samples/rlm/tests/test_simple_llm_events.py @@ -0,0 +1,399 @@ +""" +Tests for non-recursive (simple) LLM call events and logging. + +These tests verify that when llm_query() or llm_query_batched() is called with +recursive=False, proper events are emitted and calls are logged. +""" + +import json +from unittest.mock import MagicMock +from unittest.mock import patch + +from adk_rlm.code_executor import RLMCodeExecutor +from adk_rlm.events import RLMEventType +from adk_rlm.logging.rlm_logger import RLMLogger +import pytest + + +class TestLoggerSimpleLLMCall: + """Tests for RLMLogger.log_simple_llm_call method.""" + + def test_log_simple_llm_call_success(self, temp_log_dir): + """Log a successful simple LLM call.""" + logger = RLMLogger(temp_log_dir) + logger.log_simple_llm_call( + prompt="What is 2+2?", + response="The answer is 4.", + model="gemini-3-flash-preview", + execution_time_ms=150.5, + depth=0, + agent_name="rlm_agent", + ) + + with open(logger.get_log_path()) as f: + entry = json.loads(f.readline()) + + assert entry["type"] == "simple_llm_call" + assert entry["prompt"] == "What is 2+2?" + assert entry["response"] == "The answer is 4." + assert entry["model"] == "gemini-3-flash-preview" + assert entry["execution_time_ms"] == 150.5 + assert entry["depth"] == 0 + assert entry["agent_name"] == "rlm_agent" + assert entry["recursive"] is False + assert entry["success"] is True + assert "error" not in entry + + def test_log_simple_llm_call_failure(self, temp_log_dir): + """Log a failed simple LLM call.""" + logger = RLMLogger(temp_log_dir) + logger.log_simple_llm_call( + prompt="What is 2+2?", + response="Error: LLM query failed - Connection timeout", + model="gemini-3-flash-preview", + execution_time_ms=5000.0, + error="Connection timeout", + ) + + with open(logger.get_log_path()) as f: + entry = json.loads(f.readline()) + + assert entry["type"] == "simple_llm_call" + assert entry["success"] is False + assert entry["error"] == "Connection timeout" + assert "Error:" in entry["response"] + + def test_log_simple_llm_call_with_batch_metadata(self, temp_log_dir): + """Log a simple LLM call with batch metadata.""" + logger = RLMLogger(temp_log_dir) + logger.log_simple_llm_call( + prompt="Query 1", + response="Response 1", + model="gemini-3-flash-preview", + execution_time_ms=100.0, + batch_index=0, + batch_size=3, + ) + + with open(logger.get_log_path()) as f: + entry = json.loads(f.readline()) + + assert entry["batch_index"] == 0 + assert entry["batch_size"] == 3 + + def test_log_simple_llm_call_with_parent_context(self, temp_log_dir): + """Log a simple LLM call with parent iteration context.""" + logger = RLMLogger(temp_log_dir) + logger.log_simple_llm_call( + prompt="Sub query", + response="Sub response", + model="gemini-3-flash-preview", + execution_time_ms=100.0, + parent_iteration=2, + parent_block_index=1, + ) + + with open(logger.get_log_path()) as f: + entry = json.loads(f.readline()) + + assert entry["parent_iteration"] == 2 + assert entry["parent_block_index"] == 1 + + def test_log_simple_llm_call_truncates_long_prompts(self, temp_log_dir): + """Long prompts are truncated in summary but preserved in full.""" + logger = RLMLogger(temp_log_dir) + long_prompt = "x" * 1000 + logger.log_simple_llm_call( + prompt=long_prompt, + response="Short response", + model="test-model", + execution_time_ms=100.0, + ) + + with open(logger.get_log_path()) as f: + entry = json.loads(f.readline()) + + assert len(entry["prompt"]) == 500 + assert len(entry["prompt_full"]) == 1000 + + +class TestCodeExecutorEmitSubLLMEvent: + """Tests for RLMCodeExecutor._emit_sub_llm_event method.""" + + def test_emit_sub_llm_start_event(self): + """Emit SUB_LLM_START event.""" + executor = RLMCodeExecutor( + sub_model="gemini-3-flash-preview", + current_depth=0, + max_depth=5, + parent_agent="rlm_agent", + ) + executor._current_iteration = 2 + executor._current_block_index = 1 + + executor._emit_sub_llm_event( + RLMEventType.SUB_LLM_START, + model="gemini-3-flash-preview", + prompt="Test prompt", + ) + + # Check event was queued + assert not executor._event_queue.empty() + event = executor._event_queue.get() + + metadata = event.custom_metadata + assert metadata["event_type"] == RLMEventType.SUB_LLM_START.value + assert metadata["model"] == "gemini-3-flash-preview" + assert metadata["prompt_preview"] == "Test prompt" + assert metadata["iteration"] == 2 + assert metadata["block_index"] == 1 + assert metadata["agent_name"] == "rlm_agent" + assert metadata["agent_depth"] == 0 + assert metadata["metadata"]["recursive"] is False + + def test_emit_sub_llm_end_event_success(self): + """Emit SUB_LLM_END event on success.""" + executor = RLMCodeExecutor(sub_model="test-model") + + executor._emit_sub_llm_event( + RLMEventType.SUB_LLM_END, + model="test-model", + response="The answer is 42.", + execution_time_ms=150.0, + ) + + event = executor._event_queue.get() + metadata = event.custom_metadata + + assert metadata["event_type"] == RLMEventType.SUB_LLM_END.value + assert metadata["response_preview"] == "The answer is 42." + assert metadata["response_full"] == "The answer is 42." + assert metadata["execution_time_ms"] == 150.0 + assert metadata.get("error") is None + + def test_emit_sub_llm_end_event_failure(self): + """Emit SUB_LLM_END event on failure.""" + executor = RLMCodeExecutor(sub_model="test-model") + + executor._emit_sub_llm_event( + RLMEventType.SUB_LLM_END, + model="test-model", + error="API rate limit exceeded", + execution_time_ms=100.0, + ) + + event = executor._event_queue.get() + metadata = event.custom_metadata + + assert metadata["event_type"] == RLMEventType.SUB_LLM_END.value + assert metadata["error"] == "API rate limit exceeded" + assert metadata.get("response_preview") is None + + def test_emit_sub_llm_event_with_batch_metadata(self): + """Emit SUB_LLM event with batch metadata.""" + executor = RLMCodeExecutor(sub_model="test-model") + + executor._emit_sub_llm_event( + RLMEventType.SUB_LLM_START, + model="test-model", + prompt="Batch query", + batch_index=2, + batch_size=5, + ) + + event = executor._event_queue.get() + metadata = event.custom_metadata + + assert metadata["batch_index"] == 2 + assert metadata["batch_size"] == 5 + + +class TestCodeExecutorSimpleLLMCall: + """Tests for RLMCodeExecutor._simple_llm_call method.""" + + def test_simple_llm_call_emits_events(self): + """Simple LLM call emits START and END events.""" + executor = RLMCodeExecutor(sub_model="test-model") + + # Mock the genai client (fresh client is created inside _simple_llm_call) + mock_response = MagicMock() + mock_response.text = "Mocked response" + mock_response.usage_metadata = MagicMock() + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 20 + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + + with patch("adk_rlm.code_executor.genai.Client", return_value=mock_client): + result = executor._simple_llm_call("Test prompt", "test-model") + + assert result == "Mocked response" + + # Should have 2 events: START and END + events = [] + while not executor._event_queue.empty(): + events.append(executor._event_queue.get()) + + assert len(events) == 2 + + start_event = events[0] + end_event = events[1] + + assert ( + start_event.custom_metadata["event_type"] + == RLMEventType.SUB_LLM_START.value + ) + assert ( + end_event.custom_metadata["event_type"] + == RLMEventType.SUB_LLM_END.value + ) + assert end_event.custom_metadata["response_full"] == "Mocked response" + assert end_event.custom_metadata.get("error") is None + + def test_simple_llm_call_emits_events_on_error(self): + """Simple LLM call emits events even when it fails.""" + executor = RLMCodeExecutor(sub_model="test-model") + + mock_client = MagicMock() + mock_client.models.generate_content.side_effect = Exception("API error") + + with patch("adk_rlm.code_executor.genai.Client", return_value=mock_client): + result = executor._simple_llm_call("Test prompt", "test-model") + + assert "Error: LLM query failed" in result + assert "API error" in result + + events = [] + while not executor._event_queue.empty(): + events.append(executor._event_queue.get()) + + assert len(events) == 2 + + end_event = events[1] + assert ( + end_event.custom_metadata["event_type"] + == RLMEventType.SUB_LLM_END.value + ) + assert end_event.custom_metadata["error"] == "API error" + + def test_simple_llm_call_logs_to_jsonl(self, temp_log_dir): + """Simple LLM call logs to JSONL logger.""" + logger = RLMLogger(temp_log_dir) + executor = RLMCodeExecutor( + sub_model="test-model", + logger=logger, + parent_agent="rlm_agent", + ) + executor._current_iteration = 1 + executor._current_block_index = 0 + + mock_response = MagicMock() + mock_response.text = "Logged response" + mock_response.usage_metadata = MagicMock() + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 20 + + mock_client = MagicMock() + mock_client.models.generate_content.return_value = mock_response + + with patch("adk_rlm.code_executor.genai.Client", return_value=mock_client): + executor._simple_llm_call("Logged prompt", "test-model") + + with open(logger.get_log_path()) as f: + entry = json.loads(f.readline()) + + assert entry["type"] == "simple_llm_call" + assert entry["prompt_full"] == "Logged prompt" + assert entry["response_full"] == "Logged response" + assert entry["agent_name"] == "rlm_agent" + assert entry["parent_iteration"] == 1 + assert entry["success"] is True + + +class TestCodeExecutorBatchedNonRecursive: + """Tests for llm_query_batched with recursive=False.""" + + def test_batched_non_recursive_emits_events(self): + """Batched non-recursive calls emit events for each query.""" + executor = RLMCodeExecutor(sub_model="test-model") + + mock_response = MagicMock() + mock_response.text = "Batch response" + mock_response.usage_metadata = MagicMock() + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 20 + + # Mock the fresh client created inside run_all() + mock_client = MagicMock() + mock_client.aio.models.generate_content.return_value = mock_response + + llm_query_batched = executor._create_llm_query_batched_fn() + + with patch("adk_rlm.code_executor.genai.Client", return_value=mock_client): + results = llm_query_batched( + ["Query 1", "Query 2", "Query 3"], + recursive=False, + ) + + assert len(results) == 3 + + # Collect all events + events = [] + while not executor._event_queue.empty(): + events.append(executor._event_queue.get()) + + # Should have 2 events per query (START + END) = 6 events + assert len(events) == 6 + + start_events = [ + e + for e in events + if e.custom_metadata["event_type"] == RLMEventType.SUB_LLM_START.value + ] + end_events = [ + e + for e in events + if e.custom_metadata["event_type"] == RLMEventType.SUB_LLM_END.value + ] + + assert len(start_events) == 3 + assert len(end_events) == 3 + + # Check batch metadata + for event in events: + assert event.custom_metadata["batch_size"] == 3 + assert event.custom_metadata["batch_index"] in [0, 1, 2] + + def test_batched_non_recursive_logs_all_calls(self, temp_log_dir): + """Batched non-recursive calls log all queries.""" + logger = RLMLogger(temp_log_dir) + executor = RLMCodeExecutor( + sub_model="test-model", + logger=logger, + ) + + mock_response = MagicMock() + mock_response.text = "Batch response" + mock_response.usage_metadata = MagicMock() + mock_response.usage_metadata.prompt_token_count = 10 + mock_response.usage_metadata.candidates_token_count = 20 + + # Mock the fresh client created inside run_all() + mock_client = MagicMock() + mock_client.aio.models.generate_content.return_value = mock_response + + llm_query_batched = executor._create_llm_query_batched_fn() + + with patch("adk_rlm.code_executor.genai.Client", return_value=mock_client): + llm_query_batched(["Q1", "Q2"], recursive=False) + + with open(logger.get_log_path()) as f: + entries = [json.loads(line) for line in f] + + assert len(entries) == 2 + assert all(e["type"] == "simple_llm_call" for e in entries) + assert all(e["batch_size"] == 2 for e in entries) + + batch_indices = {e["batch_index"] for e in entries} + assert batch_indices == {0, 1} diff --git a/contributing/samples/rlm/tests/test_types.py b/contributing/samples/rlm/tests/test_types.py new file mode 100644 index 0000000000..1b592e2e48 --- /dev/null +++ b/contributing/samples/rlm/tests/test_types.py @@ -0,0 +1,275 @@ +""" +Tests for data types and serialization. +""" + +import json + +from adk_rlm.types import CodeBlock +from adk_rlm.types import ModelUsageSummary +from adk_rlm.types import QueryMetadata +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMChatCompletion +from adk_rlm.types import RLMIteration +from adk_rlm.types import RLMMetadata +from adk_rlm.types import UsageSummary +import pytest + + +class TestModelUsageSummary: + """Tests for ModelUsageSummary.""" + + def test_to_dict(self): + """Serialize to dict.""" + usage = ModelUsageSummary( + total_calls=5, total_input_tokens=1000, total_output_tokens=500 + ) + d = usage.to_dict() + + assert d["total_calls"] == 5 + assert d["total_input_tokens"] == 1000 + assert d["total_output_tokens"] == 500 + + def test_from_dict(self): + """Deserialize from dict.""" + d = { + "total_calls": 5, + "total_input_tokens": 1000, + "total_output_tokens": 500, + } + usage = ModelUsageSummary.from_dict(d) + + assert usage.total_calls == 5 + assert usage.total_input_tokens == 1000 + assert usage.total_output_tokens == 500 + + def test_round_trip(self): + """Serialize then deserialize.""" + original = ModelUsageSummary( + total_calls=10, total_input_tokens=2000, total_output_tokens=1000 + ) + restored = ModelUsageSummary.from_dict(original.to_dict()) + + assert restored.total_calls == original.total_calls + assert restored.total_input_tokens == original.total_input_tokens + assert restored.total_output_tokens == original.total_output_tokens + + +class TestUsageSummary: + """Tests for UsageSummary.""" + + def test_to_dict(self): + """Serialize to dict.""" + usage = UsageSummary( + model_usage_summaries={ + "gemini-pro": ModelUsageSummary( + total_calls=3, total_input_tokens=500, total_output_tokens=200 + ), + "gemini-flash": ModelUsageSummary( + total_calls=10, total_input_tokens=1000, total_output_tokens=500 + ), + } + ) + d = usage.to_dict() + + assert "gemini-pro" in d["model_usage_summaries"] + assert "gemini-flash" in d["model_usage_summaries"] + + def test_total_properties(self): + """Test total properties.""" + usage = UsageSummary( + model_usage_summaries={ + "model1": ModelUsageSummary( + total_calls=3, total_input_tokens=500, total_output_tokens=200 + ), + "model2": ModelUsageSummary( + total_calls=7, total_input_tokens=500, total_output_tokens=300 + ), + } + ) + + assert usage.total_calls == 10 + assert usage.total_input_tokens == 1000 + assert usage.total_output_tokens == 500 + + +class TestREPLResult: + """Tests for REPLResult.""" + + def test_to_dict(self): + """Serialize to dict.""" + result = REPLResult( + stdout="Hello", + stderr="", + locals={"x": 42, "y": "test"}, + execution_time=0.5, + rlm_calls=[], + ) + d = result.to_dict() + + assert d["stdout"] == "Hello" + assert d["stderr"] == "" + assert d["execution_time"] == 0.5 + assert "x" in d["locals"] + + def test_serialize_complex_locals(self): + """Locals with complex types are serialized safely.""" + import re + + result = REPLResult( + stdout="", + stderr="", + locals={ + "x": 42, + "func": lambda: None, + "module": re, + "nested": {"a": [1, 2, 3]}, + }, + execution_time=0.1, + ) + d = result.to_dict() + + # Should not raise, should convert to string representations + json_str = json.dumps(d) + assert json_str # Valid JSON + + def test_str_representation(self): + """String representation.""" + result = REPLResult( + stdout="output", + stderr="error", + locals={"x": 1}, + execution_time=0.123, + rlm_calls=[], + ) + s = str(result) + + assert "REPLResult" in s + assert "0.123" in s + + +class TestCodeBlock: + """Tests for CodeBlock.""" + + def test_to_dict(self): + """Serialize to dict.""" + result = REPLResult( + stdout="42", stderr="", locals={"x": 42}, execution_time=0.1 + ) + block = CodeBlock(code="x = 42\nprint(x)", result=result) + d = block.to_dict() + + assert d["code"] == "x = 42\nprint(x)" + assert "result" in d + assert d["result"]["stdout"] == "42" + + +class TestRLMIteration: + """Tests for RLMIteration.""" + + def test_to_dict(self): + """Serialize to dict.""" + result = REPLResult(stdout="", stderr="", locals={}, execution_time=0.1) + iteration = RLMIteration( + prompt="test prompt", + response="test response", + code_blocks=[CodeBlock(code="x = 1", result=result)], + final_answer="final", + iteration_time=1.5, + ) + d = iteration.to_dict() + + assert d["prompt"] == "test prompt" + assert d["response"] == "test response" + assert len(d["code_blocks"]) == 1 + assert d["final_answer"] == "final" + assert d["iteration_time"] == 1.5 + + def test_from_dict(self): + """Deserialize from dict.""" + d = { + "prompt": "prompt", + "response": "response", + "code_blocks": [], + "final_answer": None, + "iteration_time": 2.0, + } + iteration = RLMIteration.from_dict(d) + + assert iteration.prompt == "prompt" + assert iteration.response == "response" + assert iteration.iteration_time == 2.0 + + +class TestRLMMetadata: + """Tests for RLMMetadata.""" + + def test_to_dict(self): + """Serialize to dict.""" + metadata = RLMMetadata( + root_model="gemini-pro", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={"model_name": "gemini-pro"}, + environment_type="local", + environment_kwargs={}, + other_backends=["gemini-flash"], + ) + d = metadata.to_dict() + + assert d["root_model"] == "gemini-pro" + assert d["max_iterations"] == 30 + assert d["backend"] == "gemini" + assert d["other_backends"] == ["gemini-flash"] + + +class TestRLMChatCompletion: + """Tests for RLMChatCompletion.""" + + def test_to_dict(self): + """Serialize to dict.""" + completion = RLMChatCompletion( + root_model="gemini-pro", + prompt="test prompt", + response="test response", + usage_summary=UsageSummary(), + execution_time=5.0, + ) + d = completion.to_dict() + + assert d["root_model"] == "gemini-pro" + assert d["response"] == "test response" + assert d["execution_time"] == 5.0 + + +class TestQueryMetadata: + """Tests for QueryMetadata.""" + + def test_string_context(self): + """Metadata for string context.""" + meta = QueryMetadata("Hello, World!") + + assert meta.context_type == "str" + assert meta.context_total_length == 13 + assert meta.context_lengths == [13] + + def test_dict_context(self): + """Metadata for dict context.""" + meta = QueryMetadata({"key1": "value1", "key2": "value2"}) + + assert meta.context_type == "dict" + assert len(meta.context_lengths) == 2 + + def test_list_context(self): + """Metadata for list context.""" + meta = QueryMetadata(["chunk1", "chunk2", "chunk3"]) + + assert meta.context_type == "list" + assert len(meta.context_lengths) == 3 + + def test_empty_list(self): + """Metadata for empty list.""" + meta = QueryMetadata([]) + + assert meta.context_type == "list" + assert meta.context_total_length == 0 diff --git a/contributing/samples/rlm/tests/test_usage.py b/contributing/samples/rlm/tests/test_usage.py new file mode 100644 index 0000000000..59ecbb98e4 --- /dev/null +++ b/contributing/samples/rlm/tests/test_usage.py @@ -0,0 +1,71 @@ +""" +Tests for usage tracking. +""" + +from adk_rlm.usage import UsageTracker +import pytest + + +class TestUsageTracker: + """Tests for UsageTracker.""" + + def test_track_single_call(self): + """Track single call.""" + tracker = UsageTracker() + tracker.add("gemini-pro", input_tokens=100, output_tokens=50) + + assert tracker.total_calls == 1 + assert tracker.total_input_tokens == 100 + assert tracker.total_output_tokens == 50 + + def test_track_multiple_calls_same_model(self): + """Multiple calls to same model aggregate.""" + tracker = UsageTracker() + tracker.add("gemini-pro", input_tokens=100, output_tokens=50) + tracker.add("gemini-pro", input_tokens=200, output_tokens=100) + + assert tracker.total_calls == 2 + assert tracker.total_input_tokens == 300 + assert tracker.total_output_tokens == 150 + + def test_track_multiple_models(self): + """Calls to different models tracked separately.""" + tracker = UsageTracker() + tracker.add("gemini-pro", input_tokens=100, output_tokens=50) + tracker.add("gemini-flash", input_tokens=200, output_tokens=100) + + summary = tracker.get_summary() + + assert len(summary.model_usage_summaries) == 2 + assert summary.model_usage_summaries["gemini-pro"].total_calls == 1 + assert summary.model_usage_summaries["gemini-flash"].total_calls == 1 + + def test_get_summary(self): + """Get usage summary.""" + tracker = UsageTracker() + tracker.add("model1", input_tokens=100, output_tokens=50) + tracker.add("model2", input_tokens=200, output_tokens=100) + + summary = tracker.get_summary() + + assert "model1" in summary.model_usage_summaries + assert "model2" in summary.model_usage_summaries + assert summary.total_calls == 2 + + def test_reset(self): + """Reset clears all tracking.""" + tracker = UsageTracker() + tracker.add("model", input_tokens=100, output_tokens=50) + tracker.reset() + + assert tracker.total_calls == 0 + assert tracker.total_input_tokens == 0 + assert tracker.total_output_tokens == 0 + + def test_zero_usage(self): + """No calls returns zeros.""" + tracker = UsageTracker() + summary = tracker.get_summary() + + assert len(summary.model_usage_summaries) == 0 + assert tracker.total_calls == 0 diff --git a/contributing/samples/rlm/tests/test_visualizer_compat.py b/contributing/samples/rlm/tests/test_visualizer_compat.py new file mode 100644 index 0000000000..dfbb98611b --- /dev/null +++ b/contributing/samples/rlm/tests/test_visualizer_compat.py @@ -0,0 +1,332 @@ +""" +Visualizer compatibility tests. + +These tests verify that ADK-RLM JSONL output is compatible with the +original RLM visualizer by validating the log schema. +""" + +import json +from pathlib import Path + +from adk_rlm.logging.rlm_logger import RLMLogger +from adk_rlm.types import CodeBlock +from adk_rlm.types import ModelUsageSummary +from adk_rlm.types import REPLResult +from adk_rlm.types import RLMChatCompletion +from adk_rlm.types import RLMIteration +from adk_rlm.types import RLMMetadata +from adk_rlm.types import UsageSummary +import pytest + + +class TestVisualizerSchemaCompatibility: + """Tests that verify JSONL output matches visualizer expectations.""" + + def test_metadata_schema(self, temp_log_dir): + """Verify metadata entry matches expected schema.""" + logger = RLMLogger(temp_log_dir) + metadata = RLMMetadata( + root_model="gemini-3-pro-preview", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={"model_name": "gemini-3-pro-preview"}, + environment_type="local", + environment_kwargs={}, + other_backends=["gemini-3-flash-preview"], + ) + logger.log_metadata(metadata) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + # Required fields for visualizer + assert entry["type"] == "metadata" + assert "timestamp" in entry + assert "root_model" in entry + assert "max_depth" in entry + assert "max_iterations" in entry + assert "backend" in entry + assert "backend_kwargs" in entry + assert "environment_type" in entry + assert "environment_kwargs" in entry + assert "other_backends" in entry + + def test_iteration_schema(self, temp_log_dir): + """Verify iteration entry matches expected schema.""" + logger = RLMLogger(temp_log_dir) + + # Create an iteration with code blocks and sub-calls + sub_call = RLMChatCompletion( + root_model="gemini-3-flash-preview", + prompt="What is 2+2?", + response="4", + usage_summary=UsageSummary( + model_usage_summaries={ + "gemini-3-flash-preview": ModelUsageSummary( + total_calls=1, total_input_tokens=10, total_output_tokens=5 + ) + } + ), + execution_time=0.5, + ) + + result = REPLResult( + stdout="Output here", + stderr="", + locals={"x": 42, "y": "test"}, + execution_time=0.1, + rlm_calls=[sub_call], + ) + + code_block = CodeBlock(code="x = 42\nprint(x)", result=result) + + iteration = RLMIteration( + prompt=[{"role": "user", "content": "test"}], + response="Let me calculate...\n```repl\nx = 42\nprint(x)\n```", + code_blocks=[code_block], + final_answer=None, + iteration_time=1.5, + ) + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + # Required fields for visualizer iteration + assert entry["type"] == "iteration" + assert "iteration" in entry + assert "timestamp" in entry + assert "prompt" in entry + assert "response" in entry + assert "code_blocks" in entry + assert "final_answer" in entry + assert "iteration_time" in entry + + # Check code block structure + assert len(entry["code_blocks"]) == 1 + cb = entry["code_blocks"][0] + assert "code" in cb + assert "result" in cb + + # Check result structure + result_entry = cb["result"] + assert "stdout" in result_entry + assert "stderr" in result_entry + assert "locals" in result_entry + assert "execution_time" in result_entry + assert "rlm_calls" in result_entry + + # Check rlm_calls (sub-calls) structure + assert len(result_entry["rlm_calls"]) == 1 + call = result_entry["rlm_calls"][0] + assert "root_model" in call + assert "prompt" in call + assert "response" in call + assert "usage_summary" in call + assert "execution_time" in call + + def test_full_log_file_structure(self, temp_log_dir): + """Test complete log file with metadata and multiple iterations.""" + logger = RLMLogger(temp_log_dir) + + # Log metadata + metadata = RLMMetadata( + root_model="gemini-3-pro-preview", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={"model_name": "gemini-3-pro-preview"}, + environment_type="local", + environment_kwargs={}, + ) + logger.log_metadata(metadata) + + # Log multiple iterations + for i in range(3): + result = REPLResult( + stdout=f"Output {i}", + stderr="", + locals={}, + execution_time=0.1, + ) + iteration = RLMIteration( + prompt=f"Iteration {i}", + response=f"Response {i}", + code_blocks=[CodeBlock(code=f"print({i})", result=result)], + final_answer="Final" if i == 2 else None, + iteration_time=0.5, + ) + logger.log(iteration) + + # Verify structure + with open(logger.log_file_path) as f: + lines = f.readlines() + + assert len(lines) == 4 # 1 metadata + 3 iterations + + entries = [json.loads(line) for line in lines] + assert entries[0]["type"] == "metadata" + for i, entry in enumerate(entries[1:], 1): + assert entry["type"] == "iteration" + assert entry["iteration"] == i + + def test_field_naming_consistency(self): + """Verify we use the same field names as original RLM.""" + # REPLResult must use 'rlm_calls' not 'llm_calls' + result = REPLResult( + stdout="", + stderr="", + locals={}, + execution_time=0.1, + rlm_calls=[], + ) + d = result.to_dict() + assert "rlm_calls" in d + assert "llm_calls" not in d + + # RLMIteration must have specific fields + iteration = RLMIteration( + prompt="test", + response="response", + code_blocks=[], + final_answer=None, + iteration_time=1.0, + ) + d = iteration.to_dict() + assert "prompt" in d + assert "response" in d + assert "code_blocks" in d + assert "final_answer" in d + assert "iteration_time" in d + + # RLMMetadata must have specific fields + metadata = RLMMetadata( + root_model="model", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={}, + environment_type="local", + environment_kwargs={}, + ) + d = metadata.to_dict() + assert "root_model" in d + assert "max_depth" in d + assert "max_iterations" in d + assert "backend" in d + assert "backend_kwargs" in d + assert "environment_type" in d + assert "environment_kwargs" in d + + def test_json_serializable(self, temp_log_dir): + """All log entries must be valid JSON.""" + logger = RLMLogger(temp_log_dir) + + # Create complex structures + metadata = RLMMetadata( + root_model="gemini-pro", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={"nested": {"key": "value"}}, + environment_type="local", + environment_kwargs={"list": [1, 2, 3]}, + ) + logger.log_metadata(metadata) + + # Log with complex locals + result = REPLResult( + stdout="test", + stderr="", + locals={ + "list_var": [1, 2, 3], + "dict_var": {"a": 1, "b": {"c": 2}}, + "tuple_var": (1, 2), + }, + execution_time=0.1, + ) + iteration = RLMIteration( + prompt="test", + response="response", + code_blocks=[CodeBlock(code="pass", result=result)], + ) + logger.log(iteration) + + # Verify all entries are valid JSON + with open(logger.log_file_path) as f: + for line in f: + entry = json.loads(line) + # Re-serialize to ensure no issues + json.dumps(entry) + + +class TestVisualizerFieldTypes: + """Tests for correct field types in log entries.""" + + def test_timestamp_is_iso_format(self, temp_log_dir): + """Timestamps should be ISO format strings.""" + from datetime import datetime + + logger = RLMLogger(temp_log_dir) + logger.log_metadata( + RLMMetadata( + root_model="model", + max_depth=1, + max_iterations=30, + backend="gemini", + backend_kwargs={}, + environment_type="local", + environment_kwargs={}, + ) + ) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + # Should be parseable as ISO timestamp + timestamp = entry["timestamp"] + datetime.fromisoformat(timestamp) + + def test_iteration_number_is_integer(self, temp_log_dir): + """Iteration number should be an integer.""" + logger = RLMLogger(temp_log_dir) + logger.log( + RLMIteration( + prompt="test", + response="response", + code_blocks=[], + ) + ) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert isinstance(entry["iteration"], int) + assert entry["iteration"] == 1 + + def test_execution_times_are_floats(self, temp_log_dir): + """Execution times should be floats.""" + logger = RLMLogger(temp_log_dir) + + result = REPLResult( + stdout="", + stderr="", + locals={}, + execution_time=0.123456, + ) + iteration = RLMIteration( + prompt="test", + response="response", + code_blocks=[CodeBlock(code="pass", result=result)], + iteration_time=1.5, + ) + logger.log(iteration) + + with open(logger.log_file_path) as f: + entry = json.loads(f.readline()) + + assert isinstance(entry["iteration_time"], float) + assert isinstance( + entry["code_blocks"][0]["result"]["execution_time"], float + ) diff --git a/contributing/samples/rlm/tests/ui/__init__.py b/contributing/samples/rlm/tests/ui/__init__.py new file mode 100644 index 0000000000..207ea3bc73 --- /dev/null +++ b/contributing/samples/rlm/tests/ui/__init__.py @@ -0,0 +1 @@ +# UI tests using Playwright diff --git a/contributing/samples/rlm/tests/ui/conftest.py b/contributing/samples/rlm/tests/ui/conftest.py new file mode 100644 index 0000000000..0856a989f4 --- /dev/null +++ b/contributing/samples/rlm/tests/ui/conftest.py @@ -0,0 +1,472 @@ +""" +Fixtures for UI tests with mocked WebSocket responses. + +These tests use Playwright to test the frontend UI components +with a mock WebSocket server that simulates backend responses. +""" + +import asyncio +import json +import time +from typing import Any +from unittest.mock import AsyncMock + +from playwright.sync_api import Page +from playwright.sync_api import Route +from playwright.sync_api import WebSocket +import pytest + + +# Sample event data for mocking +def create_mock_event( + event_type: str, + iteration: int = 1, + event_id: int = 0, + **metadata: Any, +) -> dict: + """Create a mock event matching the format from web.py.""" + icons = { + "rlm.run.start": "play_arrow", + "rlm.run.end": "stop", + "rlm.iteration.start": "loop", + "rlm.iteration.end": "check_circle", + "rlm.llm.start": "psychology", + "rlm.llm.end": "psychology", + "rlm.code.found": "code", + "rlm.code.start": "terminal", + "rlm.code.end": "terminal", + "rlm.final.detected": "star", + "rlm.final.answer": "check", + } + colors = { + "rlm.run.start": "#7AA2F7", + "rlm.run.end": "#9ECE6A", + "rlm.iteration.start": "#7AA2F7", + "rlm.iteration.end": "#565F89", + "rlm.llm.start": "#BB9AF7", + "rlm.llm.end": "#BB9AF7", + "rlm.code.found": "#9ECE6A", + "rlm.code.start": "#7DCFFF", + "rlm.code.end": "#7DCFFF", + "rlm.final.detected": "#E0AF68", + "rlm.final.answer": "#E0AF68", + } + label = event_type.replace("rlm.", "").replace(".", " ").title() + + return { + "id": event_id, + "type": "event", + "event_type": event_type, + "iteration": iteration, + "timestamp": 0.1 * (event_id + 1), + "icon": icons.get(event_type, "circle"), + "color": colors.get(event_type, "#A9B1D6"), + "label": label, + "metadata": metadata, + } + + +def create_mock_session( + session_id: str = "test-session-123", + title: str = "Test Session", + model: str = "gemini-3-pro-preview", + conversation: list | None = None, + events: list | None = None, + files: list | None = None, +) -> dict: + """Create a mock session response.""" + return { + "type": "status_response", + "session_id": session_id, + "title": title, + "model": model, + "sub_model": model, + "max_iterations": 30, + "files": files or [], + "conversation": conversation or [], + "events": events or [], + } + + +def create_mock_sessions_list(sessions: list[dict] | None = None) -> dict: + """Create a mock sessions list response.""" + if sessions is None: + sessions = [ + { + "session_id": "session-1", + "title": "First Session", + "updated_at": "2024-01-15T10:00:00", + "message_count": 2, + }, + { + "session_id": "session-2", + "title": "Second Session", + "updated_at": "2024-01-14T09:00:00", + "message_count": 5, + }, + ] + return { + "type": "sessions_list", + "sessions": sessions, + } + + +class MockWebSocketServer: + """Mock WebSocket server for UI testing.""" + + def __init__(self): + self.messages_received: list[dict] = [] + self.messages_to_send: list[dict] = [] + self.auto_responses: dict[str, list[dict]] = {} + self.connected = False + self._ws: WebSocket | None = None + + def queue_message(self, message: dict): + """Queue a message to be sent to the client.""" + self.messages_to_send.append(message) + + def queue_messages(self, messages: list[dict]): + """Queue multiple messages.""" + self.messages_to_send.extend(messages) + + def set_auto_response(self, action: str, responses: list[dict]): + """Set automatic responses for a given action.""" + self.auto_responses[action] = responses + + def get_received_messages(self) -> list[dict]: + """Get all messages received from the client.""" + return self.messages_received.copy() + + def clear(self): + """Clear all queued messages and received messages.""" + self.messages_received.clear() + self.messages_to_send.clear() + + +class WebSocketInterceptor: + """Intercept and mock WebSocket connections in Playwright.""" + + # JavaScript code for WebSocket mock - defined once as class attribute + MOCK_WS_SCRIPT = """ + // Store original WebSocket + window._OriginalWebSocket = window.WebSocket; + window._mockWsMessages = []; + window._mockWsReceived = []; + window._mockWsConnected = false; + window._mockWsAutoResponses = {}; + window._mockWs = null; + + // Create mock WebSocket class + class MockWebSocket { + constructor(url) { + this.url = url; + this.readyState = 0; // CONNECTING + this.onopen = null; + this.onclose = null; + this.onerror = null; + this.onmessage = null; + window._mockWs = this; + + // Auto-connect after a small delay + setTimeout(() => { + this.readyState = 1; // OPEN + window._mockWsConnected = true; + if (this.onopen) { + this.onopen({ type: 'open' }); + } + // Send any queued messages + this._processQueue(); + }, 50); + } + + send(data) { + const parsed = JSON.parse(data); + window._mockWsReceived.push(parsed); + + // Check for auto-responses + const action = parsed.action; + if (window._mockWsAutoResponses[action]) { + const responses = window._mockWsAutoResponses[action]; + responses.forEach((resp, i) => { + setTimeout(() => { + if (this.onmessage) { + this.onmessage({ data: JSON.stringify(resp) }); + } + }, 10 * (i + 1)); + }); + } + } + + close() { + this.readyState = 3; // CLOSED + window._mockWsConnected = false; + if (this.onclose) { + this.onclose({ type: 'close' }); + } + } + + _processQueue() { + while (window._mockWsMessages.length > 0) { + const msg = window._mockWsMessages.shift(); + if (this.onmessage) { + this.onmessage({ data: JSON.stringify(msg) }); + } + } + } + + _receiveMessage(data) { + if (this.onmessage) { + this.onmessage({ data: JSON.stringify(data) }); + } + } + } + + MockWebSocket.CONNECTING = 0; + MockWebSocket.OPEN = 1; + MockWebSocket.CLOSING = 2; + MockWebSocket.CLOSED = 3; + + window.WebSocket = MockWebSocket; + """ + + def __init__(self, page: Page): + self.page = page + self.mock_server = MockWebSocketServer() + self._setup_complete = False + self._pending_auto_responses: dict[str, list[dict]] = {} + + def setup(self): + """Set up WebSocket interception via add_init_script (runs before page load).""" + if self._setup_complete: + return + + # Use add_init_script so the mock is injected BEFORE page JavaScript runs + self.page.add_init_script(self.MOCK_WS_SCRIPT) + self._setup_complete = True + + def set_auto_response(self, action: str, responses: list[dict]): + """ + Set automatic responses for a given action. + + IMPORTANT: Call this BEFORE page.goto() for responses needed during + initial connection (get_status, list_sessions). + """ + self._pending_auto_responses[action] = responses + # Add an init script to set this auto-response before page JS runs + script = ( + f"window._mockWsAutoResponses[{json.dumps(action)}] =" + f" {json.dumps(responses)};" + ) + self.page.add_init_script(script) + + def set_auto_response_after_load(self, action: str, responses: list[dict]): + """Set automatic responses after page has loaded.""" + self.page.evaluate( + "(data) => { window._mockWsAutoResponses[data.action] =" + " data.responses; }", + {"action": action, "responses": responses}, + ) + + def send_message(self, message: dict): + """Send a message from the mock server to the client.""" + self.page.evaluate( + """(msg) => { + if (window._mockWs && window._mockWs.readyState === 1) { + window._mockWs._receiveMessage(msg); + } else { + window._mockWsMessages.push(msg); + } + }""", + message, + ) + + def send_messages(self, messages: list[dict], delay_ms: int = 10): + """Send multiple messages with delay between them.""" + for i, msg in enumerate(messages): + self.page.evaluate( + f"""(msg) => {{ + setTimeout(() => {{ + if (window._mockWs && window._mockWs.readyState === 1) {{ + window._mockWs._receiveMessage(msg); + }} + }}, {i * delay_ms}); + }}""", + msg, + ) + + def get_received_messages(self) -> list[dict]: + """Get all messages received by the mock server.""" + return self.page.evaluate("() => window._mockWsReceived || []") + + def is_connected(self) -> bool: + """Check if WebSocket is connected.""" + return self.page.evaluate("() => window._mockWsConnected || false") + + def wait_for_connection(self, timeout: int = 5000): + """Wait for WebSocket connection to be established.""" + self.page.wait_for_function( + "() => window._mockWsConnected === true", + timeout=timeout, + ) + + +@pytest.fixture +def mock_ws(page: Page) -> WebSocketInterceptor: + """ + Fixture that provides a WebSocket interceptor for mocking. + + Usage: + def test_example(mock_ws, page): + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + page.goto("http://localhost:8000") + mock_ws.wait_for_connection() + """ + return WebSocketInterceptor(page) + + +@pytest.fixture +def connected_page(page: Page, mock_ws: WebSocketInterceptor) -> Page: + """ + Fixture that provides a page with mocked WebSocket already connected. + + Sets up default auto-responses for initial connection handshake. + """ + mock_ws.setup() + + # Set up default auto-responses + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + return page + + +@pytest.fixture +def mock_query_response() -> list[dict]: + """ + Fixture that provides a sequence of events for a mock query response. + """ + return [ + {"type": "query_start", "prompt": "What is 2+2?"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + create_mock_event("rlm.iteration.start", iteration=1, event_id=1), + create_mock_event("rlm.llm.start", iteration=1, event_id=2), + create_mock_event( + "rlm.llm.end", + iteration=1, + event_id=3, + response_preview="Let me calculate 2+2...", + ), + create_mock_event( + "rlm.code.found", + iteration=1, + event_id=4, + code="result = 2 + 2\nFINAL_VAR('result')", + ), + create_mock_event("rlm.code.start", iteration=1, event_id=5), + create_mock_event( + "rlm.code.end", + iteration=1, + event_id=6, + output="4", + ), + create_mock_event("rlm.final.detected", iteration=1, event_id=7), + create_mock_event("rlm.iteration.end", iteration=1, event_id=8), + create_mock_event("rlm.run.end", iteration=1, event_id=9), + { + "type": "query_complete", + "elapsed_seconds": 1.5, + "total_events": 10, + "final_answer": "4", + "title": "What is 2+2?", + }, + ] + + +# HTML content for a minimal test server (used when we don't want to run full app) +MINIMAL_HTML = """ + + +Test + +
Loading...
+ + + +""" + + +@pytest.fixture(scope="session") +def live_server(): + """ + Start a live FastAPI server for UI tests. + + This starts the actual web server on a random available port. + The server is shared across all tests in the session for efficiency. + """ + import os + import socket + import subprocess + import sys + import tempfile + + # Find an available port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + + # Use a temp database for isolation + db_file = tempfile.mktemp(suffix=".db") + db_url = f"sqlite+aiosqlite:///{db_file}" + + # Start the server in a subprocess + env = os.environ.copy() + env["RLM_DB_URL"] = db_url + + # Start uvicorn via subprocess + proc = subprocess.Popen( + [ + sys.executable, + "-m", + "uvicorn", + "adk_rlm.web:app", + "--host", + "127.0.0.1", + "--port", + str(port), + ], + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + # Wait for server to be ready + import urllib.error + import urllib.request + + max_retries = 30 + for i in range(max_retries): + try: + urllib.request.urlopen(f"http://127.0.0.1:{port}/health", timeout=1) + break + except (urllib.error.URLError, ConnectionRefusedError): + time.sleep(0.2) + else: + proc.terminate() + raise RuntimeError(f"Server did not start within {max_retries * 0.2}s") + + url = f"http://127.0.0.1:{port}" + + yield url + + # Cleanup + proc.terminate() + proc.wait(timeout=5) + + # Remove temp database + if os.path.exists(db_file): + os.unlink(db_file) diff --git a/contributing/samples/rlm/tests/ui/test_event_log.py b/contributing/samples/rlm/tests/ui/test_event_log.py new file mode 100644 index 0000000000..4a80c333fd --- /dev/null +++ b/contributing/samples/rlm/tests/ui/test_event_log.py @@ -0,0 +1,555 @@ +""" +UI tests for the event log panel. + +These tests verify event log display, toggling, and event item interactions. +""" + +import re + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import create_mock_event +from .conftest import create_mock_session +from .conftest import create_mock_sessions_list +from .conftest import WebSocketInterceptor + +pytestmark = pytest.mark.ui + + +class TestEventLogPanel: + """Tests for the event log panel layout and toggle.""" + + def test_event_log_visible_by_default( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event log panel should be visible by default.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + event_log = page.locator("#event-log-panel") + expect(event_log).to_be_visible() + expect(event_log).not_to_have_class(re.compile(r"collapsed")) + + def test_event_log_has_title( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event log should display 'Event Log' title.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + title = page.locator(".event-log-title") + expect(title).to_have_text("Event Log") + + def test_event_count_shows_zero_initially( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event count should show '0 events' initially.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session(events=[])]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + event_count = page.locator("#event-count") + expect(event_count).to_have_text("0 events") + + def test_toggle_collapses_panel( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking toggle button should collapse the event log.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + event_log = page.locator("#event-log-panel") + toggle_btn = page.locator("#toggle-log-btn") + + # Should be expanded initially + expect(event_log).not_to_have_class(re.compile(r"collapsed")) + + # Click toggle + toggle_btn.click() + + # Should be collapsed + expect(event_log).to_have_class(re.compile(r"collapsed")) + + def test_toggle_expands_panel( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking toggle button again should expand the event log.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + event_log = page.locator("#event-log-panel") + toggle_btn = page.locator("#toggle-log-btn") + + # Collapse first + toggle_btn.click() + expect(event_log).to_have_class(re.compile(r"collapsed")) + + # Click again to expand + toggle_btn.click() + expect(event_log).not_to_have_class(re.compile(r"collapsed")) + + def test_empty_state_shown_when_no_events( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Empty state should be shown when no events exist.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session(events=[])]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + empty_state = page.locator("#event-log-content .empty-state") + expect(empty_state).to_be_visible() + expect(empty_state).to_contain_text("Events will appear here") + + +class TestEventDisplay: + """Tests for individual event display.""" + + def test_events_displayed_during_query( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Events should be displayed as they arrive during query.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + create_mock_event("rlm.iteration.start", iteration=1, event_id=1), + create_mock_event("rlm.llm.start", iteration=1, event_id=2), + create_mock_event("rlm.llm.end", iteration=1, event_id=3), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + # Event count should be updated + event_count = page.locator("#event-count") + expect(event_count).to_contain_text("4 events") + + def test_event_items_have_icon( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event items should display icons.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Event items should have icons + event_icons = page.locator(".event-icon") + expect(event_icons.first).to_be_visible() + + def test_event_items_have_label( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event items should display labels.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Event items should have labels + event_labels = page.locator(".event-label") + expect(event_labels.first).to_be_visible() + + def test_event_items_have_timestamp( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event items should display timestamp.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Event items should have timestamps + event_times = page.locator(".event-time") + expect(event_times.first).to_be_visible() + # Should contain seconds notation + expect(event_times.first).to_contain_text("s") + + def test_event_with_preview( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Events with content should show preview.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event( + "rlm.llm.end", + iteration=1, + event_id=0, + response_preview="This is a preview of the response...", + ), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Should show preview + preview = page.locator(".event-preview") + expect(preview).to_be_visible() + expect(preview).to_contain_text("This is a preview") + + +class TestAgentGroups: + """Tests for agent group display in event log.""" + + def test_agent_groups_created( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Agent groups should be created for events.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + { + **create_mock_event( + "rlm.iteration.start", iteration=1, event_id=0 + ), + "metadata": {"agent_name": "rlm_agent", "agent_depth": 0}, + }, + { + **create_mock_event("rlm.llm.start", iteration=1, event_id=1), + "metadata": {"agent_name": "rlm_agent", "agent_depth": 0}, + }, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + # Agent group should be created + agent_group = page.locator(".agent-group") + expect(agent_group.first).to_be_visible() + + def test_agent_group_expandable( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Agent groups should be expandable/collapsible.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + { + **create_mock_event( + "rlm.iteration.start", iteration=1, event_id=0 + ), + "metadata": {"agent_name": "rlm_agent", "agent_depth": 0}, + }, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + agent_group = page.locator(".agent-group").first + agent_header = agent_group.locator(".agent-header") + + # Should be expanded by default + expect(agent_group).to_have_class(re.compile(r"expanded")) + + # Click to collapse + agent_header.click() + + # Should be collapsed + expect(agent_group).not_to_have_class(re.compile(r"expanded")) + + +class TestIterationGroups: + """Tests for iteration group display within agents.""" + + def test_iteration_groups_created( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Iteration groups should be created within agents.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + { + **create_mock_event( + "rlm.iteration.start", iteration=1, event_id=0 + ), + "metadata": { + "agent_name": "rlm_agent", + "agent_depth": 0, + "iteration": 1, + }, + }, + { + **create_mock_event("rlm.llm.start", iteration=1, event_id=1), + "metadata": { + "agent_name": "rlm_agent", + "agent_depth": 0, + "iteration": 1, + }, + }, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + # Iteration group should be created + iteration_group = page.locator(".agent-iteration") + expect(iteration_group.first).to_be_visible() + + def test_multiple_iterations_displayed( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Multiple iterations should be displayed separately.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + { + **create_mock_event( + "rlm.iteration.start", iteration=1, event_id=0 + ), + "metadata": { + "agent_name": "rlm_agent", + "agent_depth": 0, + "iteration": 1, + }, + }, + { + **create_mock_event( + "rlm.iteration.end", iteration=1, event_id=1 + ), + "metadata": { + "agent_name": "rlm_agent", + "agent_depth": 0, + "iteration": 1, + }, + }, + { + **create_mock_event( + "rlm.iteration.start", iteration=2, event_id=2 + ), + "metadata": { + "agent_name": "rlm_agent", + "agent_depth": 0, + "iteration": 2, + }, + }, + { + **create_mock_event( + "rlm.iteration.end", iteration=2, event_id=3 + ), + "metadata": { + "agent_name": "rlm_agent", + "agent_depth": 0, + "iteration": 2, + }, + }, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + # Should have 2 iteration groups + iteration_groups = page.locator(".agent-iteration") + expect(iteration_groups).to_have_count(2) + + +class TestEventClick: + """Tests for clicking on event items.""" + + def test_event_click_opens_modal( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking an event should open the detail modal.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Click on event item + event_item = page.locator(".event-item").first + event_item.click() + + # Modal should be visible + modal = page.locator("#event-modal") + expect(modal).not_to_have_class(re.compile(r"hidden")) + + +class TestEventLogScroll: + """Tests for event log scrolling behavior.""" + + def test_event_log_scrolls_to_bottom( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event log should auto-scroll to bottom on new events.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + # Create many events to trigger scroll + events = [{"type": "query_start", "prompt": "Test"}] + for i in range(20): + events.append({ + **create_mock_event("rlm.llm.end", iteration=1, event_id=i), + "metadata": { + "agent_name": "rlm_agent", + "agent_depth": 0, + "iteration": 1, + "response_preview": f"Response {i}", + }, + }) + + mock_ws.set_auto_response("query", events) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(500) + + # Check that event log content is scrolled + event_log_content = page.locator("#event-log-content") + scroll_height = event_log_content.evaluate("el => el.scrollHeight") + scroll_top = event_log_content.evaluate("el => el.scrollTop") + client_height = event_log_content.evaluate("el => el.clientHeight") + + # Should be scrolled near bottom (allow some tolerance) + assert scroll_top + client_height >= scroll_height - 100 diff --git a/contributing/samples/rlm/tests/ui/test_modals.py b/contributing/samples/rlm/tests/ui/test_modals.py new file mode 100644 index 0000000000..8af77cf0e6 --- /dev/null +++ b/contributing/samples/rlm/tests/ui/test_modals.py @@ -0,0 +1,627 @@ +""" +UI tests for modal dialogs. + +These tests verify the settings modal and event detail modal behavior. +""" + +import re + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import create_mock_event +from .conftest import create_mock_session +from .conftest import create_mock_sessions_list +from .conftest import WebSocketInterceptor + +pytestmark = pytest.mark.ui + + +class TestSettingsModal: + """Tests for the settings/config modal.""" + + def test_settings_button_opens_modal( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking settings button should open config modal.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + config_modal = page.locator("#config-modal") + settings_btn = page.locator("#config-btn") + + # Modal should be hidden initially + expect(config_modal).to_have_class(re.compile(r"hidden")) + + # Click settings button + settings_btn.click() + + # Modal should be visible + expect(config_modal).not_to_have_class(re.compile(r"hidden")) + + def test_modal_displays_session_title( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Settings modal should display current session title.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(title="My Session")] + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + page.locator("#config-btn").click() + + title_input = page.locator("#config-title") + expect(title_input).to_have_value("My Session") + + def test_modal_displays_model( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Settings modal should display current model.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(model="gemini-3-flash-preview")] + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + page.locator("#config-btn").click() + + model_input = page.locator("#config-model") + expect(model_input).to_have_value("gemini-3-flash-preview") + + def test_close_button_closes_modal( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Close button should close the modal.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + config_modal = page.locator("#config-modal") + page.locator("#config-btn").click() + expect(config_modal).not_to_have_class(re.compile(r"hidden")) + + # Click close button + page.locator("#config-modal-close").click() + + expect(config_modal).to_have_class(re.compile(r"hidden")) + + def test_cancel_button_closes_modal( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Cancel button should close the modal.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + config_modal = page.locator("#config-modal") + page.locator("#config-btn").click() + expect(config_modal).not_to_have_class(re.compile(r"hidden")) + + # Click cancel button + page.locator("#config-cancel").click() + + expect(config_modal).to_have_class(re.compile(r"hidden")) + + def test_click_outside_closes_modal( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking outside modal should close it.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + config_modal = page.locator("#config-modal") + page.locator("#config-btn").click() + expect(config_modal).not_to_have_class(re.compile(r"hidden")) + + # Click on modal overlay (outside the modal content) + config_modal.click(position={"x": 10, "y": 10}) + + expect(config_modal).to_have_class(re.compile(r"hidden")) + + def test_save_sends_config_action( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Saving config should send config action via WebSocket.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "config", [{"type": "status", "message": "Configuration updated"}] + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + # Open settings + page.locator("#config-btn").click() + + # Modify values + page.locator("#config-title").fill("New Title") + page.locator("#config-model").fill("gemini-3-flash-preview") + page.locator("#config-iterations").fill("50") + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + # Submit form + page.locator("#config-form button[type='submit']").click() + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + config_msgs = [m for m in received if m.get("action") == "config"] + + assert len(config_msgs) == 1 + assert config_msgs[0]["title"] == "New Title" + assert config_msgs[0]["model"] == "gemini-3-flash-preview" + assert config_msgs[0]["max_iterations"] == 50 + + def test_save_closes_modal( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Saving should close the modal.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "config", [{"type": "status", "message": "Configuration updated"}] + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + config_modal = page.locator("#config-modal") + page.locator("#config-btn").click() + expect(config_modal).not_to_have_class(re.compile(r"hidden")) + + # Submit form + page.locator("#config-form button[type='submit']").click() + + expect(config_modal).to_have_class(re.compile(r"hidden")) + + def test_clear_session_button( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clear session button should clear conversation with confirmation.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", + [ + create_mock_session( + conversation=[ + { + "role": "user", + "content": "Hello", + "timestamp": "2024-01-15T10:00:00", + }, + ] + ) + ], + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "clear", [{"type": "status", "message": "Session cleared"}] + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + # Should have message initially + messages = page.locator(".message") + expect(messages).to_have_count(1) + + # Set up dialog handler to accept confirmation + page.on("dialog", lambda dialog: dialog.accept()) + + # Open settings and click clear + page.locator("#config-btn").click() + page.locator("#config-clear").click() + + page.wait_for_timeout(200) + + # Messages should be cleared + empty_state = page.locator("#empty-state") + expect(empty_state).to_be_visible() + + def test_clear_cancelled_no_action( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Cancelling clear confirmation should not clear.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", + [ + create_mock_session( + conversation=[ + { + "role": "user", + "content": "Hello", + "timestamp": "2024-01-15T10:00:00", + }, + ] + ) + ], + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + # Set up dialog handler to dismiss confirmation + page.on("dialog", lambda dialog: dialog.dismiss()) + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + # Open settings and click clear + page.locator("#config-btn").click() + page.locator("#config-clear").click() + + page.wait_for_timeout(100) + + # Should NOT have sent clear action + received = mock_ws.get_received_messages() + clear_msgs = [m for m in received if m.get("action") == "clear"] + assert len(clear_msgs) == 0 + + # Message should still exist + messages = page.locator(".message") + expect(messages).to_have_count(1) + + +class TestFilesInSettings: + """Tests for file handling in settings modal.""" + + def test_files_input_present( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Files input field should be present in settings.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + page.locator("#config-btn").click() + + files_input = page.locator("#config-files") + expect(files_input).to_be_visible() + + def test_adding_files_sends_action( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Adding files should send add_files action.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "config", [{"type": "status", "message": "Configuration updated"}] + ) + mock_ws.set_auto_response( + "add_files", + [{ + "type": "files_added", + "patterns": ["./docs/*.md"], + "count": 5, + "names": [ + "file1.md", + "file2.md", + "file3.md", + "file4.md", + "file5.md", + ], + "total": 5, + }], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + page.locator("#config-btn").click() + page.locator("#config-files").fill("./docs/*.md ./data/*.csv") + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + # Submit form + page.locator("#config-form button[type='submit']").click() + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + add_files_msgs = [m for m in received if m.get("action") == "add_files"] + + assert len(add_files_msgs) == 1 + assert "./docs/*.md" in add_files_msgs[0]["patterns"] + assert "./data/*.csv" in add_files_msgs[0]["patterns"] + + +class TestFilesDisplay: + """Tests for files display section.""" + + def test_files_section_hidden_when_empty( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Files section should be hidden when no files loaded.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session(files=[])]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + files_section = page.locator("#files-section") + expect(files_section).to_have_class(re.compile(r"hidden")) + + def test_files_section_visible_with_files( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Files section should be visible when files are loaded.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(files=["./docs/*.md"])] + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + files_section = page.locator("#files-section") + expect(files_section).not_to_have_class(re.compile(r"hidden")) + + def test_file_chips_displayed( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """File chips should be displayed for each file.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(files=["file1.md", "file2.md"])] + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + file_chips = page.locator(".file-chip") + expect(file_chips).to_have_count(2) + expect(file_chips.first).to_contain_text("file1.md") + + +class TestEventDetailModal: + """Tests for the event detail modal.""" + + def test_modal_shows_event_type( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event modal should show the event type.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Click on event + page.locator(".event-item").first.click() + + # Modal should show event type + modal_body = page.locator("#modal-body") + expect(modal_body).to_contain_text("rlm.run.start") + + def test_modal_shows_timestamp( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event modal should show the timestamp.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Click on event + page.locator(".event-item").first.click() + + # Modal should show timestamp section + modal_body = page.locator("#modal-body") + expect(modal_body).to_contain_text("Timestamp") + + def test_modal_shows_code_block( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event modal should show code block for code events.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event( + "rlm.code.found", + iteration=1, + event_id=0, + code="result = 2 + 2\nprint(result)", + ), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Click on event + page.locator(".event-item").first.click() + + # Modal should show code + modal_body = page.locator("#modal-body") + expect(modal_body).to_contain_text("Code") + expect(modal_body).to_contain_text("result = 2 + 2") + + def test_modal_shows_output( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event modal should show output for execution events.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event( + "rlm.code.end", + iteration=1, + event_id=0, + output="4\n", + ), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Click on event + page.locator(".event-item").first.click() + + # Modal should show output + modal_body = page.locator("#modal-body") + expect(modal_body).to_contain_text("Output") + + def test_modal_close_button( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Close button should close the event modal.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Click on event to open modal + page.locator(".event-item").first.click() + + event_modal = page.locator("#event-modal") + expect(event_modal).not_to_have_class(re.compile(r"hidden")) + + # Click close button + page.locator("#modal-close").click() + + expect(event_modal).to_have_class(re.compile(r"hidden")) + + def test_click_outside_closes_modal( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking outside event modal should close it.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Click on event to open modal + page.locator(".event-item").first.click() + + event_modal = page.locator("#event-modal") + expect(event_modal).not_to_have_class(re.compile(r"hidden")) + + # Click on modal overlay (outside the modal content) + event_modal.click(position={"x": 10, "y": 10}) + + expect(event_modal).to_have_class(re.compile(r"hidden")) diff --git a/contributing/samples/rlm/tests/ui/test_page_load.py b/contributing/samples/rlm/tests/ui/test_page_load.py new file mode 100644 index 0000000000..dd826c462d --- /dev/null +++ b/contributing/samples/rlm/tests/ui/test_page_load.py @@ -0,0 +1,314 @@ +""" +UI tests for page load and WebSocket connection. + +These tests verify the initial page load behavior and WebSocket connection handling. +""" + +import re + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import create_mock_session +from .conftest import create_mock_sessions_list +from .conftest import WebSocketInterceptor + +pytestmark = pytest.mark.ui + + +class TestPageLoad: + """Tests for initial page load.""" + + def test_page_loads_with_correct_title( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Page should load with 'ADK-RLM' title.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + expect(page).to_have_title("ADK-RLM") + + def test_header_displays_logo( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Header should display logo icon and 'ADK-RLM' text.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + logo_icon = page.locator(".logo-icon") + logo_text = page.locator(".logo-text") + + expect(logo_icon).to_be_visible() + expect(logo_text).to_have_text("ADK-RLM") + + def test_empty_state_shown_initially( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Empty state message should be shown when no conversation exists.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(conversation=[])] + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + empty_state = page.locator("#empty-state") + expect(empty_state).to_be_visible() + expect(empty_state).to_contain_text("Recursive Language Model") + + def test_session_sidebar_visible( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Session sidebar should be visible with 'Sessions' title.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + sidebar = page.locator("#session-sidebar") + sidebar_title = page.locator(".sidebar-title") + + expect(sidebar).to_be_visible() + expect(sidebar_title).to_have_text("Sessions") + + def test_settings_button_visible( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Settings button should be visible in header.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + settings_btn = page.locator("#config-btn") + expect(settings_btn).to_be_visible() + expect(settings_btn).to_contain_text("Settings") + + def test_input_area_visible( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Input textarea and send button should be visible.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + input_area = page.locator("#prompt-input") + send_btn = page.locator("#send-btn") + + expect(input_area).to_be_visible() + expect(send_btn).to_be_visible() + + def test_event_log_panel_visible( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Event log panel should be visible by default.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + event_log = page.locator("#event-log-panel") + event_log_title = page.locator(".event-log-title") + + expect(event_log).to_be_visible() + expect(event_log_title).to_have_text("Event Log") + + +class TestWebSocketConnection: + """Tests for WebSocket connection behavior.""" + + @pytest.mark.skip( + reason="Mock WebSocket connects too fast to test transient state" + ) + def test_status_shows_connecting_initially( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Status badge should show 'Connecting...' initially.""" + mock_ws.setup() + # Don't set auto-responses yet to test initial state + + page.goto(live_server) + + status_badge = page.locator("#status-badge") + # Initially should show Connecting... + expect(status_badge).to_have_text("Connecting...") + + def test_status_shows_connected_after_websocket_connects( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Status badge should show 'Connected' after WebSocket connects.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + status_badge = page.locator("#status-badge") + expect(status_badge).to_have_text("Connected") + expect(status_badge).to_have_class(re.compile(r"connected")) + + @pytest.mark.skip( + reason="Mock WebSocket connects too fast to test disconnected state" + ) + def test_send_button_disabled_when_disconnected( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Send button should be disabled when not connected.""" + # Don't set up mock_ws to simulate no connection + page.goto(live_server) + + # Wait briefly for initial state + page.wait_for_timeout(100) + + send_btn = page.locator("#send-btn") + # Button should be disabled initially before connection + expect(send_btn).to_be_disabled() + + def test_send_button_enabled_when_connected( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Send button should be enabled when connected.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + send_btn = page.locator("#send-btn") + expect(send_btn).to_be_enabled() + + def test_initial_get_status_sent( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Client should send get_status action on connection.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Wait for messages to be sent + page.wait_for_timeout(200) + + received = mock_ws.get_received_messages() + actions = [msg.get("action") for msg in received] + + assert "get_status" in actions + + def test_initial_list_sessions_sent( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Client should send list_sessions action on connection.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Wait for messages to be sent + page.wait_for_timeout(200) + + received = mock_ws.get_received_messages() + actions = [msg.get("action") for msg in received] + + assert "list_sessions" in actions + + def test_session_title_populated_from_status_response( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Session title should be populated from status_response.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(title="My Test Session")] + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Wait for status response to be processed + page.wait_for_timeout(200) + + session_title = page.locator("#session-title") + expect(session_title).to_have_text("My Test Session") + + +class TestSessionListPopulation: + """Tests for session list population from WebSocket.""" + + def test_sessions_list_populated( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Session list should be populated from sessions_list response.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response( + "list_sessions", + [ + create_mock_sessions_list( + sessions=[ + { + "session_id": "sess-1", + "title": "Session One", + "updated_at": "2024-01-15T10:00:00", + "message_count": 3, + }, + { + "session_id": "sess-2", + "title": "Session Two", + "updated_at": "2024-01-14T09:00:00", + "message_count": 7, + }, + ] + ) + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Wait for sessions list to be rendered + page.wait_for_timeout(300) + + session_items = page.locator(".session-item") + expect(session_items).to_have_count(2) + + first_session = session_items.first + expect(first_session).to_contain_text("Session One") + + def test_empty_sessions_message_when_no_sessions( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Should show 'No sessions yet' when session list is empty.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response( + "list_sessions", [create_mock_sessions_list(sessions=[])] + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Wait for sessions list to be rendered + page.wait_for_timeout(300) + + empty_message = page.locator(".empty-sessions") + expect(empty_message).to_have_text("No sessions yet") diff --git a/contributing/samples/rlm/tests/ui/test_query_submission.py b/contributing/samples/rlm/tests/ui/test_query_submission.py new file mode 100644 index 0000000000..7a19f9ac0d --- /dev/null +++ b/contributing/samples/rlm/tests/ui/test_query_submission.py @@ -0,0 +1,551 @@ +""" +UI tests for query submission and processing. + +These tests verify the query input, submission, and response handling behavior. +""" + +import re + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import create_mock_event +from .conftest import create_mock_session +from .conftest import create_mock_sessions_list +from .conftest import WebSocketInterceptor + +pytestmark = pytest.mark.ui + + +class TestInputArea: + """Tests for the input textarea and send button.""" + + def test_input_accepts_text( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Input textarea should accept text.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("What is 2+2?") + + expect(prompt_input).to_have_value("What is 2+2?") + + def test_input_placeholder( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Input should have placeholder text.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + prompt_input = page.locator("#prompt-input") + expect(prompt_input).to_have_attribute("placeholder", "Ask a question...") + + def test_enter_submits_message( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Pressing Enter should submit the message.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test query") + prompt_input.press("Enter") + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + query_msgs = [m for m in received if m.get("action") == "query"] + + assert len(query_msgs) == 1 + assert query_msgs[0]["prompt"] == "Test query" + + def test_shift_enter_adds_newline( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Pressing Shift+Enter should add a newline, not submit.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Line 1") + prompt_input.press("Shift+Enter") + prompt_input.type("Line 2") + + page.wait_for_timeout(100) + + # Value should contain newline + value = prompt_input.input_value() + assert "Line 1" in value + assert "Line 2" in value + + # Should NOT have sent a query + received = mock_ws.get_received_messages() + query_msgs = [m for m in received if m.get("action") == "query"] + assert len(query_msgs) == 0 + + def test_send_button_submits_message( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking send button should submit the message.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + prompt_input = page.locator("#prompt-input") + send_btn = page.locator("#send-btn") + + prompt_input.fill("Button test query") + send_btn.click() + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + query_msgs = [m for m in received if m.get("action") == "query"] + + assert len(query_msgs) == 1 + assert query_msgs[0]["prompt"] == "Button test query" + + def test_input_cleared_after_submit( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Input should be cleared after submitting.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test query") + prompt_input.press("Enter") + + page.wait_for_timeout(100) + + expect(prompt_input).to_have_value("") + + def test_empty_input_not_submitted( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Empty input should not be submitted.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + prompt_input = page.locator("#prompt-input") + prompt_input.press("Enter") + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + query_msgs = [m for m in received if m.get("action") == "query"] + + assert len(query_msgs) == 0 + + +class TestUserMessage: + """Tests for user message display.""" + + def test_user_message_displayed( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Submitted query should appear as user message.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("What is the capital of France?") + prompt_input.press("Enter") + + page.wait_for_timeout(100) + + # User message should appear + user_message = page.locator(".message.user") + expect(user_message).to_be_visible() + expect(user_message).to_contain_text("What is the capital of France?") + + def test_empty_state_hidden_after_message( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Empty state should be hidden after sending a message.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Empty state should be visible initially + empty_state = page.locator("#empty-state") + expect(empty_state).to_be_visible() + + # Send a message + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Hello") + prompt_input.press("Enter") + + page.wait_for_timeout(100) + + # Empty state should be hidden + expect(empty_state).not_to_be_visible() + + +class TestProcessingState: + """Tests for processing state during query execution.""" + + def test_processing_indicator_shows( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Processing indicator should appear during query execution.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test query") + prompt_input.press("Enter") + + page.wait_for_timeout(100) + + # Processing indicator should be visible + processing = page.locator("#processing") + expect(processing).not_to_have_class(re.compile(r"hidden")) + + def test_send_button_disabled_during_processing( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Send button should be disabled during processing.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + send_btn = page.locator("#send-btn") + prompt_input = page.locator("#prompt-input") + + prompt_input.fill("Test query") + prompt_input.press("Enter") + + page.wait_for_timeout(100) + + # Send button should be disabled + expect(send_btn).to_be_disabled() + + def test_processing_text_updates( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Processing text should update based on events.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + create_mock_event("rlm.iteration.start", iteration=1, event_id=1), + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test query") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Processing text should show iteration + processing_text = page.locator("#processing-text") + expect(processing_text).to_contain_text("Iteration 1") + + +class TestQueryResponse: + """Tests for query response handling.""" + + def test_final_answer_displayed( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Final answer should be displayed in answer panel.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "What is 2+2?"}, + create_mock_event("rlm.run.start", iteration=0, event_id=0), + create_mock_event("rlm.iteration.start", iteration=1, event_id=1), + create_mock_event("rlm.final.detected", iteration=1, event_id=2), + create_mock_event("rlm.run.end", iteration=1, event_id=3), + { + "type": "query_complete", + "elapsed_seconds": 1.5, + "total_events": 4, + "final_answer": "The answer is 4", + "title": "What is 2+2?", + }, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("What is 2+2?") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + # Answer panel should be visible with answer + answer_panel = page.locator(".answer-panel") + expect(answer_panel).to_be_visible() + expect(answer_panel).to_contain_text("The answer is 4") + + def test_processing_hidden_after_completion( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Processing indicator should be hidden after completion.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + { + "type": "query_complete", + "elapsed_seconds": 1.0, + "total_events": 0, + "final_answer": "Done", + "title": "Test", + }, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + # Processing should be hidden + processing = page.locator("#processing") + expect(processing).to_have_class(re.compile(r"hidden")) + + def test_send_button_enabled_after_completion( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Send button should be re-enabled after completion.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + { + "type": "query_complete", + "elapsed_seconds": 1.0, + "total_events": 0, + "final_answer": "Done", + "title": "Test", + }, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + send_btn = page.locator("#send-btn") + prompt_input = page.locator("#prompt-input") + + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(300) + + # Send button should be enabled + expect(send_btn).to_be_enabled() + + +class TestErrorHandling: + """Tests for error handling during query execution.""" + + def test_error_message_displayed( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Error message should be displayed when query fails.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + {"type": "error", "message": "Something went wrong!"}, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Error message should be displayed + error_message = page.locator(".message.assistant").last + expect(error_message).to_contain_text("Something went wrong!") + + def test_processing_ends_on_error( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Processing should end when error occurs.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "query", + [ + {"type": "query_start", "prompt": "Test"}, + {"type": "error", "message": "Error occurred"}, + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + prompt_input = page.locator("#prompt-input") + prompt_input.fill("Test") + prompt_input.press("Enter") + + page.wait_for_timeout(200) + + # Processing should be hidden + processing = page.locator("#processing") + expect(processing).to_have_class(re.compile(r"hidden")) + + # Send button should be enabled + send_btn = page.locator("#send-btn") + expect(send_btn).to_be_enabled() + + +class TestConversationRestore: + """Tests for conversation restoration.""" + + def test_conversation_restored_on_load( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Existing conversation should be restored on page load.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", + [ + create_mock_session( + conversation=[ + { + "role": "user", + "content": "Hello", + "timestamp": "2024-01-15T10:00:00", + }, + { + "role": "assistant", + "content": "Hi there!", + "timestamp": "2024-01-15T10:00:05", + }, + { + "role": "user", + "content": "How are you?", + "timestamp": "2024-01-15T10:01:00", + }, + { + "role": "assistant", + "content": "I'm doing well!", + "timestamp": "2024-01-15T10:01:05", + }, + ] + ) + ], + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(300) + + # Should have 4 messages + messages = page.locator(".message") + expect(messages).to_have_count(4) + + # Verify content + user_messages = page.locator(".message.user") + expect(user_messages).to_have_count(2) + expect(user_messages.first).to_contain_text("Hello") + + assistant_messages = page.locator(".message.assistant") + expect(assistant_messages).to_have_count(2) + expect(assistant_messages.first).to_contain_text("Hi there!") diff --git a/contributing/samples/rlm/tests/ui/test_session_management.py b/contributing/samples/rlm/tests/ui/test_session_management.py new file mode 100644 index 0000000000..ee494ae38b --- /dev/null +++ b/contributing/samples/rlm/tests/ui/test_session_management.py @@ -0,0 +1,486 @@ +""" +UI tests for session management. + +These tests verify session creation, loading, deletion, and sidebar interactions. +""" + +import re + +from playwright.sync_api import expect +from playwright.sync_api import Page +import pytest + +from .conftest import create_mock_session +from .conftest import create_mock_sessions_list +from .conftest import WebSocketInterceptor + +pytestmark = pytest.mark.ui + + +class TestSessionSidebar: + """Tests for session sidebar interactions.""" + + def test_collapse_sidebar( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Sidebar should collapse when close button is clicked.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + sidebar = page.locator("#session-sidebar") + close_btn = page.locator("#toggle-sidebar-close") + + # Sidebar should be visible initially + expect(sidebar).not_to_have_class(re.compile(r"collapsed")) + + # Click close button + close_btn.click() + + # Sidebar should be collapsed + expect(sidebar).to_have_class(re.compile(r"collapsed")) + + def test_expand_sidebar( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Sidebar should expand when open button is clicked.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + mock_ws.wait_for_connection() + + sidebar = page.locator("#session-sidebar") + close_btn = page.locator("#toggle-sidebar-close") + open_btn = page.locator("#toggle-sidebar-open") + + # Collapse first + close_btn.click() + expect(sidebar).to_have_class(re.compile(r"collapsed")) + + # Click open button + open_btn.click() + + # Sidebar should be expanded + expect(sidebar).not_to_have_class(re.compile(r"collapsed")) + + +class TestNewSession: + """Tests for creating new sessions.""" + + def test_new_session_button_visible( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """New session button should be visible in sidebar.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + + page.goto(live_server) + + new_session_btn = page.locator("#new-session-btn") + expect(new_session_btn).to_be_visible() + + def test_new_session_sends_action( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking new session button should send new_session action.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "new_session", + [{ + "type": "session_created", + "session_id": "new-session-id", + "title": "Session 2024-01-15 12:00", + }], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + # Click new session button + new_session_btn = page.locator("#new-session-btn") + new_session_btn.click() + + # Wait for message to be sent + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + actions = [msg.get("action") for msg in received] + + assert "new_session" in actions + + def test_new_session_clears_ui( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Creating new session should clear the conversation UI.""" + mock_ws.setup() + # Start with a session that has conversation + mock_ws.set_auto_response( + "get_status", + [ + create_mock_session( + conversation=[ + { + "role": "user", + "content": "Hello", + "timestamp": "2024-01-15T10:00:00", + }, + { + "role": "assistant", + "content": "Hi there!", + "timestamp": "2024-01-15T10:00:05", + }, + ] + ) + ], + ) + mock_ws.set_auto_response("list_sessions", [create_mock_sessions_list()]) + mock_ws.set_auto_response( + "new_session", + [{ + "type": "session_created", + "session_id": "new-session-id", + "title": "New Session", + }], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(200) + + # Should have messages initially + messages = page.locator(".message") + expect(messages).to_have_count(2) + + # Click new session + page.locator("#new-session-btn").click() + page.wait_for_timeout(200) + + # Messages should be cleared, empty state should show + empty_state = page.locator("#empty-state") + expect(empty_state).to_be_visible() + + +class TestLoadSession: + """Tests for loading existing sessions.""" + + def test_click_session_sends_load_action( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking a session item should send load_session action.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(session_id="current")] + ) + mock_ws.set_auto_response( + "list_sessions", + [ + create_mock_sessions_list( + sessions=[ + { + "session_id": "other-session", + "title": "Other Session", + "updated_at": "2024-01-15T10:00:00", + "message_count": 3, + }, + ] + ) + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(300) + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + # Click the session item + session_item = page.locator(".session-item").first + session_item.click() + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + + # Should have sent load_session with the session_id + load_msgs = [m for m in received if m.get("action") == "load_session"] + assert len(load_msgs) == 1 + assert load_msgs[0]["session_id"] == "other-session" + + def test_session_loaded_updates_ui( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Loading a session should update the UI with session data.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", + [ + create_mock_session( + session_id="current", + title="Current Session", + conversation=[], + ) + ], + ) + mock_ws.set_auto_response( + "list_sessions", + [ + create_mock_sessions_list( + sessions=[ + { + "session_id": "other-session", + "title": "Loaded Session", + "updated_at": "2024-01-15T10:00:00", + "message_count": 2, + }, + ] + ) + ], + ) + mock_ws.set_auto_response( + "load_session", + [{ + "type": "session_loaded", + "session_id": "other-session", + "title": "Loaded Session", + "model": "gemini-3-pro-preview", + "sub_model": "gemini-3-pro-preview", + "max_iterations": 30, + "files": [], + "conversation": [ + { + "role": "user", + "content": "Question?", + "timestamp": "2024-01-15T10:00:00", + }, + { + "role": "assistant", + "content": "Answer!", + "timestamp": "2024-01-15T10:00:05", + }, + ], + "events": [], + }], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(300) + + # Click the session item + session_item = page.locator(".session-item").first + session_item.click() + + page.wait_for_timeout(300) + + # Session title should be updated + session_title = page.locator("#session-title") + expect(session_title).to_have_text("Loaded Session") + + # Conversation should be restored + messages = page.locator(".message") + expect(messages).to_have_count(2) + + +class TestDeleteSession: + """Tests for deleting sessions.""" + + def test_delete_button_visible_on_hover( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Delete button should become visible on session item hover.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response( + "list_sessions", + [ + create_mock_sessions_list( + sessions=[ + { + "session_id": "test-session", + "title": "Test Session", + "updated_at": "2024-01-15T10:00:00", + "message_count": 1, + }, + ] + ) + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(300) + + session_item = page.locator(".session-item").first + delete_btn = session_item.locator(".session-item-delete") + + # Delete button should be hidden initially (opacity 0) + expect(delete_btn).to_have_css("opacity", "0") + + # Hover over session item + session_item.hover() + + # Delete button should be visible + expect(delete_btn).to_have_css("opacity", "1") + + def test_delete_session_sends_action( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Clicking delete button should send delete_session action after confirmation.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(session_id="current")] + ) + mock_ws.set_auto_response( + "list_sessions", + [ + create_mock_sessions_list( + sessions=[ + { + "session_id": "to-delete", + "title": "Session To Delete", + "updated_at": "2024-01-15T10:00:00", + "message_count": 1, + }, + ] + ) + ], + ) + mock_ws.set_auto_response( + "delete_session", + [{ + "type": "session_deleted", + "session_id": "to-delete", + "success": True, + }], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(300) + + # Set up dialog handler to accept confirmation + page.on("dialog", lambda dialog: dialog.accept()) + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + # Hover and click delete + session_item = page.locator(".session-item").first + delete_btn = session_item.locator(".session-item-delete") + session_item.hover() + delete_btn.click() + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + + # Should have sent delete_session + delete_msgs = [m for m in received if m.get("action") == "delete_session"] + assert len(delete_msgs) == 1 + assert delete_msgs[0]["session_id"] == "to-delete" + + def test_delete_cancelled_no_action( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Cancelling delete confirmation should not send action.""" + mock_ws.setup() + mock_ws.set_auto_response("get_status", [create_mock_session()]) + mock_ws.set_auto_response( + "list_sessions", + [ + create_mock_sessions_list( + sessions=[ + { + "session_id": "test-session", + "title": "Test Session", + "updated_at": "2024-01-15T10:00:00", + "message_count": 1, + }, + ] + ) + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(300) + + # Set up dialog handler to dismiss confirmation + page.on("dialog", lambda dialog: dialog.dismiss()) + + # Clear received messages + page.evaluate("() => window._mockWsReceived = []") + + # Hover and click delete + session_item = page.locator(".session-item").first + delete_btn = session_item.locator(".session-item-delete") + session_item.hover() + delete_btn.click() + + page.wait_for_timeout(100) + + received = mock_ws.get_received_messages() + + # Should NOT have sent delete_session + delete_msgs = [m for m in received if m.get("action") == "delete_session"] + assert len(delete_msgs) == 0 + + +class TestSessionActive: + """Tests for active session highlighting.""" + + def test_current_session_highlighted( + self, page: Page, mock_ws: WebSocketInterceptor, live_server: str + ): + """Current session should be highlighted with 'active' class.""" + mock_ws.setup() + mock_ws.set_auto_response( + "get_status", [create_mock_session(session_id="active-session")] + ) + mock_ws.set_auto_response( + "list_sessions", + [ + create_mock_sessions_list( + sessions=[ + { + "session_id": "active-session", + "title": "Active Session", + "updated_at": "2024-01-15T10:00:00", + "message_count": 1, + }, + { + "session_id": "other-session", + "title": "Other Session", + "updated_at": "2024-01-14T10:00:00", + "message_count": 2, + }, + ] + ) + ], + ) + + page.goto(live_server) + mock_ws.wait_for_connection() + page.wait_for_timeout(300) + + # Find session items + active_item = page.locator( + ".session-item[data-session-id='active-session']" + ) + other_item = page.locator(".session-item[data-session-id='other-session']") + + # Active session should have 'active' class + expect(active_item).to_have_class(re.compile(r"active")) + expect(other_item).not_to_have_class(re.compile(r"active"))