From 49e7f339bc49f0526158a79ddb073fd908eb1ae3 Mon Sep 17 00:00:00 2001 From: Rishi Tank Date: Fri, 2 Jan 2026 03:02:28 +0000 Subject: [PATCH 01/37] feat: workflow improvements and MCP enhancements Workflow Improvements: - Update release workflow to trigger on successful CI after main branch merge - Add workflow_dispatch for manual releases with version input - Auto-create and push tags when releasing via workflow_run - Add dependabot.yml for Cargo, GitHub Actions, and Docker updates - Add sdk-sync.yml for weekly Augment SDK sync checks MCP Enhancements: - Add prompt templates support (prompts/list, prompts/get methods) - Implement PromptRegistry with built-in prompts: - code_review: Review code for quality, bugs, and best practices - explain_code: Explain what a piece of code does - write_tests: Generate test cases for code - Declare prompts capability in server initialization Documentation: - Add docs/MCP_IMPROVEMENTS.md with feature roadmap --- .github/dependabot.yml | 67 ++++++++ .github/workflows/release.yml | 96 +++++++++++- .github/workflows/sdk-sync.yml | 115 ++++++++++++++ Cargo.lock | 2 +- docs/MCP_IMPROVEMENTS.md | 228 +++++++++++++++++++++++++++ src/mcp/mod.rs | 3 + src/mcp/prompts.rs | 277 +++++++++++++++++++++++++++++++++ src/mcp/server.rs | 59 ++++++- 8 files changed, 842 insertions(+), 5 deletions(-) create mode 100644 .github/dependabot.yml create mode 100644 .github/workflows/sdk-sync.yml create mode 100644 docs/MCP_IMPROVEMENTS.md create mode 100644 src/mcp/prompts.rs diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..da6bdea --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,67 @@ +version: 2 +updates: + # Rust dependencies (Cargo) + - package-ecosystem: "cargo" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "09:00" + timezone: "America/Los_Angeles" + open-pull-requests-limit: 10 + commit-message: + prefix: "deps" + labels: + - "dependencies" + - "rust" + reviewers: + - "rishitank" + groups: + # Group minor/patch updates to reduce PR noise + rust-minor-updates: + patterns: + - "*" + update-types: + - "minor" + - "patch" + # Keep major updates separate for careful review + rust-major-updates: + patterns: + - "*" + update-types: + - "major" + + # GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "09:00" + timezone: "America/Los_Angeles" + open-pull-requests-limit: 5 + commit-message: + prefix: "ci" + labels: + - "dependencies" + - "github-actions" + reviewers: + - "rishitank" + + # Docker (if Dockerfile exists) + - package-ecosystem: "docker" + directory: "/" + schedule: + interval: "weekly" + day: "monday" + time: "09:00" + timezone: "America/Los_Angeles" + open-pull-requests-limit: 5 + commit-message: + prefix: "docker" + labels: + - "dependencies" + - "docker" + reviewers: + - "rishitank" + diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index eae9503..f9a5b82 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,15 +1,91 @@ name: Release on: + # Manual trigger with version input + workflow_dispatch: + inputs: + version: + description: 'Release version (e.g., 2.0.1)' + required: true + type: string + prerelease: + description: 'Is this a pre-release?' + required: false + type: boolean + default: false + # Triggered by tag push (manual releases) push: tags: - 'v*' + # Triggered after successful CI on main + workflow_run: + workflows: ["CI"] + types: + - completed + branches: + - main permissions: contents: write jobs: + # Check if release should proceed + check: + runs-on: ubuntu-latest + outputs: + should_release: ${{ steps.check.outputs.should_release }} + version: ${{ steps.check.outputs.version }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Check release conditions + id: check + run: | + # For workflow_dispatch, always release with provided version + if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then + echo "should_release=true" >> $GITHUB_OUTPUT + echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT + exit 0 + fi + + # For tag push, always release + if [ "${{ github.event_name }}" == "push" ] && [[ "${{ github.ref }}" == refs/tags/v* ]]; then + echo "should_release=true" >> $GITHUB_OUTPUT + echo "version=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT + exit 0 + fi + + # For workflow_run, check if CI succeeded and version changed + if [ "${{ github.event_name }}" == "workflow_run" ]; then + if [ "${{ github.event.workflow_run.conclusion }}" != "success" ]; then + echo "CI did not succeed, skipping release" + echo "should_release=false" >> $GITHUB_OUTPUT + exit 0 + fi + + # Get version from Cargo.toml + VERSION=$(grep '^version = ' Cargo.toml | head -1 | sed 's/version = "\(.*\)"/\1/') + + # Check if this version tag already exists + if git tag -l "v$VERSION" | grep -q .; then + echo "Tag v$VERSION already exists, skipping release" + echo "should_release=false" >> $GITHUB_OUTPUT + else + echo "New version v$VERSION detected, will release" + echo "should_release=true" >> $GITHUB_OUTPUT + echo "version=$VERSION" >> $GITHUB_OUTPUT + fi + exit 0 + fi + + echo "Unknown trigger, skipping release" + echo "should_release=false" >> $GITHUB_OUTPUT + build: + needs: check + if: needs.check.outputs.should_release == 'true' strategy: matrix: include: @@ -46,24 +122,38 @@ jobs: path: ${{ matrix.artifact }} release: - needs: build + needs: [check, build] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + with: + fetch-depth: 0 - name: Download all artifacts uses: actions/download-artifact@v4 with: path: artifacts + - name: Create and push tag + if: github.event_name == 'workflow_run' || github.event_name == 'workflow_dispatch' + run: | + VERSION="${{ needs.check.outputs.version }}" + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git tag -a "v$VERSION" -m "Release v$VERSION" + git push origin "v$VERSION" + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + - name: Create Release uses: softprops/action-gh-release@v1 with: + tag_name: v${{ needs.check.outputs.version }} + name: v${{ needs.check.outputs.version }} files: | artifacts/**/* generate_release_notes: true draft: false - prerelease: false + prerelease: ${{ github.event_name == 'workflow_dispatch' && inputs.prerelease || false }} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - diff --git a/.github/workflows/sdk-sync.yml b/.github/workflows/sdk-sync.yml new file mode 100644 index 0000000..1de2c64 --- /dev/null +++ b/.github/workflows/sdk-sync.yml @@ -0,0 +1,115 @@ +name: SDK Sync Check + +# This workflow helps track when the Augment SDK may need updates +# Since the Augment SDK is implemented locally (not from a package registry), +# this workflow periodically checks for API changes and creates issues/reminders + +on: + schedule: + # Run every Monday at 10am UTC + - cron: '0 10 * * 1' + workflow_dispatch: + inputs: + create_issue: + description: 'Create a tracking issue' + required: false + type: boolean + default: true + +permissions: + contents: read + issues: write + +jobs: + check-sdk: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: actions/cache@v4 + with: + path: | + ~/.cargo/bin/ + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + target/ + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + + - name: Check SDK builds successfully + run: cargo build --release + + - name: Run SDK tests + run: cargo test --lib -- sdk + + - name: Check for SDK-related TODOs + id: todos + run: | + # Find any TODOs related to SDK updates + TODOS=$(grep -r "TODO.*SDK\|TODO.*Augment\|FIXME.*SDK\|FIXME.*Augment" src/sdk/ 2>/dev/null || echo "") + if [ -n "$TODOS" ]; then + echo "Found SDK-related TODOs:" + echo "$TODOS" + echo "has_todos=true" >> $GITHUB_OUTPUT + echo "todos<> $GITHUB_OUTPUT + echo "$TODOS" >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + else + echo "No SDK-related TODOs found" + echo "has_todos=false" >> $GITHUB_OUTPUT + fi + + - name: Create tracking issue + if: steps.todos.outputs.has_todos == 'true' || github.event.inputs.create_issue == 'true' + uses: actions/github-script@v7 + with: + script: | + const todos = `${{ steps.todos.outputs.todos }}`; + const title = `[SDK Sync] Weekly Augment SDK Review - ${new Date().toISOString().split('T')[0]}`; + + // Check if issue already exists this week + const { data: issues } = await github.rest.issues.listForRepo({ + owner: context.repo.owner, + repo: context.repo.repo, + state: 'open', + labels: 'sdk-sync', + per_page: 5 + }); + + const weekStart = new Date(); + weekStart.setDate(weekStart.getDate() - weekStart.getDay()); + + const existingIssue = issues.find(i => + new Date(i.created_at) >= weekStart + ); + + if (existingIssue) { + console.log(`Issue already exists: #${existingIssue.number}`); + return; + } + + let body = `## Weekly SDK Sync Check\n\n`; + body += `This is an automated reminder to review the Augment SDK implementation.\n\n`; + + if (todos) { + body += `### Found TODOs\n\n\`\`\`\n${todos}\n\`\`\`\n\n`; + } + + body += `### Checklist\n\n`; + body += `- [ ] Check if Augment API has new endpoints\n`; + body += `- [ ] Review any SDK-related issues or feedback\n`; + body += `- [ ] Update type definitions if needed\n`; + body += `- [ ] Run integration tests with latest API\n`; + + await github.rest.issues.create({ + owner: context.repo.owner, + repo: context.repo.repo, + title: title, + body: body, + labels: ['sdk-sync', 'maintenance'] + }); + diff --git a/Cargo.lock b/Cargo.lock index efc028d..61fa0a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -414,7 +414,7 @@ checksum = "75984efb6ed102a0d42db99afb6c1948f0380d1d91808d5529916e6c08b49d8d" [[package]] name = "context-engine-rs" -version = "2.0.0" +version = "2.0.1" dependencies = [ "anyhow", "async-compression", diff --git a/docs/MCP_IMPROVEMENTS.md b/docs/MCP_IMPROVEMENTS.md new file mode 100644 index 0000000..2b8ecfb --- /dev/null +++ b/docs/MCP_IMPROVEMENTS.md @@ -0,0 +1,228 @@ +# MCP Server Improvement Roadmap + +This document outlines potential improvements to make the Context Engine MCP Server more powerful and fully utilize the MCP specification. + +## Current Implementation Status + +### ✅ Fully Implemented +- **Tools** - All 49 tools for retrieval, indexing, memory, planning, and review +- **JSON-RPC 2.0** - Full request/response/notification handling +- **Stdio Transport** - Standard input/output for MCP clients +- **HTTP Transport** - Axum-based HTTP server with SSE +- **Logging Capability** - Structured logging support +- **Tools List Changed** - Dynamic tool list notifications + +### 🔶 Partially Implemented +- **Resources** - Capability declared but not actively used +- **Prompts** - Capability declared but no prompts defined + +### ❌ Not Yet Implemented +- **Resource Subscriptions** - Subscribe to file/resource changes +- **Prompt Templates** - Pre-defined prompt templates with arguments +- **Completions API** - Autocomplete suggestions for prompts/resources +- **Progress Notifications** - Long-running operation progress +- **Cancellation** - Cancel in-progress operations + +--- + +## High-Value Improvements + +### 1. Resource Subscriptions (High Priority) + +Enable clients to subscribe to file changes in the codebase. + +**Use Case:** Real-time code updates as files change + +```json +// Subscribe to a file +{"method": "resources/subscribe", "params": {"uri": "file:///src/main.rs"}} + +// Server sends notification when file changes +{"method": "notifications/resources/updated", "params": {"uri": "file:///src/main.rs"}} +``` + +**Implementation:** +- Integrate with existing `watcher` module for file system monitoring +- Track subscribed URIs per client session +- Emit notifications on file changes + +### 2. Prompt Templates (High Priority) + +Pre-defined prompts that guide AI assistants in common tasks. + +**Proposed Prompts:** + +| Prompt Name | Description | Arguments | +|-------------|-------------|-----------| +| `code_review` | Review code changes | `file_path`, `focus_areas` | +| `explain_code` | Explain a code section | `code`, `level` (beginner/advanced) | +| `write_tests` | Generate test cases | `file_path`, `function_name` | +| `debug_issue` | Help debug an issue | `error_message`, `stack_trace` | +| `refactor` | Suggest refactoring | `code`, `goals` | +| `document` | Generate documentation | `code`, `style` (jsdoc/rustdoc) | + +**Implementation:** +- Add `prompts/list` and `prompts/get` handlers +- Store prompts as structured templates +- Support argument substitution + +### 3. Progress Notifications (Medium Priority) + +Report progress for long-running operations like indexing. + +**Use Case:** Show progress during full codebase indexing + +```json +// Server sends progress updates +{ + "method": "notifications/progress", + "params": { + "progressToken": "index-123", + "progress": 45, + "total": 100, + "message": "Indexing src/..." + } +} +``` + +**Implementation:** +- Add progress token to long-running tool calls +- Emit periodic progress notifications +- Track active operations for cancellation + +### 4. Completions API (Medium Priority) + +Provide autocomplete suggestions for tool arguments. + +**Use Case:** Autocomplete file paths, function names + +```json +// Request completions for file path +{ + "method": "completion/complete", + "params": { + "ref": {"type": "ref/resource", "uri": "file:///src/"}, + "argument": {"name": "path", "value": "src/m"} + } +} + +// Response +{ + "result": { + "completion": { + "values": ["src/main.rs", "src/mcp/", "src/metrics/"], + "hasMore": true + } + } +} +``` + +**Implementation:** +- Integrate with index for file/symbol completion +- Cache recent completions for performance +- Support fuzzy matching + +### 5. Request Cancellation (Low Priority) + +Allow clients to cancel in-progress operations. + +**Implementation:** +- Track active requests with cancellation tokens +- Check cancellation token during long operations +- Clean up resources on cancellation + +--- + +## Performance Improvements + +### 1. Caching Layer +- Cache semantic search results with LRU eviction +- Cache file content hashes for change detection +- Memoize expensive computations + +### 2. Batch Operations +- Support batch tool calls in single request +- Parallel execution for independent operations + +### 3. Streaming Responses +- Stream large search results +- Progressive rendering for code reviews + +--- + +## Enhanced Tool Capabilities + +### Current Tools (49) +- **Retrieval (6):** semantic_search, grep_search, file_search, etc. +- **Index (5):** index_status, index_directory, clear_index, etc. +- **Memory (4):** memory_store, memory_retrieve, memory_list, memory_delete +- **Planning (20):** create_review, analyze_changes, etc. +- **Review (14):** review_code, suggest_fixes, etc. + +### Potential New Tools + +| Tool | Description | Priority | +|------|-------------|----------| +| `diff_files` | Compare two files | High | +| `find_references` | Find all references to a symbol | High | +| `go_to_definition` | Find definition of a symbol | High | +| `call_hierarchy` | Show call graph for a function | Medium | +| `type_hierarchy` | Show class/type inheritance | Medium | +| `ast_query` | Query AST with tree-sitter | Medium | +| `git_blame` | Show git blame for a file | Low | +| `git_history` | Show commit history | Low | +| `dependency_graph` | Show module dependencies | Low | + +--- + +## Architecture Improvements + +### 1. Plugin System +Allow extending the server with custom tools without modifying core code. + +```rust +// Plugin trait +trait McpPlugin { + fn tools(&self) -> Vec; + fn resources(&self) -> Vec; + fn prompts(&self) -> Vec; +} +``` + +### 2. Multi-Workspace Support +Support multiple workspace roots simultaneously. + +### 3. Language Server Protocol Integration +Bridge with LSP servers for richer code intelligence. + +--- + +## Implementation Priority + +### Phase 1 (Next Release) +1. ✅ Workflow improvements (PR-based releases) +2. ✅ Dependabot configuration +3. 🔲 Prompt templates (basic set) +4. 🔲 find_references tool +5. 🔲 go_to_definition tool + +### Phase 2 +1. 🔲 Resource subscriptions +2. 🔲 Progress notifications +3. 🔲 diff_files tool +4. 🔲 Caching layer + +### Phase 3 +1. 🔲 Completions API +2. 🔲 Plugin system +3. 🔲 AST query tool +4. 🔲 Request cancellation + +--- + +## References + +- [MCP Specification](https://modelcontextprotocol.io/specification) +- [MCP TypeScript SDK](https://github.com/modelcontextprotocol/typescript-sdk) +- [MCP Python SDK](https://github.com/modelcontextprotocol/python-sdk) + diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 2b1688e..39afb0d 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -9,13 +9,16 @@ //! - `server` - MCP server implementation //! - `transport` - Transport layer (stdio, HTTP/SSE) //! - `handler` - Request/notification handlers +//! - `prompts` - Prompt templates for common tasks pub mod handler; +pub mod prompts; pub mod protocol; pub mod server; pub mod transport; pub use handler::McpHandler; +pub use prompts::PromptRegistry; pub use protocol::*; pub use server::McpServer; pub use transport::{StdioTransport, Transport}; diff --git a/src/mcp/prompts.rs b/src/mcp/prompts.rs new file mode 100644 index 0000000..ef84738 --- /dev/null +++ b/src/mcp/prompts.rs @@ -0,0 +1,277 @@ +//! MCP Prompt Templates +//! +//! Pre-defined prompts that guide AI assistants in common tasks. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// A prompt argument definition. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptArgument { + pub name: String, + pub description: String, + pub required: bool, +} + +/// A prompt template. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Prompt { + pub name: String, + pub description: String, + #[serde(default)] + pub arguments: Vec, +} + +/// A prompt message (the actual content). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PromptMessage { + pub role: String, + pub content: PromptContent, +} + +/// Prompt content types. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum PromptContent { + Text { + text: String, + }, + Resource { + uri: String, + mime_type: Option, + }, +} + +/// Result of prompts/list. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListPromptsResult { + pub prompts: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +/// Result of prompts/get. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GetPromptResult { + pub description: Option, + pub messages: Vec, +} + +/// Prompt registry. +#[derive(Debug, Clone, Default)] +pub struct PromptRegistry { + prompts: HashMap, +} + +/// Template for generating prompt messages. +#[derive(Debug, Clone)] +pub struct PromptTemplate { + pub template: String, +} + +impl PromptRegistry { + /// Create a new registry with built-in prompts. + pub fn new() -> Self { + let mut registry = Self::default(); + registry.register_builtin_prompts(); + registry + } + + /// Register built-in prompts. + fn register_builtin_prompts(&mut self) { + // Code Review Prompt + self.register( + Prompt { + name: "code_review".to_string(), + description: "Review code for quality, bugs, and best practices".to_string(), + arguments: vec![ + PromptArgument { + name: "code".to_string(), + description: "The code to review".to_string(), + required: true, + }, + PromptArgument { + name: "language".to_string(), + description: "Programming language (optional, auto-detected)".to_string(), + required: false, + }, + PromptArgument { + name: "focus".to_string(), + description: "Areas to focus on (security, performance, style)".to_string(), + required: false, + }, + ], + }, + PromptTemplate { + template: r#"Please review the following code: + +```{{language}} +{{code}} +``` + +{{#if focus}}Focus areas: {{focus}}{{/if}} + +Analyze for: +1. Potential bugs or errors +2. Security vulnerabilities +3. Performance issues +4. Code style and best practices +5. Suggestions for improvement"# + .to_string(), + }, + ); + + // Explain Code Prompt + self.register( + Prompt { + name: "explain_code".to_string(), + description: "Explain what a piece of code does".to_string(), + arguments: vec![ + PromptArgument { + name: "code".to_string(), + description: "The code to explain".to_string(), + required: true, + }, + PromptArgument { + name: "level".to_string(), + description: "Explanation level: beginner, intermediate, advanced" + .to_string(), + required: false, + }, + ], + }, + PromptTemplate { + template: + r#"Please explain the following code{{#if level}} at a {{level}} level{{/if}}: + +``` +{{code}} +``` + +Explain: +1. What the code does overall +2. How it works step by step +3. Any important patterns or techniques used"# + .to_string(), + }, + ); + + // Write Tests Prompt + self.register( + Prompt { + name: "write_tests".to_string(), + description: "Generate test cases for code".to_string(), + arguments: vec![ + PromptArgument { + name: "code".to_string(), + description: "The code to test".to_string(), + required: true, + }, + PromptArgument { + name: "framework".to_string(), + description: "Test framework (jest, pytest, cargo test, etc.)".to_string(), + required: false, + }, + ], + }, + PromptTemplate { + template: r#"Generate comprehensive tests for the following code{{#if framework}} using {{framework}}{{/if}}: + +``` +{{code}} +``` + +Include: +1. Unit tests for each function/method +2. Edge cases and boundary conditions +3. Error handling tests +4. Integration tests if applicable"#.to_string(), + }, + ); + } + + /// Register a prompt. + pub fn register(&mut self, prompt: Prompt, template: PromptTemplate) { + self.prompts.insert(prompt.name.clone(), (prompt, template)); + } + + /// List all prompts. + pub fn list(&self) -> Vec { + self.prompts.values().map(|(p, _)| p.clone()).collect() + } + + /// Get a prompt by name with arguments substituted. + pub fn get(&self, name: &str, arguments: &HashMap) -> Option { + self.prompts.get(name).map(|(prompt, template)| { + let mut text = template.template.clone(); + + // Simple template substitution + for (key, value) in arguments { + text = text.replace(&format!("{{{{{}}}}}", key), value); + } + + // Handle conditionals (very simple implementation) + // {{#if var}}content{{/if}} + for (key, value) in arguments { + let if_pattern = format!("{{{{#if {}}}}}", key); + let endif_pattern = "{{/if}}"; + + if let Some(start) = text.find(&if_pattern) { + if let Some(end) = text[start..].find(endif_pattern) { + let content = &text[start + if_pattern.len()..start + end]; + if !value.is_empty() { + text = text + .replace(&text[start..start + end + endif_pattern.len()], content); + } else { + text = + text.replace(&text[start..start + end + endif_pattern.len()], ""); + } + } + } + } + + // Clean up remaining template markers + text = text + .lines() + .filter(|line| !line.contains("{{#if") && !line.contains("{{/if}}")) + .collect::>() + .join("\n"); + + GetPromptResult { + description: Some(prompt.description.clone()), + messages: vec![PromptMessage { + role: "user".to_string(), + content: PromptContent::Text { text }, + }], + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_list_prompts() { + let registry = PromptRegistry::new(); + let prompts = registry.list(); + assert!(!prompts.is_empty()); + assert!(prompts.iter().any(|p| p.name == "code_review")); + } + + #[test] + fn test_get_prompt() { + let registry = PromptRegistry::new(); + let mut args = HashMap::new(); + args.insert("code".to_string(), "fn main() {}".to_string()); + args.insert("language".to_string(), "rust".to_string()); + + let result = registry.get("code_review", &args); + assert!(result.is_some()); + + let result = result.unwrap(); + assert_eq!(result.messages.len(), 1); + } +} diff --git a/src/mcp/server.rs b/src/mcp/server.rs index d5cb239..9ad15cb 100644 --- a/src/mcp/server.rs +++ b/src/mcp/server.rs @@ -1,11 +1,13 @@ //! MCP server implementation. use serde_json::Value; +use std::collections::HashMap; use std::sync::Arc; use tracing::{debug, error, info, warn}; use crate::error::{Error, Result}; use crate::mcp::handler::McpHandler; +use crate::mcp::prompts::PromptRegistry; use crate::mcp::protocol::*; use crate::mcp::transport::{Message, Transport}; use crate::VERSION; @@ -13,6 +15,7 @@ use crate::VERSION; /// MCP server. pub struct McpServer { handler: Arc, + prompts: Arc, name: String, version: String, } @@ -22,6 +25,21 @@ impl McpServer { pub fn new(handler: McpHandler, name: impl Into) -> Self { Self { handler: Arc::new(handler), + prompts: Arc::new(PromptRegistry::new()), + name: name.into(), + version: VERSION.to_string(), + } + } + + /// Create a new MCP server with custom prompt registry. + pub fn with_prompts( + handler: McpHandler, + prompts: PromptRegistry, + name: impl Into, + ) -> Self { + Self { + handler: Arc::new(handler), + prompts: Arc::new(prompts), name: name.into(), version: VERSION.to_string(), } @@ -64,6 +82,8 @@ impl McpServer { "initialize" => self.handle_initialize(req.params).await, "tools/list" => self.handle_list_tools().await, "tools/call" => self.handle_call_tool(req.params).await, + "prompts/list" => self.handle_list_prompts().await, + "prompts/get" => self.handle_get_prompt(req.params).await, "ping" => Ok(serde_json::json!({})), _ => Err(Error::McpProtocol(format!( "Unknown method: {}", @@ -115,7 +135,9 @@ impl McpServer { capabilities: ServerCapabilities { tools: Some(ToolsCapability { list_changed: true }), resources: None, - prompts: None, + prompts: Some(PromptsCapability { + list_changed: false, + }), logging: Some(LoggingCapability {}), }, server_info: ServerInfo { @@ -150,4 +172,39 @@ impl McpServer { let result = handler.execute(params.arguments).await?; Ok(serde_json::to_value(result)?) } + + /// Handle list prompts request. + async fn handle_list_prompts(&self) -> Result { + use crate::mcp::prompts::ListPromptsResult; + + let prompts = self.prompts.list(); + let result = ListPromptsResult { + prompts, + next_cursor: None, + }; + Ok(serde_json::to_value(result)?) + } + + /// Handle get prompt request. + async fn handle_get_prompt(&self, params: Option) -> Result { + #[derive(serde::Deserialize)] + struct GetPromptParams { + name: String, + #[serde(default)] + arguments: HashMap, + } + + let params: GetPromptParams = params + .ok_or_else(|| Error::InvalidToolArguments("Missing params".to_string())) + .and_then(|v| { + serde_json::from_value(v).map_err(|e| Error::InvalidToolArguments(e.to_string())) + })?; + + let result = self + .prompts + .get(¶ms.name, ¶ms.arguments) + .ok_or_else(|| Error::McpProtocol(format!("Prompt not found: {}", params.name)))?; + + Ok(serde_json::to_value(result)?) + } } From 494e0c0ccf12c3c11ae12744666d4fab649ccab5 Mon Sep 17 00:00:00 2001 From: Rishi Tank Date: Fri, 2 Jan 2026 03:04:41 +0000 Subject: [PATCH 02/37] fix: use Europe/London timezone in dependabot config --- .github/dependabot.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index da6bdea..0fb9741 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -7,7 +7,7 @@ updates: interval: "weekly" day: "monday" time: "09:00" - timezone: "America/Los_Angeles" + timezone: "Europe/London" open-pull-requests-limit: 10 commit-message: prefix: "deps" @@ -38,7 +38,7 @@ updates: interval: "weekly" day: "monday" time: "09:00" - timezone: "America/Los_Angeles" + timezone: "Europe/London" open-pull-requests-limit: 5 commit-message: prefix: "ci" @@ -55,7 +55,7 @@ updates: interval: "weekly" day: "monday" time: "09:00" - timezone: "America/Los_Angeles" + timezone: "Europe/London" open-pull-requests-limit: 5 commit-message: prefix: "docker" From f0e1666c0c8f5ef99d99e828fd7907cea99d6bad Mon Sep 17 00:00:00 2001 From: Rishi Tank Date: Fri, 2 Jan 2026 03:19:47 +0000 Subject: [PATCH 03/37] feat: implement full MCP spec support Resources: - Add resources/list, resources/read, resources/subscribe handlers - Create ResourceRegistry for browsing indexed files - Support file:// URI scheme for resources Progress Notifications: - Add ProgressManager and ProgressReporter for long-running ops - Support notifications/progress with token tracking Completions API: - Add completion/complete handler - Support file path and prompt name completions Request Cancellation: - Track active requests with HashSet - Handle notifications/cancelled to cancel in-progress ops Roots Support: - Parse client roots from initialize params - Store workspace roots for file path resolution New Navigation Tools (3): - find_references: Find all usages of a symbol - go_to_definition: Find where a symbol is defined - diff_files: Compare two files with unified diff Total tools: 52 (was 49) --- src/mcp/mod.rs | 6 + src/mcp/progress.rs | 173 +++++++++++++ src/mcp/resources.rs | 287 ++++++++++++++++++++++ src/mcp/server.rs | 294 +++++++++++++++++++++- src/tools/mod.rs | 9 +- src/tools/navigation.rs | 523 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 1284 insertions(+), 8 deletions(-) create mode 100644 src/mcp/progress.rs create mode 100644 src/mcp/resources.rs create mode 100644 src/tools/navigation.rs diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs index 39afb0d..a74c0bd 100644 --- a/src/mcp/mod.rs +++ b/src/mcp/mod.rs @@ -10,15 +10,21 @@ //! - `transport` - Transport layer (stdio, HTTP/SSE) //! - `handler` - Request/notification handlers //! - `prompts` - Prompt templates for common tasks +//! - `resources` - File resources for browsing codebase +//! - `progress` - Progress notifications for long-running operations pub mod handler; +pub mod progress; pub mod prompts; pub mod protocol; +pub mod resources; pub mod server; pub mod transport; pub use handler::McpHandler; +pub use progress::{ProgressManager, ProgressReporter, ProgressToken}; pub use prompts::PromptRegistry; pub use protocol::*; +pub use resources::ResourceRegistry; pub use server::McpServer; pub use transport::{StdioTransport, Transport}; diff --git a/src/mcp/progress.rs b/src/mcp/progress.rs new file mode 100644 index 0000000..f5f8454 --- /dev/null +++ b/src/mcp/progress.rs @@ -0,0 +1,173 @@ +//! MCP Progress Notifications +//! +//! Support for emitting progress updates during long-running operations. + +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::mpsc; + +/// Progress token for tracking operations. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(untagged)] +pub enum ProgressToken { + String(String), + Number(i64), +} + +/// Progress notification params. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProgressParams { + pub progress_token: ProgressToken, + pub progress: u64, + pub total: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +/// Progress notification message. +#[derive(Debug, Clone, Serialize)] +pub struct ProgressNotification { + pub jsonrpc: String, + pub method: String, + pub params: ProgressParams, +} + +impl ProgressNotification { + /// Create a new progress notification. + pub fn new( + token: ProgressToken, + progress: u64, + total: Option, + message: Option, + ) -> Self { + Self { + jsonrpc: "2.0".to_string(), + method: "notifications/progress".to_string(), + params: ProgressParams { + progress_token: token, + progress, + total, + message, + }, + } + } +} + +/// Progress reporter for emitting updates. +#[derive(Clone)] +pub struct ProgressReporter { + token: ProgressToken, + sender: mpsc::Sender, + total: Option, +} + +impl ProgressReporter { + /// Create a new progress reporter. + pub fn new( + token: ProgressToken, + sender: mpsc::Sender, + total: Option, + ) -> Self { + Self { + token, + sender, + total, + } + } + + /// Report progress. + pub async fn report(&self, progress: u64, message: Option<&str>) { + let notification = ProgressNotification::new( + self.token.clone(), + progress, + self.total, + message.map(String::from), + ); + let _ = self.sender.send(notification).await; + } + + /// Report progress with percentage. + pub async fn report_percent(&self, percent: u64, message: Option<&str>) { + let progress = if let Some(total) = self.total { + (percent * total) / 100 + } else { + percent + }; + self.report(progress, message).await; + } + + /// Complete the progress. + pub async fn complete(&self, message: Option<&str>) { + if let Some(total) = self.total { + self.report(total, message).await; + } + } +} + +/// Progress manager for creating and tracking progress reporters. +pub struct ProgressManager { + sender: mpsc::Sender, + receiver: Arc>>, + next_id: std::sync::atomic::AtomicI64, +} + +impl ProgressManager { + /// Create a new progress manager. + pub fn new() -> Self { + let (sender, receiver) = mpsc::channel(100); + Self { + sender, + receiver: Arc::new(tokio::sync::Mutex::new(receiver)), + next_id: std::sync::atomic::AtomicI64::new(1), + } + } + + /// Create a new progress reporter with a generated token. + pub fn create_reporter(&self, total: Option) -> ProgressReporter { + let id = self + .next_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let token = ProgressToken::Number(id); + ProgressReporter::new(token, self.sender.clone(), total) + } + + /// Create a progress reporter with a specific token. + pub fn create_reporter_with_token( + &self, + token: ProgressToken, + total: Option, + ) -> ProgressReporter { + ProgressReporter::new(token, self.sender.clone(), total) + } + + /// Get the receiver for progress notifications. + pub fn receiver(&self) -> Arc>> { + self.receiver.clone() + } +} + +impl Default for ProgressManager { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_progress_reporter() { + let (tx, mut rx) = mpsc::channel(10); + let reporter = + ProgressReporter::new(ProgressToken::String("test".to_string()), tx, Some(100)); + + reporter.report(50, Some("Halfway")).await; + + let notification = rx.recv().await.unwrap(); + assert_eq!(notification.params.progress, 50); + assert_eq!(notification.params.total, Some(100)); + assert_eq!(notification.params.message, Some("Halfway".to_string())); + } +} diff --git a/src/mcp/resources.rs b/src/mcp/resources.rs new file mode 100644 index 0000000..f4a5438 --- /dev/null +++ b/src/mcp/resources.rs @@ -0,0 +1,287 @@ +//! MCP Resources Support +//! +//! Expose indexed files as MCP resources that AI clients can browse and read. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use tokio::fs; +use tokio::sync::RwLock; + +use crate::error::{Error, Result}; +use crate::service::ContextService; + +/// A resource exposed by the server. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Resource { + pub uri: String, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, +} + +/// Resource contents. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceContents { + pub uri: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub blob: Option, // base64 encoded +} + +/// Result of resources/list. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListResourcesResult { + pub resources: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +/// Result of resources/read. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadResourceResult { + pub contents: Vec, +} + +/// Resource registry and manager. +pub struct ResourceRegistry { + context_service: Arc, + subscriptions: Arc>>>, // uri -> session_ids +} + +impl ResourceRegistry { + /// Create a new resource registry. + pub fn new(context_service: Arc) -> Self { + Self { + context_service, + subscriptions: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// List available resources (files in workspace). + pub async fn list(&self, cursor: Option<&str>) -> Result { + let workspace = self.context_service.workspace(); + let files = self.discover_files(workspace, 100, cursor).await?; + + let resources: Vec = files + .iter() + .map(|path| { + let relative = path + .strip_prefix(workspace) + .unwrap_or(path) + .to_string_lossy() + .to_string(); + + let uri = format!("file://{}", path.display()); + let mime_type = Self::guess_mime_type(path); + + Resource { + uri, + name: relative.clone(), + description: Some(format!("File: {}", relative)), + mime_type, + } + }) + .collect(); + + // Simple pagination - if we got max results, there might be more + let next_cursor = if resources.len() >= 100 { + resources.last().map(|r| r.name.clone()) + } else { + None + }; + + Ok(ListResourcesResult { + resources, + next_cursor, + }) + } + + /// Read a resource by URI. + pub async fn read(&self, uri: &str) -> Result { + // Parse file:// URI + let path = if let Some(path) = uri.strip_prefix("file://") { + PathBuf::from(path) + } else { + return Err(Error::InvalidToolArguments(format!( + "Invalid URI scheme: {}", + uri + ))); + }; + + // Security: ensure path is within workspace + let workspace = self.context_service.workspace(); + let canonical = path + .canonicalize() + .map_err(|e| Error::InvalidToolArguments(format!("Cannot resolve path: {}", e)))?; + + if !canonical.starts_with(workspace) { + return Err(Error::InvalidToolArguments( + "Access denied: path outside workspace".to_string(), + )); + } + + // Read file + let content = fs::read_to_string(&canonical) + .await + .map_err(|e| Error::InvalidToolArguments(format!("Cannot read file: {}", e)))?; + + let mime_type = Self::guess_mime_type(&canonical); + + Ok(ReadResourceResult { + contents: vec![ResourceContents { + uri: uri.to_string(), + mime_type, + text: Some(content), + blob: None, + }], + }) + } + + /// Subscribe to resource changes. + pub async fn subscribe(&self, uri: &str, session_id: &str) -> Result<()> { + let mut subs = self.subscriptions.write().await; + subs.entry(uri.to_string()) + .or_default() + .push(session_id.to_string()); + Ok(()) + } + + /// Unsubscribe from resource changes. + pub async fn unsubscribe(&self, uri: &str, session_id: &str) -> Result<()> { + let mut subs = self.subscriptions.write().await; + if let Some(sessions) = subs.get_mut(uri) { + sessions.retain(|s| s != session_id); + } + Ok(()) + } + + /// Discover files in directory (with pagination). + async fn discover_files( + &self, + dir: &std::path::Path, + limit: usize, + after: Option<&str>, + ) -> Result> { + use tokio::fs::read_dir; + + let mut files = Vec::new(); + let mut stack = vec![dir.to_path_buf()]; + let mut past_cursor = after.is_none(); + + while let Some(current) = stack.pop() { + if files.len() >= limit { + break; + } + + let mut entries = match read_dir(¤t).await { + Ok(e) => e, + Err(_) => continue, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + if files.len() >= limit { + break; + } + + let path = entry.path(); + let name = path.file_name().unwrap_or_default().to_string_lossy(); + + // Skip hidden files and common ignore patterns + if name.starts_with('.') || Self::should_ignore(&name) { + continue; + } + + if path.is_dir() { + stack.push(path); + } else if path.is_file() { + let relative = path + .strip_prefix(dir) + .unwrap_or(&path) + .to_string_lossy() + .to_string(); + + // Handle cursor pagination + if !past_cursor { + if Some(relative.as_str()) == after { + past_cursor = true; + } + continue; + } + + files.push(path); + } + } + } + + Ok(files) + } + + /// Check if a file should be ignored. + fn should_ignore(name: &str) -> bool { + matches!( + name, + "node_modules" | "target" | "dist" | "build" | "__pycache__" | ".git" + ) || name.ends_with(".lock") + || name.ends_with(".pyc") + } + + /// Guess MIME type from file extension. + fn guess_mime_type(path: &std::path::Path) -> Option { + let ext = path.extension()?.to_str()?; + let mime = match ext { + "rs" => "text/x-rust", + "py" => "text/x-python", + "js" => "text/javascript", + "ts" => "text/typescript", + "tsx" | "jsx" => "text/javascript", + "json" => "application/json", + "yaml" | "yml" => "text/yaml", + "toml" => "text/x-toml", + "md" => "text/markdown", + "html" => "text/html", + "css" => "text/css", + "sh" | "bash" => "text/x-shellscript", + "sql" => "text/x-sql", + "go" => "text/x-go", + "java" => "text/x-java", + "c" | "h" => "text/x-c", + "cpp" | "hpp" | "cc" => "text/x-c++", + "rb" => "text/x-ruby", + "php" => "text/x-php", + "swift" => "text/x-swift", + "kt" => "text/x-kotlin", + "scala" => "text/x-scala", + "txt" => "text/plain", + "xml" => "application/xml", + _ => "text/plain", + }; + Some(mime.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_guess_mime_type() { + assert_eq!( + ResourceRegistry::guess_mime_type(std::path::Path::new("test.rs")), + Some("text/x-rust".to_string()) + ); + assert_eq!( + ResourceRegistry::guess_mime_type(std::path::Path::new("test.py")), + Some("text/x-python".to_string()) + ); + } +} diff --git a/src/mcp/server.rs b/src/mcp/server.rs index 9ad15cb..aef30d9 100644 --- a/src/mcp/server.rs +++ b/src/mcp/server.rs @@ -1,23 +1,32 @@ //! MCP server implementation. use serde_json::Value; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; use std::sync::Arc; +use tokio::sync::RwLock; use tracing::{debug, error, info, warn}; use crate::error::{Error, Result}; use crate::mcp::handler::McpHandler; use crate::mcp::prompts::PromptRegistry; use crate::mcp::protocol::*; +use crate::mcp::resources::ResourceRegistry; use crate::mcp::transport::{Message, Transport}; +use crate::service::ContextService; use crate::VERSION; /// MCP server. pub struct McpServer { handler: Arc, prompts: Arc, + resources: Option>, name: String, version: String, + /// Workspace roots provided by the client. + roots: Arc>>, + /// Active request IDs for cancellation support. + active_requests: Arc>>, } impl McpServer { @@ -26,25 +35,42 @@ impl McpServer { Self { handler: Arc::new(handler), prompts: Arc::new(PromptRegistry::new()), + resources: None, name: name.into(), version: VERSION.to_string(), + roots: Arc::new(RwLock::new(Vec::new())), + active_requests: Arc::new(RwLock::new(HashSet::new())), } } - /// Create a new MCP server with custom prompt registry. - pub fn with_prompts( + /// Create a new MCP server with all features. + pub fn with_features( handler: McpHandler, prompts: PromptRegistry, + context_service: Arc, name: impl Into, ) -> Self { Self { handler: Arc::new(handler), prompts: Arc::new(prompts), + resources: Some(Arc::new(ResourceRegistry::new(context_service))), name: name.into(), version: VERSION.to_string(), + roots: Arc::new(RwLock::new(Vec::new())), + active_requests: Arc::new(RwLock::new(HashSet::new())), } } + /// Get the client-provided workspace roots. + pub async fn roots(&self) -> Vec { + self.roots.read().await.clone() + } + + /// Check if a request has been cancelled. + pub async fn is_cancelled(&self, id: &RequestId) -> bool { + !self.active_requests.read().await.contains(id) + } + /// Run the server with the given transport. pub async fn run(&self, mut transport: T) -> Result<()> { info!("Starting MCP server: {} v{}", self.name, self.version); @@ -78,19 +104,36 @@ impl McpServer { async fn handle_request(&self, req: JsonRpcRequest) -> JsonRpcResponse { debug!("Handling request: {} (id: {:?})", req.method, req.id); + // Track active request for cancellation + self.active_requests.write().await.insert(req.id.clone()); + let result = match req.method.as_str() { + // Core "initialize" => self.handle_initialize(req.params).await, + "ping" => Ok(serde_json::json!({})), + // Tools "tools/list" => self.handle_list_tools().await, "tools/call" => self.handle_call_tool(req.params).await, + // Prompts "prompts/list" => self.handle_list_prompts().await, "prompts/get" => self.handle_get_prompt(req.params).await, - "ping" => Ok(serde_json::json!({})), + // Resources + "resources/list" => self.handle_list_resources(req.params).await, + "resources/read" => self.handle_read_resource(req.params).await, + "resources/subscribe" => self.handle_subscribe_resource(req.params).await, + "resources/unsubscribe" => self.handle_unsubscribe_resource(req.params).await, + // Completions + "completion/complete" => self.handle_completion(req.params).await, + // Unknown _ => Err(Error::McpProtocol(format!( "Unknown method: {}", req.method ))), }; + // Remove from active requests + self.active_requests.write().await.remove(&req.id); + match result { Ok(value) => JsonRpcResponse { jsonrpc: JSONRPC_VERSION.to_string(), @@ -120,7 +163,24 @@ impl McpServer { info!("Client initialized"); } "notifications/cancelled" => { - debug!("Request cancelled"); + // Extract the request ID from params and cancel it + if let Some(params) = notif.params { + #[derive(serde::Deserialize)] + struct CancelledParams { + #[serde(rename = "requestId")] + request_id: RequestId, + } + if let Ok(cancel) = serde_json::from_value::(params) { + info!("Cancelling request: {:?}", cancel.request_id); + self.active_requests + .write() + .await + .remove(&cancel.request_id); + } + } + } + "notifications/roots/listChanged" => { + info!("Client roots changed"); } _ => { debug!("Unknown notification: {}", notif.method); @@ -129,12 +189,47 @@ impl McpServer { } /// Handle initialize request. - async fn handle_initialize(&self, _params: Option) -> Result { + async fn handle_initialize(&self, params: Option) -> Result { + // Extract roots from client if provided + if let Some(ref params) = params { + #[derive(serde::Deserialize)] + struct InitParams { + #[serde(default)] + roots: Vec, + } + #[derive(serde::Deserialize)] + struct RootInfo { + uri: String, + #[serde(default)] + name: Option, + } + + if let Ok(init) = serde_json::from_value::(params.clone()) { + let mut roots = self.roots.write().await; + for root in init.roots { + if let Some(path) = root.uri.strip_prefix("file://") { + roots.push(PathBuf::from(path)); + info!("Added client root: {} ({:?})", path, root.name); + } + } + } + } + + // Build capabilities based on what's configured + let resources_cap = if self.resources.is_some() { + Some(ResourcesCapability { + subscribe: true, + list_changed: true, + }) + } else { + None + }; + let result = InitializeResult { protocol_version: MCP_VERSION.to_string(), capabilities: ServerCapabilities { tools: Some(ToolsCapability { list_changed: true }), - resources: None, + resources: resources_cap, prompts: Some(PromptsCapability { list_changed: false, }), @@ -207,4 +302,189 @@ impl McpServer { Ok(serde_json::to_value(result)?) } + + /// Handle list resources request. + async fn handle_list_resources(&self, params: Option) -> Result { + let resources = self + .resources + .as_ref() + .ok_or_else(|| Error::McpProtocol("Resources not enabled".to_string()))?; + + #[derive(serde::Deserialize, Default)] + struct ListParams { + cursor: Option, + } + + let list_params: ListParams = params + .map(|v| serde_json::from_value(v).unwrap_or_default()) + .unwrap_or_default(); + + let result = resources.list(list_params.cursor.as_deref()).await?; + Ok(serde_json::to_value(result)?) + } + + /// Handle read resource request. + async fn handle_read_resource(&self, params: Option) -> Result { + let resources = self + .resources + .as_ref() + .ok_or_else(|| Error::McpProtocol("Resources not enabled".to_string()))?; + + #[derive(serde::Deserialize)] + struct ReadParams { + uri: String, + } + + let read_params: ReadParams = params + .ok_or_else(|| Error::InvalidToolArguments("Missing params".to_string())) + .and_then(|v| { + serde_json::from_value(v).map_err(|e| Error::InvalidToolArguments(e.to_string())) + })?; + + let result = resources.read(&read_params.uri).await?; + Ok(serde_json::to_value(result)?) + } + + /// Handle subscribe to resource. + async fn handle_subscribe_resource(&self, params: Option) -> Result { + let resources = self + .resources + .as_ref() + .ok_or_else(|| Error::McpProtocol("Resources not enabled".to_string()))?; + + #[derive(serde::Deserialize)] + struct SubscribeParams { + uri: String, + } + + let sub_params: SubscribeParams = params + .ok_or_else(|| Error::InvalidToolArguments("Missing params".to_string())) + .and_then(|v| { + serde_json::from_value(v).map_err(|e| Error::InvalidToolArguments(e.to_string())) + })?; + + // Use a placeholder session ID for now + resources.subscribe(&sub_params.uri, "default").await?; + Ok(serde_json::json!({})) + } + + /// Handle unsubscribe from resource. + async fn handle_unsubscribe_resource(&self, params: Option) -> Result { + let resources = self + .resources + .as_ref() + .ok_or_else(|| Error::McpProtocol("Resources not enabled".to_string()))?; + + #[derive(serde::Deserialize)] + struct UnsubscribeParams { + uri: String, + } + + let unsub_params: UnsubscribeParams = params + .ok_or_else(|| Error::InvalidToolArguments("Missing params".to_string())) + .and_then(|v| { + serde_json::from_value(v).map_err(|e| Error::InvalidToolArguments(e.to_string())) + })?; + + resources.unsubscribe(&unsub_params.uri, "default").await?; + Ok(serde_json::json!({})) + } + + /// Handle completion request. + async fn handle_completion(&self, params: Option) -> Result { + #[derive(serde::Deserialize)] + struct CompletionParams { + r#ref: CompletionRef, + argument: CompletionArgument, + } + + #[derive(serde::Deserialize)] + #[allow(dead_code)] + struct CompletionRef { + r#type: String, + #[serde(default)] + uri: Option, + #[serde(default)] + name: Option, + } + + #[derive(serde::Deserialize)] + struct CompletionArgument { + name: String, + value: String, + } + + let comp_params: CompletionParams = params + .ok_or_else(|| Error::InvalidToolArguments("Missing params".to_string())) + .and_then(|v| { + serde_json::from_value(v).map_err(|e| Error::InvalidToolArguments(e.to_string())) + })?; + + // Provide completions based on argument type + let values = match comp_params.argument.name.as_str() { + "path" | "file" | "uri" => { + // File path completion + self.complete_file_path(&comp_params.argument.value).await + } + "prompt" | "name" if comp_params.r#ref.r#type == "ref/prompt" => { + // Prompt name completion + self.prompts + .list() + .into_iter() + .filter(|p| p.name.starts_with(&comp_params.argument.value)) + .map(|p| p.name) + .collect() + } + _ => Vec::new(), + }; + + Ok(serde_json::json!({ + "completion": { + "values": values, + "hasMore": false + } + })) + } + + /// Complete file paths. + async fn complete_file_path(&self, prefix: &str) -> Vec { + let roots = self.roots.read().await; + let mut completions = Vec::new(); + + // If we have resources, use that + if let Some(ref resources) = self.resources { + if let Ok(result) = resources.list(None).await { + for resource in result.resources { + if resource.name.starts_with(prefix) { + completions.push(resource.name); + } + } + } + } + + // Also check client-provided roots + for root in roots.iter() { + let search_path = root.join(prefix); + if let Some(parent) = search_path.parent() { + if let Ok(mut entries) = tokio::fs::read_dir(parent).await { + while let Ok(Some(entry)) = entries.next_entry().await { + let name = entry.file_name().to_string_lossy().to_string(); + let full = format!( + "{}{}", + prefix + .rsplit_once('/') + .map(|(p, _)| format!("{}/", p)) + .unwrap_or_default(), + name + ); + if full.starts_with(prefix) && !completions.contains(&full) { + completions.push(full); + } + } + } + } + } + + completions.into_iter().take(20).collect() + } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 5f4d6fa..eed1e3b 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,15 +1,17 @@ //! MCP tool implementations. //! -//! This module contains all 49 MCP tools organized by category: +//! This module contains all 52 MCP tools organized by category: //! //! - `retrieval` - Codebase search and context retrieval (6 tools) //! - `index` - Index management tools (5 tools) //! - `planning` - AI-powered task planning (20 tools) //! - `memory` - Persistent memory storage (4 tools) //! - `review` - Code review tools (14 tools) +//! - `navigation` - Code navigation tools (3 tools) pub mod index; pub mod memory; +pub mod navigation; pub mod planning; pub mod retrieval; pub mod review; @@ -88,4 +90,9 @@ pub fn register_all_tools( handler.register(review::PauseReviewTool::new()); handler.register(review::ResumeReviewTool::new()); handler.register(review::GetReviewTelemetryTool::new()); + + // Navigation tools (3) + handler.register(navigation::FindReferencesTool::new(context_service.clone())); + handler.register(navigation::GoToDefinitionTool::new(context_service.clone())); + handler.register(navigation::DiffFilesTool::new(context_service)); } diff --git a/src/tools/navigation.rs b/src/tools/navigation.rs new file mode 100644 index 0000000..3c51656 --- /dev/null +++ b/src/tools/navigation.rs @@ -0,0 +1,523 @@ +//! Code navigation tools for finding references and definitions. + +use async_trait::async_trait; +use serde_json::Value; +use std::collections::HashMap; +use std::path::Path; +use std::sync::Arc; +use tokio::fs; +use tokio::io::AsyncBufReadExt; + +use crate::error::Result; +use crate::mcp::handler::{error_result, get_string_arg, success_result, ToolHandler}; +use crate::mcp::protocol::{Tool, ToolResult}; +use crate::service::ContextService; + +/// Find all references to a symbol in the codebase. +pub struct FindReferencesTool { + service: Arc, +} + +impl FindReferencesTool { + pub fn new(service: Arc) -> Self { + Self { service } + } +} + +#[async_trait] +impl ToolHandler for FindReferencesTool { + fn definition(&self) -> Tool { + Tool { + name: "find_references".to_string(), + description: "Find all references to a symbol (function, class, variable) in the codebase. Returns file paths and line numbers where the symbol is used.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The symbol name to search for" + }, + "file_pattern": { + "type": "string", + "description": "Optional glob pattern to filter files (e.g., '*.rs', 'src/**/*.ts')" + }, + "max_results": { + "type": "integer", + "description": "Maximum number of results to return (default: 50)" + } + }, + "required": ["symbol"] + }), + } + } + + async fn execute(&self, args: HashMap) -> Result { + let symbol = get_string_arg(&args, "symbol")?; + let file_pattern = args.get("file_pattern").and_then(|v| v.as_str()); + let max_results = args + .get("max_results") + .and_then(|v| v.as_u64()) + .unwrap_or(50) as usize; + + let workspace = self.service.workspace(); + let references = find_symbol_in_files(workspace, &symbol, file_pattern, max_results).await; + + if references.is_empty() { + return Ok(success_result(format!( + "No references found for symbol: `{}`", + symbol + ))); + } + + let mut output = format!( + "# References to `{}`\n\nFound {} references:\n\n", + symbol, + references.len() + ); + + for reference in references { + output.push_str(&format!( + "- **{}:{}**: `{}`\n", + reference.file, + reference.line, + reference.context.trim() + )); + } + + Ok(success_result(output)) + } +} + +/// Go to definition - find where a symbol is defined. +pub struct GoToDefinitionTool { + service: Arc, +} + +impl GoToDefinitionTool { + pub fn new(service: Arc) -> Self { + Self { service } + } +} + +#[async_trait] +impl ToolHandler for GoToDefinitionTool { + fn definition(&self) -> Tool { + Tool { + name: "go_to_definition".to_string(), + description: "Find the definition of a symbol (function, class, struct, type). Returns the file and line where the symbol is defined.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "symbol": { + "type": "string", + "description": "The symbol name to find the definition of" + }, + "language": { + "type": "string", + "description": "Programming language hint (rust, python, typescript, etc.)" + } + }, + "required": ["symbol"] + }), + } + } + + async fn execute(&self, args: HashMap) -> Result { + let symbol = get_string_arg(&args, "symbol")?; + let language = args.get("language").and_then(|v| v.as_str()); + + let workspace = self.service.workspace(); + let definitions = find_definition(workspace, &symbol, language).await; + + if definitions.is_empty() { + return Ok(success_result(format!( + "No definition found for symbol: `{}`", + symbol + ))); + } + + let mut output = format!("# Definition of `{}`\n\n", symbol); + + for def in definitions { + output.push_str(&format!("## {}\n\n", def.file)); + output.push_str(&format!("Line {}\n\n", def.line)); + output.push_str(&format!("```{}\n{}\n```\n\n", def.language, def.context)); + } + + Ok(success_result(output)) + } +} + +/// Diff two files or show changes. +pub struct DiffFilesTool { + service: Arc, +} + +impl DiffFilesTool { + pub fn new(service: Arc) -> Self { + Self { service } + } +} + +#[async_trait] +impl ToolHandler for DiffFilesTool { + fn definition(&self) -> Tool { + Tool { + name: "diff_files".to_string(), + description: + "Compare two files and show the differences. Returns a unified diff format." + .to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "file1": { + "type": "string", + "description": "Path to the first file" + }, + "file2": { + "type": "string", + "description": "Path to the second file" + }, + "context_lines": { + "type": "integer", + "description": "Number of context lines around changes (default: 3)" + } + }, + "required": ["file1", "file2"] + }), + } + } + + async fn execute(&self, args: HashMap) -> Result { + let file1 = get_string_arg(&args, "file1")?; + let file2 = get_string_arg(&args, "file2")?; + let context = args + .get("context_lines") + .and_then(|v| v.as_u64()) + .unwrap_or(3) as usize; + + let workspace = self.service.workspace(); + let path1 = workspace.join(&file1); + let path2 = workspace.join(&file2); + + let content1 = match fs::read_to_string(&path1).await { + Ok(c) => c, + Err(e) => return Ok(error_result(format!("Cannot read {}: {}", file1, e))), + }; + + let content2 = match fs::read_to_string(&path2).await { + Ok(c) => c, + Err(e) => return Ok(error_result(format!("Cannot read {}: {}", file2, e))), + }; + + let diff = generate_diff(&file1, &file2, &content1, &content2, context); + + if diff.is_empty() { + Ok(success_result("Files are identical.".to_string())) + } else { + Ok(success_result(format!("```diff\n{}\n```", diff))) + } + } +} + +// ===== Helper types and functions ===== + +struct Reference { + file: String, + line: usize, + context: String, +} + +struct Definition { + file: String, + line: usize, + context: String, + language: String, +} + +/// Find symbol references in files. +async fn find_symbol_in_files( + workspace: &Path, + symbol: &str, + file_pattern: Option<&str>, + max_results: usize, +) -> Vec { + let mut references = Vec::new(); + let mut stack = vec![workspace.to_path_buf()]; + + while let Some(dir) = stack.pop() { + if references.len() >= max_results { + break; + } + + let mut entries = match fs::read_dir(&dir).await { + Ok(e) => e, + Err(_) => continue, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + if references.len() >= max_results { + break; + } + + let path = entry.path(); + let name = path.file_name().unwrap_or_default().to_string_lossy(); + + // Skip hidden and common ignore patterns + if name.starts_with('.') + || matches!(name.as_ref(), "node_modules" | "target" | "dist" | "build") + { + continue; + } + + if path.is_dir() { + stack.push(path); + } else if path.is_file() { + // Check file pattern if provided + if let Some(pattern) = file_pattern { + if !matches_pattern(&name, pattern) { + continue; + } + } + + // Search file for symbol + if let Ok(file) = fs::File::open(&path).await { + let reader = tokio::io::BufReader::new(file); + let mut lines = reader.lines(); + let mut line_num = 0; + + while let Ok(Some(line)) = lines.next_line().await { + line_num += 1; + if line.contains(symbol) { + let rel_path = path + .strip_prefix(workspace) + .unwrap_or(&path) + .to_string_lossy() + .to_string(); + + references.push(Reference { + file: rel_path, + line: line_num, + context: line, + }); + + if references.len() >= max_results { + break; + } + } + } + } + } + } + } + + references +} + +/// Find symbol definition. +async fn find_definition( + workspace: &Path, + symbol: &str, + language: Option<&str>, +) -> Vec { + let mut definitions = Vec::new(); + + // Build definition patterns based on language + let patterns = get_definition_patterns(symbol, language); + + let mut stack = vec![workspace.to_path_buf()]; + + while let Some(dir) = stack.pop() { + let mut entries = match fs::read_dir(&dir).await { + Ok(e) => e, + Err(_) => continue, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let path = entry.path(); + let name = path.file_name().unwrap_or_default().to_string_lossy(); + + if name.starts_with('.') + || matches!(name.as_ref(), "node_modules" | "target" | "dist" | "build") + { + continue; + } + + if path.is_dir() { + stack.push(path); + } else if path.is_file() { + let ext = path.extension().and_then(|e| e.to_str()).unwrap_or(""); + let file_lang = get_language(ext); + + // Skip if language hint provided and doesn't match + if let Some(lang) = language { + if !file_lang.contains(lang) && lang != file_lang { + continue; + } + } + + if let Ok(content) = fs::read_to_string(&path).await { + for (line_num, line) in content.lines().enumerate() { + for pattern in &patterns { + if line.contains(pattern) { + let rel_path = path + .strip_prefix(workspace) + .unwrap_or(&path) + .to_string_lossy() + .to_string(); + + // Get a few lines of context + let start = line_num.saturating_sub(1); + let context: String = content + .lines() + .skip(start) + .take(5) + .collect::>() + .join("\n"); + + definitions.push(Definition { + file: rel_path, + line: line_num + 1, + context, + language: file_lang.to_string(), + }); + } + } + } + } + } + } + } + + definitions +} + +/// Get definition patterns for a symbol. +fn get_definition_patterns(symbol: &str, language: Option<&str>) -> Vec { + let mut patterns = Vec::new(); + + match language { + Some("rust" | "rs") => { + patterns.push(format!("fn {}(", symbol)); + patterns.push(format!("struct {} ", symbol)); + patterns.push(format!("struct {}", symbol)); + patterns.push(format!("enum {} ", symbol)); + patterns.push(format!("trait {} ", symbol)); + patterns.push(format!("type {} ", symbol)); + patterns.push(format!("const {}", symbol)); + patterns.push(format!("static {}", symbol)); + } + Some("python" | "py") => { + patterns.push(format!("def {}(", symbol)); + patterns.push(format!("class {}:", symbol)); + patterns.push(format!("class {}(", symbol)); + } + Some("typescript" | "javascript" | "ts" | "js") => { + patterns.push(format!("function {}(", symbol)); + patterns.push(format!("const {} =", symbol)); + patterns.push(format!("let {} =", symbol)); + patterns.push(format!("class {} ", symbol)); + patterns.push(format!("interface {} ", symbol)); + patterns.push(format!("type {} =", symbol)); + } + _ => { + // Generic patterns + patterns.push(format!("fn {}(", symbol)); + patterns.push(format!("function {}(", symbol)); + patterns.push(format!("def {}(", symbol)); + patterns.push(format!("class {} ", symbol)); + patterns.push(format!("struct {} ", symbol)); + patterns.push(format!("interface {} ", symbol)); + } + } + + patterns +} + +/// Get language from file extension. +fn get_language(ext: &str) -> &'static str { + match ext { + "rs" => "rust", + "py" => "python", + "ts" | "tsx" => "typescript", + "js" | "jsx" => "javascript", + "go" => "go", + "java" => "java", + "rb" => "ruby", + "c" | "h" => "c", + "cpp" | "hpp" | "cc" => "cpp", + _ => "text", + } +} + +/// Simple pattern matching. +fn matches_pattern(name: &str, pattern: &str) -> bool { + if let Some(ext) = pattern.strip_prefix("*.") { + name.ends_with(&format!(".{}", ext)) + } else { + name.contains(pattern) + } +} + +/// Generate a simple unified diff. +fn generate_diff( + name1: &str, + name2: &str, + content1: &str, + content2: &str, + context: usize, +) -> String { + let lines1: Vec<&str> = content1.lines().collect(); + let lines2: Vec<&str> = content2.lines().collect(); + + if lines1 == lines2 { + return String::new(); + } + + let mut output = format!("--- {}\n+++ {}\n", name1, name2); + + // Simple line-by-line comparison + let max_len = lines1.len().max(lines2.len()); + let mut i = 0; + + while i < max_len { + let l1 = lines1.get(i).copied(); + let l2 = lines2.get(i).copied(); + + if l1 != l2 { + // Found a difference - output hunk + let start = i.saturating_sub(context); + let end = (i + context + 1).min(max_len); + + output.push_str(&format!( + "@@ -{},{} +{},{} @@\n", + start + 1, + end - start, + start + 1, + end - start + )); + + for j in start..end { + let l1 = lines1.get(j).copied().unwrap_or(""); + let l2 = lines2.get(j).copied().unwrap_or(""); + + if l1 == l2 { + output.push_str(&format!(" {}\n", l1)); + } else { + if j < lines1.len() { + output.push_str(&format!("-{}\n", l1)); + } + if j < lines2.len() { + output.push_str(&format!("+{}\n", l2)); + } + } + } + + i = end; + } else { + i += 1; + } + } + + output +} From 49b2d031657a78d80494b9f04507dfaf5b9d7f99 Mon Sep 17 00:00:00 2001 From: Rishi Tank Date: Fri, 2 Jan 2026 03:22:26 +0000 Subject: [PATCH 04/37] feat: add automatic version bumping to release workflow - Add bump_type option (patch/minor/major) for workflow_dispatch - Auto-calculate next version when version input is empty - Add bump-version job to update Cargo.toml - Ensure build uses latest code after version bump - Add duplicate tag protection --- .github/workflows/release.yml | 125 +++++++++++++++++++++++++++++----- 1 file changed, 108 insertions(+), 17 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f9a5b82..8926417 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,9 +5,18 @@ on: workflow_dispatch: inputs: version: - description: 'Release version (e.g., 2.0.1)' - required: true + description: 'Release version (e.g., 2.0.2). Leave empty to auto-bump patch version.' + required: false type: string + bump_type: + description: 'Version bump type (only used if version is empty)' + required: false + type: choice + options: + - patch + - minor + - major + default: patch prerelease: description: 'Is this a pre-release?' required: false @@ -35,6 +44,7 @@ jobs: outputs: should_release: ${{ steps.check.outputs.should_release }} version: ${{ steps.check.outputs.version }} + needs_bump: ${{ steps.check.outputs.needs_bump }} steps: - uses: actions/checkout@v4 with: @@ -43,10 +53,40 @@ jobs: - name: Check release conditions id: check run: | - # For workflow_dispatch, always release with provided version + # Get current version from Cargo.toml + CURRENT_VERSION=$(grep '^version = ' Cargo.toml | head -1 | sed 's/version = "\(.*\)"/\1/') + echo "Current version: $CURRENT_VERSION" + + # For workflow_dispatch if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then - echo "should_release=true" >> $GITHUB_OUTPUT - echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT + if [ -n "${{ inputs.version }}" ]; then + # Explicit version provided + echo "should_release=true" >> $GITHUB_OUTPUT + echo "version=${{ inputs.version }}" >> $GITHUB_OUTPUT + if [ "${{ inputs.version }}" != "$CURRENT_VERSION" ]; then + echo "needs_bump=true" >> $GITHUB_OUTPUT + else + echo "needs_bump=false" >> $GITHUB_OUTPUT + fi + else + # Auto-bump version based on bump_type + IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT_VERSION" + case "${{ inputs.bump_type }}" in + major) + NEW_VERSION="$((MAJOR + 1)).0.0" + ;; + minor) + NEW_VERSION="${MAJOR}.$((MINOR + 1)).0" + ;; + patch|*) + NEW_VERSION="${MAJOR}.${MINOR}.$((PATCH + 1))" + ;; + esac + echo "Auto-bumped to: $NEW_VERSION" + echo "should_release=true" >> $GITHUB_OUTPUT + echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT + echo "needs_bump=true" >> $GITHUB_OUTPUT + fi exit 0 fi @@ -54,6 +94,7 @@ jobs: if [ "${{ github.event_name }}" == "push" ] && [[ "${{ github.ref }}" == refs/tags/v* ]]; then echo "should_release=true" >> $GITHUB_OUTPUT echo "version=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT + echo "needs_bump=false" >> $GITHUB_OUTPUT exit 0 fi @@ -65,17 +106,15 @@ jobs: exit 0 fi - # Get version from Cargo.toml - VERSION=$(grep '^version = ' Cargo.toml | head -1 | sed 's/version = "\(.*\)"/\1/') - # Check if this version tag already exists - if git tag -l "v$VERSION" | grep -q .; then - echo "Tag v$VERSION already exists, skipping release" + if git tag -l "v$CURRENT_VERSION" | grep -q .; then + echo "Tag v$CURRENT_VERSION already exists, skipping release" echo "should_release=false" >> $GITHUB_OUTPUT else - echo "New version v$VERSION detected, will release" + echo "New version v$CURRENT_VERSION detected, will release" echo "should_release=true" >> $GITHUB_OUTPUT - echo "version=$VERSION" >> $GITHUB_OUTPUT + echo "version=$CURRENT_VERSION" >> $GITHUB_OUTPUT + echo "needs_bump=false" >> $GITHUB_OUTPUT fi exit 0 fi @@ -83,9 +122,41 @@ jobs: echo "Unknown trigger, skipping release" echo "should_release=false" >> $GITHUB_OUTPUT - build: + # Bump version in Cargo.toml if needed + bump-version: needs: check - if: needs.check.outputs.should_release == 'true' + if: needs.check.outputs.should_release == 'true' && needs.check.outputs.needs_bump == 'true' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Update Cargo.toml version + run: | + VERSION="${{ needs.check.outputs.version }}" + echo "Updating Cargo.toml to version $VERSION" + sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml + cat Cargo.toml | head -5 + + - name: Commit version bump + run: | + VERSION="${{ needs.check.outputs.version }}" + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add Cargo.toml + git commit -m "chore: bump version to $VERSION" + git push + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + build: + needs: [check, bump-version] + # Run if should_release is true, and either no bump needed OR bump-version succeeded + if: | + always() && + needs.check.outputs.should_release == 'true' && + (needs.check.outputs.needs_bump != 'true' || needs.bump-version.result == 'success') strategy: matrix: include: @@ -102,6 +173,13 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + # Fetch latest to get version bump commit + fetch-depth: 0 + + - name: Pull latest changes + run: git pull origin ${{ github.ref_name }} || true - name: Install Rust uses: dtolnay/rust-toolchain@stable @@ -122,13 +200,20 @@ jobs: path: ${{ matrix.artifact }} release: - needs: [check, build] + needs: [check, bump-version, build] + if: | + always() && + needs.check.outputs.should_release == 'true' && + needs.build.result == 'success' runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 with: fetch-depth: 0 + - name: Pull latest changes + run: git pull origin ${{ github.ref_name }} || true + - name: Download all artifacts uses: actions/download-artifact@v4 with: @@ -140,8 +225,14 @@ jobs: VERSION="${{ needs.check.outputs.version }}" git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - git tag -a "v$VERSION" -m "Release v$VERSION" - git push origin "v$VERSION" + + # Check if tag already exists + if git tag -l "v$VERSION" | grep -q .; then + echo "Tag v$VERSION already exists, skipping" + else + git tag -a "v$VERSION" -m "Release v$VERSION" + git push origin "v$VERSION" + fi env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} From b962a010f0ea8994e0d1cc9ba80ae7252193cc28 Mon Sep 17 00:00:00 2001 From: Rishi Tank Date: Fri, 2 Jan 2026 03:28:36 +0000 Subject: [PATCH 05/37] feat: add comprehensive tests and logging/setLevel handler Tests: - Add 23 new tests for resources.rs (MIME types, serialization) - Add 8 new tests for progress.rs (reporter, manager, tokens) - Add 13 new tests for navigation.rs (patterns, diff, structs) - Add 3 new tests for server.rs (LogLevel) MCP Features: - Add logging/setLevel handler per MCP spec - Add LogLevel enum with proper level handling - Total tests: 137 (was 111) --- src/mcp/progress.rs | 93 +++++++++++++++++++++++++++++++ src/mcp/resources.rs | 85 +++++++++++++++++++++++++++++ src/mcp/server.rs | 113 ++++++++++++++++++++++++++++++++++++++ src/tools/navigation.rs | 118 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 409 insertions(+) diff --git a/src/mcp/progress.rs b/src/mcp/progress.rs index f5f8454..7d5b364 100644 --- a/src/mcp/progress.rs +++ b/src/mcp/progress.rs @@ -170,4 +170,97 @@ mod tests { assert_eq!(notification.params.total, Some(100)); assert_eq!(notification.params.message, Some("Halfway".to_string())); } + + #[tokio::test] + async fn test_progress_reporter_percent() { + let (tx, mut rx) = mpsc::channel(10); + let reporter = ProgressReporter::new(ProgressToken::Number(1), tx, Some(200)); + + reporter.report_percent(50, Some("Half done")).await; + + let notification = rx.recv().await.unwrap(); + assert_eq!(notification.params.progress, 100); // 50% of 200 + assert_eq!(notification.params.total, Some(200)); + } + + #[tokio::test] + async fn test_progress_reporter_complete() { + let (tx, mut rx) = mpsc::channel(10); + let reporter = ProgressReporter::new(ProgressToken::Number(2), tx, Some(100)); + + reporter.complete(Some("Done!")).await; + + let notification = rx.recv().await.unwrap(); + assert_eq!(notification.params.progress, 100); + assert_eq!(notification.params.message, Some("Done!".to_string())); + } + + #[test] + fn test_progress_token_serialization() { + let token_str = ProgressToken::String("test-token".to_string()); + let token_num = ProgressToken::Number(42); + + let json_str = serde_json::to_string(&token_str).unwrap(); + let json_num = serde_json::to_string(&token_num).unwrap(); + + assert_eq!(json_str, "\"test-token\""); + assert_eq!(json_num, "42"); + + let parsed_str: ProgressToken = serde_json::from_str(&json_str).unwrap(); + let parsed_num: ProgressToken = serde_json::from_str(&json_num).unwrap(); + + assert_eq!(parsed_str, token_str); + assert_eq!(parsed_num, token_num); + } + + #[test] + fn test_progress_notification_structure() { + let notification = ProgressNotification::new( + ProgressToken::String("op-1".to_string()), + 25, + Some(100), + Some("Processing...".to_string()), + ); + + assert_eq!(notification.jsonrpc, "2.0"); + assert_eq!(notification.method, "notifications/progress"); + assert_eq!(notification.params.progress, 25); + assert_eq!(notification.params.total, Some(100)); + } + + #[test] + fn test_progress_manager_create_reporter() { + let manager = ProgressManager::new(); + + let reporter1 = manager.create_reporter(Some(100)); + let reporter2 = manager.create_reporter(Some(200)); + + // Reporters should have different tokens + assert_ne!(reporter1.token, reporter2.token); + } + + #[test] + fn test_progress_manager_with_custom_token() { + let manager = ProgressManager::new(); + let custom_token = ProgressToken::String("custom".to_string()); + + let reporter = manager.create_reporter_with_token(custom_token.clone(), Some(50)); + assert_eq!(reporter.token, custom_token); + } + + #[test] + fn test_progress_params_serialization() { + let params = ProgressParams { + progress_token: ProgressToken::Number(1), + progress: 50, + total: Some(100), + message: Some("Working...".to_string()), + }; + + let json = serde_json::to_string(¶ms).unwrap(); + assert!(json.contains("\"progressToken\":1")); + assert!(json.contains("\"progress\":50")); + assert!(json.contains("\"total\":100")); + assert!(json.contains("\"message\":\"Working...\"")); + } } diff --git a/src/mcp/resources.rs b/src/mcp/resources.rs index f4a5438..5583f37 100644 --- a/src/mcp/resources.rs +++ b/src/mcp/resources.rs @@ -283,5 +283,90 @@ mod tests { ResourceRegistry::guess_mime_type(std::path::Path::new("test.py")), Some("text/x-python".to_string()) ); + assert_eq!( + ResourceRegistry::guess_mime_type(std::path::Path::new("test.ts")), + Some("text/typescript".to_string()) + ); + assert_eq!( + ResourceRegistry::guess_mime_type(std::path::Path::new("test.json")), + Some("application/json".to_string()) + ); + assert_eq!( + ResourceRegistry::guess_mime_type(std::path::Path::new("test.unknown")), + Some("text/plain".to_string()) + ); + } + + #[test] + fn test_resource_serialization() { + let resource = Resource { + uri: "file:///test/file.rs".to_string(), + name: "file.rs".to_string(), + description: Some("A test file".to_string()), + mime_type: Some("text/x-rust".to_string()), + }; + + let json = serde_json::to_string(&resource).unwrap(); + assert!(json.contains("\"uri\":\"file:///test/file.rs\"")); + assert!(json.contains("\"name\":\"file.rs\"")); + assert!(json.contains("\"mimeType\":\"text/x-rust\"")); + + let parsed: Resource = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.uri, resource.uri); + assert_eq!(parsed.name, resource.name); + } + + #[test] + fn test_resource_contents_serialization() { + let contents = ResourceContents { + uri: "file:///test/file.rs".to_string(), + mime_type: Some("text/x-rust".to_string()), + text: Some("fn main() {}".to_string()), + blob: None, + }; + + let json = serde_json::to_string(&contents).unwrap(); + assert!(json.contains("\"text\":\"fn main() {}\"")); + assert!(!json.contains("\"blob\"")); // blob should be skipped when None + } + + #[test] + fn test_list_resources_result_serialization() { + let result = ListResourcesResult { + resources: vec![Resource { + uri: "file:///test.rs".to_string(), + name: "test.rs".to_string(), + description: None, + mime_type: None, + }], + next_cursor: Some("cursor123".to_string()), + }; + + let json = serde_json::to_string(&result).unwrap(); + assert!(json.contains("\"nextCursor\":\"cursor123\"")); + + let result_no_cursor = ListResourcesResult { + resources: vec![], + next_cursor: None, + }; + let json2 = serde_json::to_string(&result_no_cursor).unwrap(); + assert!(!json2.contains("nextCursor")); + } + + #[test] + fn test_read_resource_result_serialization() { + let result = ReadResourceResult { + contents: vec![ResourceContents { + uri: "file:///test.rs".to_string(), + mime_type: Some("text/x-rust".to_string()), + text: Some("code".to_string()), + blob: None, + }], + }; + + let json = serde_json::to_string(&result).unwrap(); + let parsed: ReadResourceResult = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.contents.len(), 1); + assert_eq!(parsed.contents[0].text, Some("code".to_string())); } } diff --git a/src/mcp/server.rs b/src/mcp/server.rs index aef30d9..68af7a8 100644 --- a/src/mcp/server.rs +++ b/src/mcp/server.rs @@ -16,6 +16,49 @@ use crate::mcp::transport::{Message, Transport}; use crate::service::ContextService; use crate::VERSION; +/// Log level for the MCP server. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum LogLevel { + Debug, + #[default] + Info, + Notice, + Warning, + Error, + Critical, + Alert, + Emergency, +} + +impl LogLevel { + fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "debug" => Self::Debug, + "info" => Self::Info, + "notice" => Self::Notice, + "warning" | "warn" => Self::Warning, + "error" => Self::Error, + "critical" => Self::Critical, + "alert" => Self::Alert, + "emergency" => Self::Emergency, + _ => Self::Info, + } + } + + fn as_str(&self) -> &'static str { + match self { + Self::Debug => "debug", + Self::Info => "info", + Self::Notice => "notice", + Self::Warning => "warning", + Self::Error => "error", + Self::Critical => "critical", + Self::Alert => "alert", + Self::Emergency => "emergency", + } + } +} + /// MCP server. pub struct McpServer { handler: Arc, @@ -27,6 +70,8 @@ pub struct McpServer { roots: Arc>>, /// Active request IDs for cancellation support. active_requests: Arc>>, + /// Current log level. + log_level: Arc>, } impl McpServer { @@ -40,6 +85,7 @@ impl McpServer { version: VERSION.to_string(), roots: Arc::new(RwLock::new(Vec::new())), active_requests: Arc::new(RwLock::new(HashSet::new())), + log_level: Arc::new(RwLock::new(LogLevel::default())), } } @@ -58,9 +104,20 @@ impl McpServer { version: VERSION.to_string(), roots: Arc::new(RwLock::new(Vec::new())), active_requests: Arc::new(RwLock::new(HashSet::new())), + log_level: Arc::new(RwLock::new(LogLevel::default())), } } + /// Get the current log level. + pub async fn log_level(&self) -> LogLevel { + *self.log_level.read().await + } + + /// Set the log level. + pub async fn set_log_level(&self, level: LogLevel) { + *self.log_level.write().await = level; + } + /// Get the client-provided workspace roots. pub async fn roots(&self) -> Vec { self.roots.read().await.clone() @@ -124,6 +181,8 @@ impl McpServer { "resources/unsubscribe" => self.handle_unsubscribe_resource(req.params).await, // Completions "completion/complete" => self.handle_completion(req.params).await, + // Logging + "logging/setLevel" => self.handle_set_log_level(req.params).await, // Unknown _ => Err(Error::McpProtocol(format!( "Unknown method: {}", @@ -487,4 +546,58 @@ impl McpServer { completions.into_iter().take(20).collect() } + + /// Handle logging/setLevel request. + async fn handle_set_log_level(&self, params: Option) -> Result { + #[derive(serde::Deserialize)] + struct SetLevelParams { + level: String, + } + + let level_str = if let Some(params) = params { + let p: SetLevelParams = serde_json::from_value(params)?; + p.level + } else { + return Err(Error::McpProtocol( + "Missing level parameter".to_string(), + )); + }; + + let level = LogLevel::from_str(&level_str); + self.set_log_level(level).await; + + info!("Log level set to: {}", level.as_str()); + Ok(serde_json::json!({})) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_log_level_from_str() { + assert_eq!(LogLevel::from_str("debug"), LogLevel::Debug); + assert_eq!(LogLevel::from_str("DEBUG"), LogLevel::Debug); + assert_eq!(LogLevel::from_str("info"), LogLevel::Info); + assert_eq!(LogLevel::from_str("warning"), LogLevel::Warning); + assert_eq!(LogLevel::from_str("warn"), LogLevel::Warning); + assert_eq!(LogLevel::from_str("error"), LogLevel::Error); + assert_eq!(LogLevel::from_str("critical"), LogLevel::Critical); + assert_eq!(LogLevel::from_str("unknown"), LogLevel::Info); // Default + } + + #[test] + fn test_log_level_as_str() { + assert_eq!(LogLevel::Debug.as_str(), "debug"); + assert_eq!(LogLevel::Info.as_str(), "info"); + assert_eq!(LogLevel::Warning.as_str(), "warning"); + assert_eq!(LogLevel::Error.as_str(), "error"); + assert_eq!(LogLevel::Emergency.as_str(), "emergency"); + } + + #[test] + fn test_log_level_default() { + assert_eq!(LogLevel::default(), LogLevel::Info); + } } diff --git a/src/tools/navigation.rs b/src/tools/navigation.rs index 3c51656..dacf215 100644 --- a/src/tools/navigation.rs +++ b/src/tools/navigation.rs @@ -521,3 +521,121 @@ fn generate_diff( output } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_matches_pattern_extension() { + assert!(matches_pattern("file.rs", "*.rs")); + assert!(matches_pattern("test.py", "*.py")); + assert!(!matches_pattern("file.rs", "*.py")); + assert!(!matches_pattern("file.txt", "*.rs")); + } + + #[test] + fn test_matches_pattern_contains() { + assert!(matches_pattern("test_file.rs", "test")); + assert!(matches_pattern("my_test.py", "test")); + assert!(!matches_pattern("file.rs", "test")); + } + + #[test] + fn test_get_language() { + assert_eq!(get_language("rs"), "rust"); + assert_eq!(get_language("py"), "python"); + assert_eq!(get_language("ts"), "typescript"); + assert_eq!(get_language("tsx"), "typescript"); + assert_eq!(get_language("js"), "javascript"); + assert_eq!(get_language("go"), "go"); + assert_eq!(get_language("unknown"), "text"); + } + + #[test] + fn test_get_definition_patterns_rust() { + let patterns = get_definition_patterns("MyStruct", Some("rust")); + assert!(patterns.contains(&"struct MyStruct ".to_string())); + assert!(patterns.contains(&"fn MyStruct(".to_string())); + assert!(patterns.contains(&"enum MyStruct ".to_string())); + } + + #[test] + fn test_get_definition_patterns_python() { + let patterns = get_definition_patterns("my_func", Some("python")); + assert!(patterns.contains(&"def my_func(".to_string())); + assert!(patterns.contains(&"class my_func:".to_string())); + } + + #[test] + fn test_get_definition_patterns_typescript() { + let patterns = get_definition_patterns("MyClass", Some("typescript")); + assert!(patterns.contains(&"class MyClass ".to_string())); + assert!(patterns.contains(&"interface MyClass ".to_string())); + assert!(patterns.contains(&"function MyClass(".to_string())); + } + + #[test] + fn test_get_definition_patterns_generic() { + let patterns = get_definition_patterns("Symbol", None); + assert!(!patterns.is_empty()); + // Should have generic patterns for multiple languages + assert!(patterns.contains(&"fn Symbol(".to_string())); + assert!(patterns.contains(&"def Symbol(".to_string())); + assert!(patterns.contains(&"class Symbol ".to_string())); + } + + #[test] + fn test_generate_diff_identical() { + let content = "line1\nline2\nline3"; + let diff = generate_diff("a.txt", "b.txt", content, content, 3); + assert!(diff.is_empty()); + } + + #[test] + fn test_generate_diff_different() { + let content1 = "line1\nline2\nline3"; + let content2 = "line1\nmodified\nline3"; + let diff = generate_diff("a.txt", "b.txt", content1, content2, 1); + + assert!(diff.contains("--- a.txt")); + assert!(diff.contains("+++ b.txt")); + assert!(diff.contains("-line2")); + assert!(diff.contains("+modified")); + } + + #[test] + fn test_generate_diff_with_context() { + let content1 = "a\nb\nc\nd\ne"; + let content2 = "a\nb\nX\nd\ne"; + let diff = generate_diff("f1", "f2", content1, content2, 1); + + // Should include context lines around the change + assert!(diff.contains("@@")); + } + + #[test] + fn test_reference_struct() { + let reference = Reference { + file: "src/main.rs".to_string(), + line: 42, + context: "fn main() {}".to_string(), + }; + + assert_eq!(reference.file, "src/main.rs"); + assert_eq!(reference.line, 42); + } + + #[test] + fn test_definition_struct() { + let definition = Definition { + file: "src/lib.rs".to_string(), + line: 10, + context: "pub struct MyStruct {}".to_string(), + language: "rust".to_string(), + }; + + assert_eq!(definition.file, "src/lib.rs"); + assert_eq!(definition.language, "rust"); + } +} From 5daf33d87fc560a211d0bc5e399b65f0b3cf6efd Mon Sep 17 00:00:00 2001 From: Rishi Tank Date: Fri, 2 Jan 2026 03:35:51 +0000 Subject: [PATCH 06/37] feat: add workspace analysis tools with comprehensive tests New workspace tools (3): - workspace_stats: Get file counts, lines of code, language breakdown - git_status: Get staged, unstaged, untracked files with optional diff - extract_symbols: Extract functions, classes, structs from source files Features: - Symbol detection for Rust, Python, TypeScript/JavaScript, Go - Recursive directory traversal with hidden file filtering - Skips common non-code directories (node_modules, target, etc.) Tests added (15): - Extension to language mapping - Symbol detection for each language - Symbol extraction from content - Name extraction helper Total tools: 55 (was 52) Total tests: 152 (was 137) --- src/mcp/server.rs | 4 +- src/tools/mod.rs | 11 +- src/tools/workspace.rs | 723 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 733 insertions(+), 5 deletions(-) create mode 100644 src/tools/workspace.rs diff --git a/src/mcp/server.rs b/src/mcp/server.rs index 68af7a8..4bdd32b 100644 --- a/src/mcp/server.rs +++ b/src/mcp/server.rs @@ -558,9 +558,7 @@ impl McpServer { let p: SetLevelParams = serde_json::from_value(params)?; p.level } else { - return Err(Error::McpProtocol( - "Missing level parameter".to_string(), - )); + return Err(Error::McpProtocol("Missing level parameter".to_string())); }; let level = LogLevel::from_str(&level_str); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index eed1e3b..20a6d74 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,6 +1,6 @@ //! MCP tool implementations. //! -//! This module contains all 52 MCP tools organized by category: +//! This module contains all 55 MCP tools organized by category: //! //! - `retrieval` - Codebase search and context retrieval (6 tools) //! - `index` - Index management tools (5 tools) @@ -8,6 +8,7 @@ //! - `memory` - Persistent memory storage (4 tools) //! - `review` - Code review tools (14 tools) //! - `navigation` - Code navigation tools (3 tools) +//! - `workspace` - Workspace analysis tools (3 tools) pub mod index; pub mod memory; @@ -15,6 +16,7 @@ pub mod navigation; pub mod planning; pub mod retrieval; pub mod review; +pub mod workspace; use std::sync::Arc; @@ -94,5 +96,10 @@ pub fn register_all_tools( // Navigation tools (3) handler.register(navigation::FindReferencesTool::new(context_service.clone())); handler.register(navigation::GoToDefinitionTool::new(context_service.clone())); - handler.register(navigation::DiffFilesTool::new(context_service)); + handler.register(navigation::DiffFilesTool::new(context_service.clone())); + + // Workspace tools (3) + handler.register(workspace::WorkspaceStatsTool::new(context_service.clone())); + handler.register(workspace::GitStatusTool::new(context_service.clone())); + handler.register(workspace::ExtractSymbolsTool::new(context_service)); } diff --git a/src/tools/workspace.rs b/src/tools/workspace.rs new file mode 100644 index 0000000..2b84fa5 --- /dev/null +++ b/src/tools/workspace.rs @@ -0,0 +1,723 @@ +//! Workspace analysis and statistics tools. + +use async_trait::async_trait; +use serde_json::Value; +use std::collections::HashMap; +use std::path::Path; +use std::sync::Arc; +use tokio::fs; +use tokio::process::Command; + +use crate::error::Result; +use crate::mcp::handler::{error_result, get_string_arg, success_result, ToolHandler}; +use crate::mcp::protocol::{Tool, ToolResult}; +use crate::service::ContextService; + +/// Get workspace statistics (file counts, language breakdown, etc.). +pub struct WorkspaceStatsTool { + service: Arc, +} + +impl WorkspaceStatsTool { + pub fn new(service: Arc) -> Self { + Self { service } + } +} + +#[async_trait] +impl ToolHandler for WorkspaceStatsTool { + fn definition(&self) -> Tool { + Tool { + name: "workspace_stats".to_string(), + description: "Get workspace statistics including file counts by language, total lines of code, and directory structure overview.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "include_hidden": { + "type": "boolean", + "description": "Include hidden files/directories (default: false)" + } + }, + "required": [] + }), + } + } + + async fn execute(&self, args: HashMap) -> Result { + let include_hidden = args + .get("include_hidden") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let workspace = self.service.workspace_path(); + match collect_workspace_stats(workspace, include_hidden).await { + Ok(stats) => Ok(success_result( + serde_json::to_string_pretty(&stats).unwrap(), + )), + Err(e) => Ok(error_result(format!("Failed to collect stats: {}", e))), + } + } +} + +#[derive(serde::Serialize)] +struct WorkspaceStats { + total_files: usize, + total_lines: usize, + languages: HashMap, + directories: usize, +} + +#[derive(serde::Serialize, Default)] +struct LanguageStats { + files: usize, + lines: usize, +} + +async fn collect_workspace_stats(root: &Path, include_hidden: bool) -> Result { + let mut stats = WorkspaceStats { + total_files: 0, + total_lines: 0, + languages: HashMap::new(), + directories: 0, + }; + + collect_stats_recursive(root, &mut stats, include_hidden).await; + Ok(stats) +} + +fn collect_stats_recursive<'a>( + path: &'a Path, + stats: &'a mut WorkspaceStats, + include_hidden: bool, +) -> std::pin::Pin + Send + 'a>> { + Box::pin(async move { + let mut entries = match fs::read_dir(path).await { + Ok(e) => e, + Err(_) => return, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + // Skip hidden files/dirs unless requested + if !include_hidden && name_str.starts_with('.') { + continue; + } + + // Skip common non-code directories + if matches!( + name_str.as_ref(), + "node_modules" | "target" | "dist" | "build" | ".git" | "__pycache__" | "venv" + ) { + continue; + } + + let file_type = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + let entry_path = entry.path(); + + if file_type.is_dir() { + stats.directories += 1; + collect_stats_recursive(&entry_path, stats, include_hidden).await; + } else if file_type.is_file() { + if let Some(ext) = entry_path.extension() { + let ext_str = ext.to_string_lossy().to_lowercase(); + let lang = extension_to_language(&ext_str); + + if lang != "binary" { + stats.total_files += 1; + let lines = count_lines(&entry_path).await.unwrap_or(0); + stats.total_lines += lines; + + let lang_stats = stats.languages.entry(lang.to_string()).or_default(); + lang_stats.files += 1; + lang_stats.lines += lines; + } + } + } + } + }) +} + +async fn count_lines(path: &Path) -> Result { + let content = fs::read_to_string(path).await?; + Ok(content.lines().count()) +} + +fn extension_to_language(ext: &str) -> &'static str { + match ext { + "rs" => "rust", + "py" => "python", + "js" => "javascript", + "ts" => "typescript", + "tsx" | "jsx" => "react", + "go" => "go", + "java" => "java", + "rb" => "ruby", + "c" | "h" => "c", + "cpp" | "cc" | "hpp" => "cpp", + "cs" => "csharp", + "swift" => "swift", + "kt" => "kotlin", + "scala" => "scala", + "php" => "php", + "sh" | "bash" => "shell", + "sql" => "sql", + "html" => "html", + "css" | "scss" | "sass" => "css", + "json" => "json", + "yaml" | "yml" => "yaml", + "toml" => "toml", + "md" | "markdown" => "markdown", + "xml" => "xml", + _ => "binary", + } +} + +/// Get git status for the workspace. +pub struct GitStatusTool { + service: Arc, +} + +impl GitStatusTool { + pub fn new(service: Arc) -> Self { + Self { service } + } +} + +#[async_trait] +impl ToolHandler for GitStatusTool { + fn definition(&self) -> Tool { + Tool { + name: "git_status".to_string(), + description: "Get the current git status of the workspace including staged, unstaged, and untracked files.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "include_diff": { + "type": "boolean", + "description": "Include diff of changes (default: false)" + } + }, + "required": [] + }), + } + } + + async fn execute(&self, args: HashMap) -> Result { + let include_diff = args + .get("include_diff") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let workspace = self.service.workspace_path(); + + // Get git status + let status_output = Command::new("git") + .arg("status") + .arg("--porcelain") + .current_dir(workspace) + .output() + .await; + + let status = match status_output { + Ok(output) if output.status.success() => { + String::from_utf8_lossy(&output.stdout).to_string() + } + _ => return Ok(error_result("Not a git repository or git command failed")), + }; + + // Parse status + let mut result = GitStatus { + staged: Vec::new(), + unstaged: Vec::new(), + untracked: Vec::new(), + diff: None, + }; + + for line in status.lines() { + if line.len() < 3 { + continue; + } + let index_status = line.chars().next().unwrap_or(' '); + let work_status = line.chars().nth(1).unwrap_or(' '); + let file = line[3..].to_string(); + + match (index_status, work_status) { + ('?', '?') => result.untracked.push(file), + (' ', _) => result.unstaged.push(file), + (_, ' ') => result.staged.push(file), + (_, _) => { + result.staged.push(file.clone()); + result.unstaged.push(file); + } + } + } + + // Get diff if requested + if include_diff { + let diff_output = Command::new("git") + .arg("diff") + .current_dir(workspace) + .output() + .await; + + if let Ok(output) = diff_output { + if output.status.success() { + result.diff = Some(String::from_utf8_lossy(&output.stdout).to_string()); + } + } + } + + Ok(success_result( + serde_json::to_string_pretty(&result).unwrap(), + )) + } +} + +#[derive(serde::Serialize)] +struct GitStatus { + staged: Vec, + unstaged: Vec, + untracked: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + diff: Option, +} + +/// Extract symbols (functions, classes, structs) from a file. +pub struct ExtractSymbolsTool { + service: Arc, +} + +impl ExtractSymbolsTool { + pub fn new(service: Arc) -> Self { + Self { service } + } +} + +#[async_trait] +impl ToolHandler for ExtractSymbolsTool { + fn definition(&self) -> Tool { + Tool { + name: "extract_symbols".to_string(), + description: "Extract function, class, struct, and other symbol definitions from a source file. Returns a structured list of symbols with their line numbers.".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Path to the source file" + } + }, + "required": ["file_path"] + }), + } + } + + async fn execute(&self, args: HashMap) -> Result { + let file_path = get_string_arg(&args, "file_path")?; + + let full_path = self.service.workspace_path().join(&file_path); + let content = match fs::read_to_string(&full_path).await { + Ok(c) => c, + Err(e) => return Ok(error_result(format!("Failed to read file: {}", e))), + }; + + let ext = full_path.extension().and_then(|e| e.to_str()).unwrap_or(""); + let symbols = extract_symbols_from_content(&content, ext); + + let result = serde_json::json!({ + "file": file_path, + "symbols": symbols + }); + Ok(success_result( + serde_json::to_string_pretty(&result).unwrap(), + )) + } +} + +#[derive(serde::Serialize)] +struct Symbol { + name: String, + kind: String, + line: usize, + #[serde(skip_serializing_if = "Option::is_none")] + signature: Option, +} + +fn extract_symbols_from_content(content: &str, ext: &str) -> Vec { + let mut symbols = Vec::new(); + let lines: Vec<&str> = content.lines().collect(); + + for (i, line) in lines.iter().enumerate() { + let trimmed = line.trim(); + if let Some(sym) = detect_symbol(trimmed, ext, i + 1) { + symbols.push(sym); + } + } + + symbols +} + +fn detect_symbol(line: &str, ext: &str, line_num: usize) -> Option { + match ext { + "rs" => detect_rust_symbol(line, line_num), + "py" => detect_python_symbol(line, line_num), + "ts" | "tsx" | "js" | "jsx" => detect_ts_symbol(line, line_num), + "go" => detect_go_symbol(line, line_num), + _ => None, + } +} + +fn detect_rust_symbol(line: &str, line_num: usize) -> Option { + if line.starts_with("pub fn ") || line.starts_with("fn ") { + let name = extract_name(line, "fn "); + return Some(Symbol { + name, + kind: "function".to_string(), + line: line_num, + signature: Some(line.to_string()), + }); + } + if line.starts_with("pub struct ") || line.starts_with("struct ") { + let name = extract_name(line, "struct "); + return Some(Symbol { + name, + kind: "struct".to_string(), + line: line_num, + signature: None, + }); + } + if line.starts_with("pub enum ") || line.starts_with("enum ") { + let name = extract_name(line, "enum "); + return Some(Symbol { + name, + kind: "enum".to_string(), + line: line_num, + signature: None, + }); + } + if line.starts_with("pub trait ") || line.starts_with("trait ") { + let name = extract_name(line, "trait "); + return Some(Symbol { + name, + kind: "trait".to_string(), + line: line_num, + signature: None, + }); + } + if line.starts_with("impl ") { + let name = line.strip_prefix("impl ").unwrap_or(line); + let name = name.split_whitespace().next().unwrap_or("").to_string(); + return Some(Symbol { + name, + kind: "impl".to_string(), + line: line_num, + signature: Some(line.to_string()), + }); + } + None +} + +fn detect_python_symbol(line: &str, line_num: usize) -> Option { + if line.starts_with("def ") { + let name = extract_name(line, "def "); + return Some(Symbol { + name, + kind: "function".to_string(), + line: line_num, + signature: Some(line.to_string()), + }); + } + if line.starts_with("class ") { + let name = extract_name(line, "class "); + return Some(Symbol { + name, + kind: "class".to_string(), + line: line_num, + signature: None, + }); + } + if line.starts_with("async def ") { + let name = extract_name(line, "async def "); + return Some(Symbol { + name, + kind: "async_function".to_string(), + line: line_num, + signature: Some(line.to_string()), + }); + } + None +} + +fn detect_ts_symbol(line: &str, line_num: usize) -> Option { + // Function declarations + if line.contains("function ") { + let parts: Vec<&str> = line.split("function ").collect(); + if parts.len() > 1 { + let name = parts[1].split('(').next().unwrap_or("").trim().to_string(); + if !name.is_empty() { + return Some(Symbol { + name, + kind: "function".to_string(), + line: line_num, + signature: Some(line.to_string()), + }); + } + } + } + // Class declarations + if line.starts_with("class ") || line.starts_with("export class ") { + let name = if line.contains("export class ") { + extract_name(line, "export class ") + } else { + extract_name(line, "class ") + }; + return Some(Symbol { + name, + kind: "class".to_string(), + line: line_num, + signature: None, + }); + } + // Interface declarations + if line.starts_with("interface ") || line.starts_with("export interface ") { + let name = if line.contains("export interface ") { + extract_name(line, "export interface ") + } else { + extract_name(line, "interface ") + }; + return Some(Symbol { + name, + kind: "interface".to_string(), + line: line_num, + signature: None, + }); + } + // Type declarations + if line.starts_with("type ") || line.starts_with("export type ") { + let name = if line.contains("export type ") { + extract_name(line, "export type ") + } else { + extract_name(line, "type ") + }; + return Some(Symbol { + name, + kind: "type".to_string(), + line: line_num, + signature: None, + }); + } + None +} + +fn detect_go_symbol(line: &str, line_num: usize) -> Option { + if line.starts_with("func ") { + let rest = line.strip_prefix("func ").unwrap_or(line); + let name = if rest.starts_with('(') { + // Method: func (r *Receiver) MethodName(...) + rest.split(')') + .nth(1) + .and_then(|s| s.trim().split('(').next()) + } else { + // Function: func FuncName(...) + rest.split('(').next() + }; + if let Some(name) = name { + return Some(Symbol { + name: name.trim().to_string(), + kind: "function".to_string(), + line: line_num, + signature: Some(line.to_string()), + }); + } + } + if line.starts_with("type ") && line.contains(" struct") { + let name = extract_name(line, "type "); + return Some(Symbol { + name, + kind: "struct".to_string(), + line: line_num, + signature: None, + }); + } + if line.starts_with("type ") && line.contains(" interface") { + let name = extract_name(line, "type "); + return Some(Symbol { + name, + kind: "interface".to_string(), + line: line_num, + signature: None, + }); + } + None +} + +fn extract_name(line: &str, prefix: &str) -> String { + line.split(prefix) + .nth(1) + .unwrap_or("") + .split(|c: char| !c.is_alphanumeric() && c != '_') + .next() + .unwrap_or("") + .to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extension_to_language() { + assert_eq!(extension_to_language("rs"), "rust"); + assert_eq!(extension_to_language("py"), "python"); + assert_eq!(extension_to_language("ts"), "typescript"); + assert_eq!(extension_to_language("go"), "go"); + assert_eq!(extension_to_language("unknown"), "binary"); + } + + #[test] + fn test_detect_rust_symbol_function() { + let sym = detect_rust_symbol("pub fn hello_world() -> Result<()> {", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "hello_world"); + assert_eq!(sym.kind, "function"); + } + + #[test] + fn test_detect_rust_symbol_struct() { + let sym = detect_rust_symbol("pub struct MyStruct {", 10); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "MyStruct"); + assert_eq!(sym.kind, "struct"); + assert_eq!(sym.line, 10); + } + + #[test] + fn test_detect_rust_symbol_enum() { + let sym = detect_rust_symbol("enum Color { Red, Green, Blue }", 5); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "Color"); + assert_eq!(sym.kind, "enum"); + } + + #[test] + fn test_detect_rust_symbol_trait() { + let sym = detect_rust_symbol("pub trait Handler {", 15); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "Handler"); + assert_eq!(sym.kind, "trait"); + } + + #[test] + fn test_detect_python_symbol_function() { + let sym = detect_python_symbol("def process_data(data: dict) -> list:", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "process_data"); + assert_eq!(sym.kind, "function"); + } + + #[test] + fn test_detect_python_symbol_class() { + let sym = detect_python_symbol("class MyClass:", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "MyClass"); + assert_eq!(sym.kind, "class"); + } + + #[test] + fn test_detect_python_symbol_async() { + let sym = detect_python_symbol("async def fetch_data():", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "fetch_data"); + assert_eq!(sym.kind, "async_function"); + } + + #[test] + fn test_detect_ts_symbol_function() { + let sym = detect_ts_symbol("function processData(data: any): void {", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "processData"); + assert_eq!(sym.kind, "function"); + } + + #[test] + fn test_detect_ts_symbol_class() { + let sym = detect_ts_symbol("export class UserService {", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "UserService"); + assert_eq!(sym.kind, "class"); + } + + #[test] + fn test_detect_ts_symbol_interface() { + let sym = detect_ts_symbol("interface UserData {", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "UserData"); + assert_eq!(sym.kind, "interface"); + } + + #[test] + fn test_detect_go_symbol_function() { + let sym = detect_go_symbol( + "func HandleRequest(w http.ResponseWriter, r *http.Request) {", + 1, + ); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "HandleRequest"); + assert_eq!(sym.kind, "function"); + } + + #[test] + fn test_detect_go_symbol_struct() { + let sym = detect_go_symbol("type Config struct {", 1); + assert!(sym.is_some()); + let sym = sym.unwrap(); + assert_eq!(sym.name, "Config"); + assert_eq!(sym.kind, "struct"); + } + + #[test] + fn test_extract_symbols_from_content() { + let rust_code = r#" +pub struct Server { + port: u16, +} + +impl Server { + pub fn new(port: u16) -> Self { + Self { port } + } +} +"#; + let symbols = extract_symbols_from_content(rust_code, "rs"); + assert!(!symbols.is_empty()); + assert!(symbols + .iter() + .any(|s| s.name == "Server" && s.kind == "struct")); + assert!(symbols + .iter() + .any(|s| s.name == "new" && s.kind == "function")); + } + + #[test] + fn test_extract_name() { + assert_eq!(extract_name("fn hello() {", "fn "), "hello"); + assert_eq!(extract_name("struct MyStruct {", "struct "), "MyStruct"); + assert_eq!(extract_name("def process():", "def "), "process"); + } +} From b1518b5aef48c424727ce5d9a619efafe4fd89e0 Mon Sep 17 00:00:00 2001 From: Rishi Tank Date: Fri, 2 Jan 2026 04:00:26 +0000 Subject: [PATCH 07/37] fix: address all PR review comments Security fixes: - Add path traversal protection in extract_symbols tool - Add path traversal protection in resources/read handler - Canonicalize both workspace and target paths before comparison - Fix file URI construction for Windows paths Bug fixes: - Fix generic impl block parsing (impl Foo now extracts 'Foo') - Fix prompt template mid-line conditionals (no longer drops entire lines) - Fix is_cancelled logic with separate cancelled_requests tracking - Use environment variable for SDK sync template injection Workflow improvements: - Update softprops/action-gh-release from v1 to v2 - Add [skip ci] to version bump commit to prevent recursive CI Documentation: - Update MCP_IMPROVEMENTS.md to reflect all implemented features Tests: - Add edge case tests for prompts (missing args, empty values, non-existent) - Total tests: 157 (was 152) --- .github/workflows/release.yml | 4 +- .github/workflows/sdk-sync.yml | 4 +- docs/MCP_IMPROVEMENTS.md | 60 ++++++++------ src/mcp/prompts.rs | 147 +++++++++++++++++++++++++++------ src/mcp/resources.rs | 33 +++++++- src/mcp/server.rs | 30 +++++-- src/tools/workspace.rs | 39 ++++++++- 7 files changed, 247 insertions(+), 70 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8926417..d97440c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -145,7 +145,7 @@ jobs: git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" git add Cargo.toml - git commit -m "chore: bump version to $VERSION" + git commit -m "chore: bump version to $VERSION [skip ci]" git push env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -237,7 +237,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Create Release - uses: softprops/action-gh-release@v1 + uses: softprops/action-gh-release@v2 with: tag_name: v${{ needs.check.outputs.version }} name: v${{ needs.check.outputs.version }} diff --git a/.github/workflows/sdk-sync.yml b/.github/workflows/sdk-sync.yml index 1de2c64..0c0f3e1 100644 --- a/.github/workflows/sdk-sync.yml +++ b/.github/workflows/sdk-sync.yml @@ -66,9 +66,11 @@ jobs: - name: Create tracking issue if: steps.todos.outputs.has_todos == 'true' || github.event.inputs.create_issue == 'true' uses: actions/github-script@v7 + env: + TODOS_OUTPUT: ${{ steps.todos.outputs.todos }} with: script: | - const todos = `${{ steps.todos.outputs.todos }}`; + const todos = process.env.TODOS_OUTPUT || ''; const title = `[SDK Sync] Weekly Augment SDK Review - ${new Date().toISOString().split('T')[0]}`; // Check if issue already exists this week diff --git a/docs/MCP_IMPROVEMENTS.md b/docs/MCP_IMPROVEMENTS.md index 2b8ecfb..218e091 100644 --- a/docs/MCP_IMPROVEMENTS.md +++ b/docs/MCP_IMPROVEMENTS.md @@ -5,23 +5,27 @@ This document outlines potential improvements to make the Context Engine MCP Ser ## Current Implementation Status ### ✅ Fully Implemented -- **Tools** - All 49 tools for retrieval, indexing, memory, planning, and review +- **Tools** - All 55 tools for retrieval, indexing, memory, planning, review, navigation, and workspace analysis - **JSON-RPC 2.0** - Full request/response/notification handling - **Stdio Transport** - Standard input/output for MCP clients - **HTTP Transport** - Axum-based HTTP server with SSE -- **Logging Capability** - Structured logging support +- **Logging Capability** - Structured logging support with `logging/setLevel` handler - **Tools List Changed** - Dynamic tool list notifications +- **Resources** - Full `resources/list` and `resources/read` with file:// URI scheme +- **Resource Subscriptions** - Subscribe/unsubscribe to file changes +- **Prompts** - 5 pre-defined prompt templates with argument substitution +- **Completions API** - Autocomplete suggestions for tool/prompt arguments +- **Progress Notifications** - Long-running operation progress with ProgressReporter +- **Cancellation** - Cancel in-progress operations via `notifications/cancelled` +- **Roots Support** - Client-provided workspace roots via `roots/list` +- **Navigation Tools** - `find_references`, `go_to_definition`, `diff_files` +- **Workspace Tools** - `workspace_stats`, `git_status`, `extract_symbols` ### 🔶 Partially Implemented -- **Resources** - Capability declared but not actively used -- **Prompts** - Capability declared but no prompts defined +- **Resource Templates** - URI templates for dynamic resources (planned) ### ❌ Not Yet Implemented -- **Resource Subscriptions** - Subscribe to file/resource changes -- **Prompt Templates** - Pre-defined prompt templates with arguments -- **Completions API** - Autocomplete suggestions for prompts/resources -- **Progress Notifications** - Long-running operation progress -- **Cancellation** - Cancel in-progress operations +- **Sampling** - Server-initiated LLM requests (requires client support) --- @@ -199,24 +203,30 @@ Bridge with LSP servers for richer code intelligence. ## Implementation Priority -### Phase 1 (Next Release) +### Phase 1 (v2.0.0 - Complete ✅) 1. ✅ Workflow improvements (PR-based releases) 2. ✅ Dependabot configuration -3. 🔲 Prompt templates (basic set) -4. 🔲 find_references tool -5. 🔲 go_to_definition tool - -### Phase 2 -1. 🔲 Resource subscriptions -2. 🔲 Progress notifications -3. 🔲 diff_files tool -4. 🔲 Caching layer - -### Phase 3 -1. 🔲 Completions API -2. 🔲 Plugin system -3. 🔲 AST query tool -4. 🔲 Request cancellation +3. ✅ Prompt templates (5 templates with conditionals) +4. ✅ find_references tool +5. ✅ go_to_definition tool +6. ✅ Resource subscriptions +7. ✅ Progress notifications +8. ✅ diff_files tool +9. ✅ Completions API +10. ✅ Request cancellation +11. ✅ Workspace analysis tools (workspace_stats, git_status, extract_symbols) +12. ✅ logging/setLevel handler + +### Phase 2 (Next) +1. 🔲 Caching layer for expensive operations +2. 🔲 Plugin system for extensibility +3. 🔲 AST query tool (tree-sitter integration) +4. 🔲 Dependency graph analysis + +### Phase 3 (Future) +1. 🔲 LSP integration for richer code intelligence +2. 🔲 Sampling support (server-initiated LLM requests) +3. 🔲 Resource templates for dynamic URIs --- diff --git a/src/mcp/prompts.rs b/src/mcp/prompts.rs index ef84738..f8e326c 100644 --- a/src/mcp/prompts.rs +++ b/src/mcp/prompts.rs @@ -206,37 +206,58 @@ Include: self.prompts.get(name).map(|(prompt, template)| { let mut text = template.template.clone(); - // Simple template substitution - for (key, value) in arguments { - text = text.replace(&format!("{{{{{}}}}}", key), value); + // First, handle all conditionals (before simple substitution) + // Find all {{#if var}}...{{/if}} blocks and process them + loop { + let if_start = text.find("{{#if "); + if if_start.is_none() { + break; + } + let start = if_start.unwrap(); + + // Find the variable name + let var_start = start + 6; // "{{#if " is 6 chars + let var_end = match text[var_start..].find("}}") { + Some(pos) => var_start + pos, + None => break, + }; + let var_name = text[var_start..var_end].trim(); + + // Find the matching {{/if}} + let block_start = var_end + 2; // skip "}}" + let endif_pos = match text[block_start..].find("{{/if}}") { + Some(pos) => block_start + pos, + None => break, + }; + let content = &text[block_start..endif_pos]; + let block_end = endif_pos + 7; // "{{/if}}" is 7 chars + + // Check if the variable is provided and non-empty + let should_include = arguments + .get(var_name) + .map(|v| !v.is_empty()) + .unwrap_or(false); + + if should_include { + // Keep the content, remove the markers + text = format!("{}{}{}", &text[..start], content, &text[block_end..]); + } else { + // Remove the entire block including markers + text = format!("{}{}", &text[..start], &text[block_end..]); + } } - // Handle conditionals (very simple implementation) - // {{#if var}}content{{/if}} + // Simple template substitution for {{variable}} for (key, value) in arguments { - let if_pattern = format!("{{{{#if {}}}}}", key); - let endif_pattern = "{{/if}}"; - - if let Some(start) = text.find(&if_pattern) { - if let Some(end) = text[start..].find(endif_pattern) { - let content = &text[start + if_pattern.len()..start + end]; - if !value.is_empty() { - text = text - .replace(&text[start..start + end + endif_pattern.len()], content); - } else { - text = - text.replace(&text[start..start + end + endif_pattern.len()], ""); - } - } - } + text = text.replace(&format!("{{{{{}}}}}", key), value); } - // Clean up remaining template markers - text = text - .lines() - .filter(|line| !line.contains("{{#if") && !line.contains("{{/if}}")) - .collect::>() - .join("\n"); + // Replace any remaining unsubstituted placeholders with empty string + // This handles optional arguments that weren't provided + let placeholder_re = regex::Regex::new(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}").ok(); + if let Some(re) = placeholder_re { + text = re.replace_all(&text, "").to_string(); + } GetPromptResult { description: Some(prompt.description.clone()), @@ -274,4 +295,78 @@ mod tests { let result = result.unwrap(); assert_eq!(result.messages.len(), 1); } + + #[test] + fn test_get_nonexistent_prompt() { + let registry = PromptRegistry::new(); + let args = HashMap::new(); + let result = registry.get("nonexistent_prompt", &args); + assert!(result.is_none()); + } + + #[test] + fn test_get_prompt_missing_required_args() { + let registry = PromptRegistry::new(); + // code_review requires 'code' and 'language' but we provide neither + let args = HashMap::new(); + let result = registry.get("code_review", &args); + assert!(result.is_some()); + // The template placeholders should be removed (empty string replacement) + let text = match &result.unwrap().messages[0].content { + PromptContent::Text { text } => text.clone(), + _ => panic!("Expected text content"), + }; + // Should not contain unsubstituted placeholders + assert!(!text.contains("{{code}}")); + assert!(!text.contains("{{language}}")); + } + + #[test] + fn test_get_prompt_empty_arg_values() { + let registry = PromptRegistry::new(); + let mut args = HashMap::new(); + args.insert("code".to_string(), "".to_string()); + args.insert("language".to_string(), "".to_string()); + + let result = registry.get("code_review", &args); + assert!(result.is_some()); + // Empty values should still work (conditionals should hide content) + } + + #[test] + fn test_conditional_with_value() { + let registry = PromptRegistry::new(); + let mut args = HashMap::new(); + args.insert("code".to_string(), "fn test() {}".to_string()); + args.insert("language".to_string(), "rust".to_string()); + args.insert("focus".to_string(), "security".to_string()); + + let result = registry.get("code_review", &args); + assert!(result.is_some()); + let text = match &result.unwrap().messages[0].content { + PromptContent::Text { text } => text.clone(), + _ => panic!("Expected text content"), + }; + // With focus provided, the conditional content should be included + assert!(text.contains("security")); + } + + #[test] + fn test_conditional_without_value() { + let registry = PromptRegistry::new(); + let mut args = HashMap::new(); + args.insert("code".to_string(), "fn test() {}".to_string()); + args.insert("language".to_string(), "rust".to_string()); + // Don't provide 'focus' - conditional should be removed + + let result = registry.get("code_review", &args); + assert!(result.is_some()); + let text = match &result.unwrap().messages[0].content { + PromptContent::Text { text } => text.clone(), + _ => panic!("Expected text content"), + }; + // Without focus, the conditional content should be removed + assert!(!text.contains("{{#if")); + assert!(!text.contains("{{/if}}")); + } } diff --git a/src/mcp/resources.rs b/src/mcp/resources.rs index 5583f37..b7363fd 100644 --- a/src/mcp/resources.rs +++ b/src/mcp/resources.rs @@ -81,7 +81,8 @@ impl ResourceRegistry { .to_string_lossy() .to_string(); - let uri = format!("file://{}", path.display()); + // Construct proper file:// URI (handle Windows paths) + let uri = Self::path_to_file_uri(path); let mime_type = Self::guess_mime_type(path); Resource { @@ -118,13 +119,16 @@ impl ResourceRegistry { ))); }; - // Security: ensure path is within workspace + // Security: canonicalize both workspace and path, then verify path is within workspace let workspace = self.context_service.workspace(); + let workspace_canonical = workspace + .canonicalize() + .map_err(|e| Error::InvalidToolArguments(format!("Cannot resolve workspace: {}", e)))?; let canonical = path .canonicalize() .map_err(|e| Error::InvalidToolArguments(format!("Cannot resolve path: {}", e)))?; - if !canonical.starts_with(workspace) { + if !canonical.starts_with(&workspace_canonical) { return Err(Error::InvalidToolArguments( "Access denied: path outside workspace".to_string(), )); @@ -235,6 +239,29 @@ impl ResourceRegistry { || name.ends_with(".pyc") } + /// Convert a path to a proper file:// URI. + /// Handles Windows paths by converting backslashes and adding leading slash. + fn path_to_file_uri(path: &std::path::Path) -> String { + let path_str = path.to_string_lossy(); + + // On Windows, paths like C:\foo\bar need to become file:///C:/foo/bar + #[cfg(windows)] + { + let normalized = path_str.replace('\\', "/"); + if normalized.chars().nth(1) == Some(':') { + // Absolute Windows path like C:/foo + format!("file:///{}", normalized) + } else { + format!("file://{}", normalized) + } + } + + #[cfg(not(windows))] + { + format!("file://{}", path_str) + } + } + /// Guess MIME type from file extension. fn guess_mime_type(path: &std::path::Path) -> Option { let ext = path.extension()?.to_str()?; diff --git a/src/mcp/server.rs b/src/mcp/server.rs index 4bdd32b..21cb44d 100644 --- a/src/mcp/server.rs +++ b/src/mcp/server.rs @@ -70,6 +70,8 @@ pub struct McpServer { roots: Arc>>, /// Active request IDs for cancellation support. active_requests: Arc>>, + /// Explicitly cancelled request IDs. + cancelled_requests: Arc>>, /// Current log level. log_level: Arc>, } @@ -85,6 +87,7 @@ impl McpServer { version: VERSION.to_string(), roots: Arc::new(RwLock::new(Vec::new())), active_requests: Arc::new(RwLock::new(HashSet::new())), + cancelled_requests: Arc::new(RwLock::new(HashSet::new())), log_level: Arc::new(RwLock::new(LogLevel::default())), } } @@ -104,6 +107,7 @@ impl McpServer { version: VERSION.to_string(), roots: Arc::new(RwLock::new(Vec::new())), active_requests: Arc::new(RwLock::new(HashSet::new())), + cancelled_requests: Arc::new(RwLock::new(HashSet::new())), log_level: Arc::new(RwLock::new(LogLevel::default())), } } @@ -123,9 +127,20 @@ impl McpServer { self.roots.read().await.clone() } - /// Check if a request has been cancelled. + /// Check if a request has been explicitly cancelled. pub async fn is_cancelled(&self, id: &RequestId) -> bool { - !self.active_requests.read().await.contains(id) + self.cancelled_requests.read().await.contains(id) + } + + /// Mark a request as cancelled. + pub async fn cancel_request(&self, id: &RequestId) { + self.cancelled_requests.write().await.insert(id.clone()); + } + + /// Clean up a completed request from tracking sets. + pub async fn complete_request(&self, id: &RequestId) { + self.active_requests.write().await.remove(id); + self.cancelled_requests.write().await.remove(id); } /// Run the server with the given transport. @@ -190,8 +205,8 @@ impl McpServer { ))), }; - // Remove from active requests - self.active_requests.write().await.remove(&req.id); + // Clean up request tracking + self.complete_request(&req.id).await; match result { Ok(value) => JsonRpcResponse { @@ -222,7 +237,7 @@ impl McpServer { info!("Client initialized"); } "notifications/cancelled" => { - // Extract the request ID from params and cancel it + // Extract the request ID from params and mark it as cancelled if let Some(params) = notif.params { #[derive(serde::Deserialize)] struct CancelledParams { @@ -231,10 +246,7 @@ impl McpServer { } if let Ok(cancel) = serde_json::from_value::(params) { info!("Cancelling request: {:?}", cancel.request_id); - self.active_requests - .write() - .await - .remove(&cancel.request_id); + self.cancel_request(&cancel.request_id).await; } } } diff --git a/src/tools/workspace.rs b/src/tools/workspace.rs index 2b84fa5..5aa1220 100644 --- a/src/tools/workspace.rs +++ b/src/tools/workspace.rs @@ -320,8 +320,26 @@ impl ToolHandler for ExtractSymbolsTool { async fn execute(&self, args: HashMap) -> Result { let file_path = get_string_arg(&args, "file_path")?; - let full_path = self.service.workspace_path().join(&file_path); - let content = match fs::read_to_string(&full_path).await { + let workspace = self.service.workspace_path(); + let full_path = workspace.join(&file_path); + + // Security: canonicalize and verify path stays within workspace + let workspace_canonical = match workspace.canonicalize() { + Ok(p) => p, + Err(e) => return Ok(error_result(format!("Cannot resolve workspace: {}", e))), + }; + let path_canonical = match full_path.canonicalize() { + Ok(p) => p, + Err(e) => return Ok(error_result(format!("Cannot resolve {}: {}", file_path, e))), + }; + if !path_canonical.starts_with(&workspace_canonical) { + return Ok(error_result(format!( + "Path escapes workspace: {}", + file_path + ))); + } + + let content = match fs::read_to_string(&path_canonical).await { Ok(c) => c, Err(e) => return Ok(error_result(format!("Failed to read file: {}", e))), }; @@ -410,8 +428,21 @@ fn detect_rust_symbol(line: &str, line_num: usize) -> Option { }); } if line.starts_with("impl ") { - let name = line.strip_prefix("impl ").unwrap_or(line); - let name = name.split_whitespace().next().unwrap_or("").to_string(); + let rest = line.strip_prefix("impl ").unwrap_or(line); + // Skip generic parameters if present (e.g., impl Foo) + let rest = if rest.starts_with('<') { + rest.split_once('>') + .map(|(_, after)| after.trim_start()) + .unwrap_or(rest) + } else { + rest + }; + // Extract the type/trait name (first identifier before '<', ' ', '{', or 'for') + let name = rest + .split(|c: char| !c.is_alphanumeric() && c != '_') + .next() + .unwrap_or("") + .to_string(); return Some(Symbol { name, kind: "impl".to_string(), From 662828fc5753beff49c3a337b32dc1b0931cf8ba Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Fri, 2 Jan 2026 04:03:08 +0000 Subject: [PATCH 08/37] =?UTF-8?q?=F0=9F=93=9D=20Add=20docstrings=20to=20`f?= =?UTF-8?q?eature/workflow-improvements`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Docstrings generation was requested by @rishitank. * https://github.com/rishitank/context-engine/pull/1#issuecomment-3704422544 The following files were modified: * `src/mcp/progress.rs` * `src/mcp/prompts.rs` * `src/mcp/resources.rs` * `src/mcp/server.rs` * `src/tools/mod.rs` * `src/tools/navigation.rs` * `src/tools/workspace.rs` --- src/mcp/progress.rs | 157 ++++++++++++++- src/mcp/prompts.rs | 71 ++++++- src/mcp/resources.rs | 148 +++++++++++++- src/mcp/server.rs | 414 +++++++++++++++++++++++++++++++++++++--- src/tools/mod.rs | 21 +- src/tools/navigation.rs | 258 ++++++++++++++++++++++++- src/tools/workspace.rs | 400 +++++++++++++++++++++++++++++++++++++- 7 files changed, 1411 insertions(+), 58 deletions(-) diff --git a/src/mcp/progress.rs b/src/mcp/progress.rs index 7d5b364..ebbbfb0 100644 --- a/src/mcp/progress.rs +++ b/src/mcp/progress.rs @@ -34,7 +34,23 @@ pub struct ProgressNotification { } impl ProgressNotification { - /// Create a new progress notification. + /// Constructs a JSON-RPC progress notification containing the provided token, progress value, optional total, and optional message. + /// + /// # Examples + /// + /// ``` + /// let note = ProgressNotification::new( + /// ProgressToken::String("op-1".into()), + /// 50, + /// Some(100), + /// Some("in progress".into()), + /// ); + /// assert_eq!(note.jsonrpc, "2.0"); + /// assert_eq!(note.method, "notifications/progress"); + /// assert_eq!(note.params.progress, 50); + /// assert_eq!(note.params.total, Some(100)); + /// assert_eq!(note.params.message.as_deref(), Some("in progress")); + /// ``` pub fn new( token: ProgressToken, progress: u64, @@ -63,7 +79,17 @@ pub struct ProgressReporter { } impl ProgressReporter { - /// Create a new progress reporter. + /// Constructs a ProgressReporter bound to a progress token, a sender channel, and an optional total. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use crate::mcp::progress::{ProgressReporter, ProgressToken}; + /// + /// let (tx, _rx) = mpsc::channel(1); + /// let reporter = ProgressReporter::new(ProgressToken::Number(1), tx, Some(100)); + /// ``` pub fn new( token: ProgressToken, sender: mpsc::Sender, @@ -76,7 +102,21 @@ impl ProgressReporter { } } - /// Report progress. + /// Send a progress notification for this reporter. + /// + /// The optional `message`, if provided, is included in the notification. Send failures are ignored. + /// + /// # Examples + /// + /// ``` + /// # use futures::executor::block_on; + /// # use crate::mcp::progress::{ProgressManager, ProgressToken}; + /// let manager = ProgressManager::new(); + /// let reporter = manager.create_reporter(Some(100)); + /// block_on(async { + /// reporter.report(42, Some("halfway")).await; + /// }); + /// ``` pub async fn report(&self, progress: u64, message: Option<&str>) { let notification = ProgressNotification::new( self.token.clone(), @@ -87,7 +127,20 @@ impl ProgressReporter { let _ = self.sender.send(notification).await; } - /// Report progress with percentage. + /// Converts a percentage into an absolute progress value (using the reporter's `total` when present) and emits that progress notification. + /// + /// # Examples + /// + /// ``` + /// # use tokio::sync::mpsc; + /// # use crate::mcp::progress::{ProgressManager}; + /// # #[tokio::test] + /// # async fn example_report_percent() { + /// let manager = ProgressManager::new(); + /// let reporter = manager.create_reporter(Some(200)); + /// reporter.report_percent(50, Some("Halfway")).await; + /// # } + /// ``` pub async fn report_percent(&self, percent: u64, message: Option<&str>) { let progress = if let Some(total) = self.total { (percent * total) / 100 @@ -97,7 +150,26 @@ impl ProgressReporter { self.report(progress, message).await; } - /// Complete the progress. + /// Report completion for this reporter by sending a notification with progress set to the reporter's total, if one is configured. + /// + /// If the reporter has no configured total, no notification is sent. + /// + /// # Parameters + /// + /// - `message`: Optional message to include with the completion notification. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::mpsc; + /// use crate::mcp::progress::{ProgressReporter, ProgressToken}; + /// + /// // Create a reporter with a total of 100 and send completion. + /// let rt = tokio::runtime::Runtime::new().unwrap(); + /// let (tx, _rx) = mpsc::channel(10); + /// let reporter = ProgressReporter::new(ProgressToken::Number(1), tx, Some(100)); + /// rt.block_on(reporter.complete(Some("finished"))); + /// ``` pub async fn complete(&self, message: Option<&str>) { if let Some(total) = self.total { self.report(total, message).await; @@ -113,7 +185,15 @@ pub struct ProgressManager { } impl ProgressManager { - /// Create a new progress manager. + /// Creates a new ProgressManager configured to emit progress notifications. + /// + /// # Examples + /// + /// ``` + /// let mgr = ProgressManager::new(); + /// // obtain a receiver to consume notifications + /// let _recv = mgr.receiver(); + /// ``` pub fn new() -> Self { let (sender, receiver) = mpsc::channel(100); Self { @@ -123,7 +203,26 @@ impl ProgressManager { } } - /// Create a new progress reporter with a generated token. + /// Creates a new ProgressReporter that uses a generated numeric token. + /// + /// The generated token is a sequential numeric identifier unique to this ProgressManager instance. + /// + /// # Parameters + /// + /// - `total`: Optional total number of work units for the operation; if provided, percentage-based reporting + /// will be computed against this value. + /// + /// # Returns + /// + /// A `ProgressReporter` bound to this manager's sender, using a newly generated numeric `ProgressToken`. + /// + /// # Examples + /// + /// ``` + /// let manager = ProgressManager::new(); + /// let reporter = manager.create_reporter(Some(100)); + /// // `reporter` can now be used to emit progress updates. + /// ``` pub fn create_reporter(&self, total: Option) -> ProgressReporter { let id = self .next_id @@ -132,7 +231,19 @@ impl ProgressManager { ProgressReporter::new(token, self.sender.clone(), total) } - /// Create a progress reporter with a specific token. + /// Creates a ProgressReporter bound to the given token and optional total. + /// + /// The returned reporter will send progress notifications tagged with `token` + /// using the manager's internal channel. + /// + /// # Examples + /// + /// ``` + /// use crate::mcp::progress::{ProgressManager, ProgressToken}; + /// + /// let manager = ProgressManager::new(); + /// let reporter = manager.create_reporter_with_token(ProgressToken::String("op".into()), Some(100)); + /// ``` pub fn create_reporter_with_token( &self, token: ProgressToken, @@ -141,13 +252,39 @@ impl ProgressManager { ProgressReporter::new(token, self.sender.clone(), total) } - /// Get the receiver for progress notifications. + /// Returns a clone of the shared receiver handle for progress notifications. + /// + /// The returned `Arc>>` can be cloned and used by consumers to lock and receive progress notifications. + /// + /// # Examples + /// + /// ``` + /// let manager = ProgressManager::new(); + /// let rx = manager.receiver(); + /// // `rx` is a clone of the manager's shared receiver handle + /// assert!(Arc::strong_count(&rx) >= 1); + /// ``` pub fn receiver(&self) -> Arc>> { self.receiver.clone() } } impl Default for ProgressManager { + /// Creates a ProgressManager initialized with its standard channel and token counter. + + /// + + /// # Examples + + /// + + /// ``` + + /// let mgr = crate::mcp::progress::ProgressManager::default(); + + /// let _recv = mgr.receiver(); + + /// ``` fn default() -> Self { Self::new() } @@ -263,4 +400,4 @@ mod tests { assert!(json.contains("\"total\":100")); assert!(json.contains("\"message\":\"Working...\"")); } -} +} \ No newline at end of file diff --git a/src/mcp/prompts.rs b/src/mcp/prompts.rs index f8e326c..b7975a8 100644 --- a/src/mcp/prompts.rs +++ b/src/mcp/prompts.rs @@ -71,14 +71,32 @@ pub struct PromptTemplate { } impl PromptRegistry { - /// Create a new registry with built-in prompts. + /// Creates a new registry populated with the built-in prompts. + /// + /// # Examples + /// + /// ``` + /// let registry = crate::mcp::prompts::PromptRegistry::new(); + /// assert!(!registry.list().is_empty()); + /// ``` pub fn new() -> Self { let mut registry = Self::default(); registry.register_builtin_prompts(); registry } - /// Register built-in prompts. + /// Populates the registry with the built-in prompt definitions used by the application. + /// + /// Registers three prompts — "code_review", "explain_code", and "write_tests" — each with their + /// argument metadata and template text (including conditional sections and variable placeholders). + /// + /// # Examples + /// + /// ``` + /// let registry = crate::mcp::prompts::PromptRegistry::new(); + /// let names: Vec<_> = registry.list().into_iter().map(|p| p.name).collect(); + /// assert!(names.contains(&"code_review".to_string())); + /// ``` fn register_builtin_prompts(&mut self) { // Code Review Prompt self.register( @@ -191,17 +209,58 @@ Include: ); } - /// Register a prompt. + /// Adds or updates a prompt and its template in the registry. + /// + /// The provided `prompt` is stored under its `name`; if a prompt with the same name + /// already exists it will be replaced along with its template. + /// + /// # Examples + /// + /// ``` + /// let mut registry = PromptRegistry::new(); + /// let prompt = Prompt { + /// name: "example".to_string(), + /// description: "An example prompt".to_string(), + /// arguments: vec![], + /// }; + /// let template = PromptTemplate { template: "Hello {{name}}".to_string() }; + /// registry.register(prompt, template); + /// assert!(registry.list().iter().any(|p| p.name == "example")); + /// ``` pub fn register(&mut self, prompt: Prompt, template: PromptTemplate) { self.prompts.insert(prompt.name.clone(), (prompt, template)); } - /// List all prompts. + /// Retrieve all registered prompts. + /// + /// Returns a vector containing a clone of each registered `Prompt`. The order of prompts is not guaranteed. + /// + /// # Examples + /// + /// ``` + /// let registry = PromptRegistry::new(); + /// let prompts = registry.list(); + /// assert!(prompts.iter().any(|p| p.name == "code_review")); + /// ``` pub fn list(&self) -> Vec { self.prompts.values().map(|(p, _)| p.clone()).collect() } - /// Get a prompt by name with arguments substituted. + /// Retrieve a registered prompt by name and render its template using the provided arguments. + /// + /// The template supports conditional blocks of the form `{{#if var}}...{{/if}}` (the block is included only when `var` is present and not empty) and simple `{{variable}}` substitutions. Any remaining unsubstituted placeholders are removed from the output. Returns `None` if no prompt with the given name exists. On success the result contains the prompt description and a single user-role message with the rendered text. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashMap; + /// + /// let registry = PromptRegistry::new(); + /// let mut args = HashMap::new(); + /// args.insert("code".to_string(), "fn main() {}".to_string()); + /// let res = registry.get("code_review", &args); + /// assert!(res.is_some()); + /// ``` pub fn get(&self, name: &str, arguments: &HashMap) -> Option { self.prompts.get(name).map(|(prompt, template)| { let mut text = template.template.clone(); @@ -369,4 +428,4 @@ mod tests { assert!(!text.contains("{{#if")); assert!(!text.contains("{{/if}}")); } -} +} \ No newline at end of file diff --git a/src/mcp/resources.rs b/src/mcp/resources.rs index b7363fd..692d5f6 100644 --- a/src/mcp/resources.rs +++ b/src/mcp/resources.rs @@ -59,7 +59,22 @@ pub struct ResourceRegistry { } impl ResourceRegistry { - /// Create a new resource registry. + /// Creates a new ResourceRegistry backed by the given workspace context. + /// + /// # Parameters + /// + /// - `context_service`: shared workspace context used to resolve the workspace root and related operations. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// # use crate::mcp::resources::ResourceRegistry; + /// # use crate::context::ContextService; + /// // Construct or obtain an Arc from your application. + /// let ctx: Arc = Arc::new(ContextService::default()); + /// let registry = ResourceRegistry::new(ctx); + /// ``` pub fn new(context_service: Arc) -> Self { Self { context_service, @@ -67,7 +82,24 @@ impl ResourceRegistry { } } - /// List available resources (files in workspace). + /// Lists workspace files as `Resource` entries with optional cursor-based pagination. + /// + /// The `cursor` parameter, if provided, is a resource name to start listing after; results include up to 100 entries. + /// + /// # Returns + /// + /// `ListResourcesResult` containing the discovered resources and an optional `next_cursor` string to continue pagination. + /// + /// # Examples + /// + /// ``` + /// # tokio_test::block_on(async { + /// // `registry` would be constructed with a real `ContextService` in production. + /// // let registry = ResourceRegistry::new(context_service); + /// // let result = registry.list(None).await.unwrap(); + /// // assert!(result.resources.len() <= 100); + /// # }); + /// ``` pub async fn list(&self, cursor: Option<&str>) -> Result { let workspace = self.context_service.workspace(); let files = self.discover_files(workspace, 100, cursor).await?; @@ -107,7 +139,34 @@ impl ResourceRegistry { }) } - /// Read a resource by URI. + /// Reads a resource identified by a `file://` URI from the workspace and returns its contents. + /// + /// # Arguments + /// + /// * `uri` - A `file://` URI pointing to a file located inside the workspace. + /// + /// # Returns + /// + /// A `ReadResourceResult` containing a single `ResourceContents` entry with the provided `uri`, the inferred `mime_type` (if any), and `text` set to the file's UTF-8 contents. + /// + /// # Errors + /// + /// Returns `Error::InvalidToolArguments` when: + /// - the URI does not start with `file://`, + /// - the workspace or target path cannot be canonicalized, + /// - the resolved path is outside the workspace, or + /// - the file cannot be read. + /// + /// # Examples + /// + /// ``` + /// # async fn example_usage(registry: &crate::mcp::resources::ResourceRegistry) -> anyhow::Result<()> { + /// let result = registry.read("file:///path/to/workspace/file.txt").await?; + /// assert_eq!(result.contents.len(), 1); + /// let content = &result.contents[0]; + /// assert_eq!(content.uri, "file:///path/to/workspace/file.txt"); + /// # Ok(()) } + /// ``` pub async fn read(&self, uri: &str) -> Result { // Parse file:// URI let path = if let Some(path) = uri.strip_prefix("file://") { @@ -151,7 +210,28 @@ impl ResourceRegistry { }) } - /// Subscribe to resource changes. + /// Registers a session to receive change notifications for the given resource URI. + /// + /// The session ID will be recorded in the registry's in-memory subscription map for the specified URI. + /// + /// # Parameters + /// + /// - `uri`: The resource URI to subscribe to (e.g., a `file://` URI). + /// - `session_id`: The identifier of the session to register for notifications. + /// + /// # Returns + /// + /// `Ok(())` on success. + /// + /// # Examples + /// + /// ```no_run + /// # use std::sync::Arc; + /// # use tokio::runtime::Runtime; + /// # async fn _example(registry: &crate::mcp::resources::ResourceRegistry) { + /// registry.subscribe("file:///path/to/file", "session-123").await.unwrap(); + /// # } + /// ``` pub async fn subscribe(&self, uri: &str, session_id: &str) -> Result<()> { let mut subs = self.subscriptions.write().await; subs.entry(uri.to_string()) @@ -160,7 +240,20 @@ impl ResourceRegistry { Ok(()) } - /// Unsubscribe from resource changes. + /// Remove a session's subscription for the given resource URI. + /// + /// # Examples + /// + /// ``` + /// # use std::sync::Arc; + /// # use tokio::runtime::Runtime; + /// # // Assume `registry` is an initialized `ResourceRegistry`. + /// # let rt = Runtime::new().unwrap(); + /// # rt.block_on(async { + /// let registry = /* ResourceRegistry instance */ unimplemented!(); + /// registry.unsubscribe("file:///path/to/file", "session-123").await.unwrap(); + /// # }); + /// ``` pub async fn unsubscribe(&self, uri: &str, session_id: &str) -> Result<()> { let mut subs = self.subscriptions.write().await; if let Some(sessions) = subs.get_mut(uri) { @@ -230,7 +323,19 @@ impl ResourceRegistry { Ok(files) } - /// Check if a file should be ignored. + /// Returns whether a file or directory name matches common ignore patterns used when discovering files. + /// + /// Matches directory names: "node_modules", "target", "dist", "build", "__pycache__", ".git", + /// and files whose names end with `.lock` or `.pyc`. + /// + /// # Examples + /// + /// ``` + /// assert!(should_ignore("node_modules")); + /// assert!(should_ignore("Cargo.lock")); + /// assert!(should_ignore("__pycache__")); + /// assert!(!should_ignore("src")); + /// ``` fn should_ignore(name: &str) -> bool { matches!( name, @@ -239,8 +344,19 @@ impl ResourceRegistry { || name.ends_with(".pyc") } - /// Convert a path to a proper file:// URI. - /// Handles Windows paths by converting backslashes and adding leading slash. + /// Convert a filesystem path to a file:// URI. + /// + /// On Windows this replaces backslashes with forward slashes and prefixes + /// absolute drive paths (e.g., `C:/path`) with `file:///`. On other platforms + /// the path is prefixed with `file://`. + /// + /// # Examples + /// + /// ``` + /// use std::path::Path; + /// let uri = path_to_file_uri(Path::new("/some/path")); + /// assert!(uri.starts_with("file://")); + /// ``` fn path_to_file_uri(path: &std::path::Path) -> String { let path_str = path.to_string_lossy(); @@ -262,7 +378,19 @@ impl ResourceRegistry { } } - /// Guess MIME type from file extension. + /// Infer a MIME type string for a file path based on its extension. + /// + /// Returns `Some` with a guessed MIME type for known extensions, `Some("text/plain")` for unknown extensions, + /// and `None` if the path has no extension or the extension is not valid UTF-8. + /// + /// # Examples + /// + /// ``` + /// use std::path::Path; + /// assert_eq!(guess_mime_type(Path::new("main.rs")), Some("text/x-rust".to_string())); + /// assert_eq!(guess_mime_type(Path::new("data.unknown")), Some("text/plain".to_string())); + /// assert_eq!(guess_mime_type(Path::new("no_extension")), None); + /// ``` fn guess_mime_type(path: &std::path::Path) -> Option { let ext = path.extension()?.to_str()?; let mime = match ext { @@ -396,4 +524,4 @@ mod tests { assert_eq!(parsed.contents.len(), 1); assert_eq!(parsed.contents[0].text, Some("code".to_string())); } -} +} \ No newline at end of file diff --git a/src/mcp/server.rs b/src/mcp/server.rs index 21cb44d..4cde8d5 100644 --- a/src/mcp/server.rs +++ b/src/mcp/server.rs @@ -31,6 +31,20 @@ pub enum LogLevel { } impl LogLevel { + /// Converts a case-insensitive string into the corresponding `LogLevel`, defaulting to `Info` for unknown values. + /// + /// # Returns + /// The matching `LogLevel` variant; `Info` if the input is not recognized. + /// + /// # Examples + /// + /// ``` + /// use crate::mcp::server::LogLevel; + /// + /// assert_eq!(LogLevel::from_str("debug"), LogLevel::Debug); + /// assert_eq!(LogLevel::from_str("Warn"), LogLevel::Warning); + /// assert_eq!(LogLevel::from_str("unknown-level"), LogLevel::Info); + /// ``` fn from_str(s: &str) -> Self { match s.to_lowercase().as_str() { "debug" => Self::Debug, @@ -45,6 +59,17 @@ impl LogLevel { } } + /// Get the lowercase string name for the log level. + /// + /// The returned string is a static, lowercase identifier corresponding to the variant + /// (for example, `"info"`, `"warning"`, or `"error"`). + /// + /// # Examples + /// + /// ``` + /// let lvl = LogLevel::Info; + /// assert_eq!(lvl.as_str(), "info"); + /// ``` fn as_str(&self) -> &'static str { match self { Self::Debug => "debug", @@ -77,7 +102,18 @@ pub struct McpServer { } impl McpServer { - /// Create a new MCP server. + /// Creates a new MCP server with default features. + /// + /// The returned server uses an empty prompt registry, no resource registry (resources disabled), + /// empty workspace roots, no active or cancelled requests, and the default log level and version. + /// + /// # Examples + /// + /// ``` + /// // create a handler appropriate for your setup + /// let handler = /* create or obtain an McpHandler instance */ ; + /// let _server = McpServer::new(handler, "my-server"); + /// ``` pub fn new(handler: McpHandler, name: impl Into) -> Self { Self { handler: Arc::new(handler), @@ -92,7 +128,20 @@ impl McpServer { } } - /// Create a new MCP server with all features. + /// Create a McpServer configured with prompts and an initialized resources registry. + /// + /// The returned server wraps the provided handler and prompt registry in Arcs, + /// constructs a ResourceRegistry from `context_service`, and initializes + /// empty workspace roots, active/cancelled request tracking, and the default log level. + /// + /// # Examples + /// + /// ```ignore + /// use std::sync::Arc; + /// + /// // Assume `handler`, `prompts`, and `context_service` are available. + /// let server = McpServer::with_features(handler, prompts, Arc::new(context_service), "my-server"); + /// ``` pub fn with_features( handler: McpHandler, prompts: PromptRegistry, @@ -112,38 +161,124 @@ impl McpServer { } } - /// Get the current log level. + /// Retrieve the server's current log level. + /// + /// # Returns + /// + /// `LogLevel` containing the server's active log level. + /// + /// # Examples + /// + /// ``` + /// # use futures::executor::block_on; + /// # // `server` must be a `McpServer` instance + /// # let server = todo!(); + /// let level = block_on(server.log_level()); + /// ``` pub async fn log_level(&self) -> LogLevel { *self.log_level.read().await } - /// Set the log level. + /// Update the server's current logging level. + /// + /// This changes the level that the server uses for subsequent log messages. + /// + /// # Examples + /// + /// ``` + /// # use crate::mcp::server::{McpServer, LogLevel}; + /// # async fn doc_example(server: &McpServer) { + /// server.set_log_level(LogLevel::Debug).await; + /// # } + /// ``` pub async fn set_log_level(&self, level: LogLevel) { *self.log_level.write().await = level; } - /// Get the client-provided workspace roots. + /// Retrieve the client-provided workspace roots. + /// + /// # Examples + /// + /// ```no_run + /// // Obtain an McpServer instance from your application context. + /// let server: McpServer = unimplemented!(); + /// + /// // Call the async method to get the current roots. + /// let roots = futures::executor::block_on(server.roots()); + /// assert!(roots.iter().all(|p| p.is_absolute())); + /// ``` pub async fn roots(&self) -> Vec { self.roots.read().await.clone() } - /// Check if a request has been explicitly cancelled. + /// Returns whether the given request ID has been explicitly cancelled. + + /// + + /// # Examples + + /// + + /// ``` + + /// // Assuming `server: McpServer` and `id: RequestId` are available: + + /// // let cancelled = server.is_cancelled(&id).await; + + /// ``` pub async fn is_cancelled(&self, id: &RequestId) -> bool { self.cancelled_requests.read().await.contains(id) } - /// Mark a request as cancelled. + /// Marks the given request ID as cancelled so the server will treat it as cancelled on subsequent checks. + /// + /// # Examples + /// + /// ``` + /// // Assuming `server` is an instance of `McpServer` and `req_id` is a `RequestId`: + /// // server.cancel_request(&req_id).await; + /// ``` pub async fn cancel_request(&self, id: &RequestId) { self.cancelled_requests.write().await.insert(id.clone()); } - /// Clean up a completed request from tracking sets. + /// Remove a request from the server's active and cancelled tracking sets. + /// + /// This removes `id` from both `active_requests` and `cancelled_requests`, ensuring + /// the server no longer treats the request as in-progress or cancelled. + /// + /// # Examples + /// + /// ```no_run + /// # use mcp::server::McpServer; + /// # use mcp::RequestId; + /// # async fn example(server: &McpServer, id: &RequestId) { + /// server.complete_request(id).await; + /// # } + /// ``` pub async fn complete_request(&self, id: &RequestId) { self.active_requests.write().await.remove(id); self.cancelled_requests.write().await.remove(id); } - /// Run the server with the given transport. + /// Run the server loop that processes incoming MCP messages on the provided transport. + /// + /// Starts the transport, receives messages until the transport ends or a send failure occurs, + /// dispatches requests and notifications to the server handlers, stops the transport, and returns + /// when the server has shut down. + /// + /// # Returns + /// + /// `Ok(())` on normal shutdown; an `Err` is returned if starting or stopping the transport fails. + /// + /// # Examples + /// + /// ``` + /// # use std::sync::Arc; + /// # async fn _example(server: Arc, transport: impl crate::transport::Transport) { + /// server.run(transport).await.unwrap(); + /// # } + /// ``` pub async fn run(&self, mut transport: T) -> Result<()> { info!("Starting MCP server: {} v{}", self.name, self.version); @@ -172,7 +307,21 @@ impl McpServer { Ok(()) } - /// Handle a JSON-RPC request. + /// Dispatches an incoming JSON-RPC request to the appropriate handler, tracks the request lifecycle for cancellation, and returns the corresponding JSON-RPC response. + /// + /// The request is registered as active while being processed; upon completion it is removed from active tracking. Known MCP methods are routed to their specific handlers; unknown methods produce a protocol error encoded in the response. + /// + /// # Returns + /// + /// `JsonRpcResponse` containing either a successful `result` value or an `error` describing the failure. + /// + /// # Examples + /// + /// ```no_run + /// // `server` and `request` are assumed to be initialized appropriately. + /// let resp = futures::executor::block_on(server.handle_request(request)); + /// assert_eq!(resp.jsonrpc, "2.0"); + /// ``` async fn handle_request(&self, req: JsonRpcRequest) -> JsonRpcResponse { debug!("Handling request: {} (id: {:?})", req.method, req.id); @@ -228,7 +377,29 @@ impl McpServer { } } - /// Handle a notification. + /// Process an incoming JSON-RPC notification and perform any side effects for known notification types. + /// + /// Known notifications handled: + /// - "notifications/initialized": logs client initialization. + /// - "notifications/cancelled": extracts a `requestId` from `params` and marks the request cancelled. + /// - "notifications/roots/listChanged": logs that client workspace roots changed. + /// Unknown notifications are ignored (logged at debug level). + /// + /// # Examples + /// + /// ```no_run + /// use serde_json::json; + /// + /// // Build a cancelled notification with a `requestId` param. + /// let notif = JsonRpcNotification { + /// jsonrpc: "2.0".into(), + /// method: "notifications/cancelled".into(), + /// params: Some(json!({ "requestId": "some-request-id" })), + /// }; + /// + /// // `server` is an instance of `McpServer`. Call will mark the request cancelled. + /// // server.handle_notification(notif).await; + /// ``` async fn handle_notification(&self, notif: JsonRpcNotification) { debug!("Handling notification: {}", notif.method); @@ -259,7 +430,20 @@ impl McpServer { } } - /// Handle initialize request. + /// Build and return the server's initialize result as JSON. + /// + /// If `params` includes client workspace roots with URIs beginning with `file://`, + /// those paths are added to the server's tracked roots. The returned JSON contains + /// the protocol version, server capabilities (including resources capability only + /// if resources support is enabled), and server info (name and version). + /// + /// # Examples + /// + /// ``` + /// // Call on a server instance: returns an `InitializeResult` serialized as JSON. + /// // let resp = server.handle_initialize(None).await.unwrap(); + /// // assert!(resp.get("protocol_version").is_some()); + /// ``` async fn handle_initialize(&self, params: Option) -> Result { // Extract roots from client if provided if let Some(ref params) = params { @@ -322,7 +506,29 @@ impl McpServer { Ok(serde_json::to_value(result)?) } - /// Handle call tool request. + /// Calls a named tool with the supplied parameters and returns the tool's result as JSON. + /// + /// Expects `params` to be a JSON-encoded `CallToolParams` object containing the tool `name` and `arguments`. + /// + /// # Returns + /// + /// The tool's execution result as a `serde_json::Value`. + /// + /// # Errors + /// + /// Returns `Error::InvalidToolArguments` if `params` is missing or cannot be deserialized into `CallToolParams`, + /// `Error::ToolNotFound` if no tool with the given name is registered, and propagates errors from the tool's + /// execution or JSON serialization. + /// + /// # Examples + /// + /// ``` + /// use serde_json::json; + /// + /// // Example params: { "name": "echo", "arguments": ["hello"] } + /// let params = Some(json!({ "name": "echo", "arguments": ["hello"] })); + /// // let result = server.handle_call_tool(params).await.unwrap(); + /// ``` async fn handle_call_tool(&self, params: Option) -> Result { let params: CallToolParams = params .ok_or_else(|| Error::InvalidToolArguments("Missing params".to_string())) @@ -339,7 +545,22 @@ impl McpServer { Ok(serde_json::to_value(result)?) } - /// Handle list prompts request. + /// List available prompts and return them as a JSON value. + /// + /// The returned JSON matches `ListPromptsResult` with the `prompts` field populated + /// and `next_cursor` set to `null`. + /// + /// # Examples + /// + /// ``` + /// # use crate::mcp::prompts::ListPromptsResult; + /// # tokio_test::block_on(async { + /// // assume `server` is a constructed `McpServer` + /// let json = server.handle_list_prompts().await.unwrap(); + /// let res: ListPromptsResult = serde_json::from_value(json).unwrap(); + /// assert!(res.next_cursor.is_none()); + /// # }); + /// ``` async fn handle_list_prompts(&self) -> Result { use crate::mcp::prompts::ListPromptsResult; @@ -351,7 +572,26 @@ impl McpServer { Ok(serde_json::to_value(result)?) } - /// Handle get prompt request. + /// Fetches a prompt by name with optional arguments and returns it as JSON. + /// + /// Expects `params` to be a JSON object with a required `name` string and an optional + /// `arguments` object mapping strings to strings. Returns the prompt result serialized + /// to a `serde_json::Value`. + /// + /// Errors: + /// - Returns `Error::InvalidToolArguments` if `params` is missing or cannot be deserialized. + /// - Returns `Error::McpProtocol` if no prompt with the given name exists. + /// + /// # Examples + /// + /// ``` + /// # use serde_json::json; + /// # async fn _example(server: &crate::mcp::server::McpServer) { + /// let params = json!({ "name": "welcome", "arguments": { "user": "Alex" } }); + /// let res = server.handle_get_prompt(Some(params)).await.unwrap(); + /// // `res` is a serde_json::Value containing the prompt result + /// # } + /// ``` async fn handle_get_prompt(&self, params: Option) -> Result { #[derive(serde::Deserialize)] struct GetPromptParams { @@ -374,7 +614,26 @@ impl McpServer { Ok(serde_json::to_value(result)?) } - /// Handle list resources request. + /// Lists available resources using an optional pagination cursor. + /// + /// If the server was built without resource support this returns an MCP protocol + /// error indicating resources are not enabled. When resources are enabled, the + /// optional `params` JSON may contain a `"cursor"` string used for paging; the + /// function returns the serialized listing result from the resource registry. + /// + /// # Errors + /// + /// Returns `Error::McpProtocol("Resources not enabled")` if resources are not + /// configured for the server, or propagates errors from the resource registry + /// or JSON serialization. + /// + /// # Examples + /// + /// ``` + /// // Construct the optional params JSON with a cursor: + /// let params = serde_json::json!({ "cursor": "page-2" }); + /// // Call: server.handle_list_resources(Some(params)).await + /// ``` async fn handle_list_resources(&self, params: Option) -> Result { let resources = self .resources @@ -394,7 +653,30 @@ impl McpServer { Ok(serde_json::to_value(result)?) } - /// Handle read resource request. + /// Read a resource identified by a URI and return its serialized content as JSON. + /// + /// Returns an error if resources are not enabled, if required parameters are missing or malformed, + /// or if the underlying resource read operation fails. + /// + /// # Examples + /// + /// ``` + /// # use serde_json::json; + /// # use std::sync::Arc; + /// # async fn _example(server: &crate::mcp::server::McpServer) { + /// let params = json!({ "uri": "file:///path/to/resource" }); + /// let result = server.handle_read_resource(Some(params)).await; + /// match result { + /// Ok(value) => { + /// // `value` is the JSON-serialized content returned by the resource registry. + /// println!("{}", value); + /// } + /// Err(e) => { + /// eprintln!("read failed: {:?}", e); + /// } + /// } + /// # } + /// ``` async fn handle_read_resource(&self, params: Option) -> Result { let resources = self .resources @@ -416,7 +698,25 @@ impl McpServer { Ok(serde_json::to_value(result)?) } - /// Handle subscribe to resource. + /// Subscribe the default session to a resource identified by URI. + /// + /// Returns an error if resources are not enabled for this server or if the required `params` are + /// missing or cannot be deserialized. + /// + /// The request causes the server to call the configured ResourceRegistry's `subscribe` method for + /// the provided URI using a placeholder session id ("default") and, on success, returns an empty + /// JSON object. + /// + /// # Examples + /// + /// ```no_run + /// # use serde_json::json; + /// # async fn example(server: &crate::mcp::McpServer) -> Result<(), Box> { + /// let params = json!({ "uri": "file:///path/to/resource" }); + /// let res = server.handle_subscribe_resource(Some(params)).await?; + /// assert_eq!(res, json!({})); + /// # Ok(()) } + /// ``` async fn handle_subscribe_resource(&self, params: Option) -> Result { let resources = self .resources @@ -461,7 +761,33 @@ impl McpServer { Ok(serde_json::json!({})) } - /// Handle completion request. + /// Provide completion suggestions for a completion request. + /// + /// Expects `params` to deserialize to `{ ref: { type, uri?, name? }, argument: { name, value } }`. + /// For argument names "path", "file", or "uri" it returns filesystem/resource path completions; + /// for argument name "prompt" when `ref.type == "ref/prompt"` it returns prompt-name completions. + /// The response is a JSON object with a `completion` field containing `values` (an array of strings) + /// and `hasMore` (a boolean). + /// + /// # Examples + /// + /// ```no_run + /// use serde_json::json; + /// + /// // Example request params for completing prompt names starting with "ins" + /// let params = json!({ + /// "ref": { "type": "ref/prompt" }, + /// "argument": { "name": "prompt", "value": "ins" } + /// }); + /// + /// // Expected shape of the response: + /// let expected = json!({ + /// "completion": { + /// "values": ["install", "instance"], // example values + /// "hasMore": false + /// } + /// }); + /// ``` async fn handle_completion(&self, params: Option) -> Result { #[derive(serde::Deserialize)] struct CompletionParams { @@ -517,7 +843,26 @@ impl McpServer { })) } - /// Complete file paths. + /// Generates file-path completion candidates that start with the given prefix. + /// + /// The returned completions are sourced from the optional resource registry (if enabled) + /// and from files/directories under client-provided workspace roots. Results are + /// deduplicated and limited to at most 20 entries. + /// + /// # Returns + /// + /// A vector of completion strings that begin with `prefix`, up to 20 items. + /// + /// # Examples + /// + /// ``` + /// // `server` is an instance of `McpServer`. + /// // This example assumes an async context (e.g., inside an async test). + /// # async fn example(server: &crate::mcp::server::McpServer) { + /// let completions = server.complete_file_path("src/").await; + /// // completions contains candidates like "src/main.rs", "src/lib.rs", ... + /// # } + /// ``` async fn complete_file_path(&self, prefix: &str) -> Vec { let roots = self.roots.read().await; let mut completions = Vec::new(); @@ -559,7 +904,32 @@ impl McpServer { completions.into_iter().take(20).collect() } - /// Handle logging/setLevel request. + /// Set the server's log level from RPC parameters. + /// + /// Expects `params` to be a JSON object `{ "level": "" }`. Parses the `level` string, + /// updates the server's log level, logs the change, and returns an empty JSON object on success. + /// If `params` is `None`, returns an MCP protocol error indicating the missing parameter. + /// Unknown or unrecognized level strings map to the default level (Info). + /// + /// # Parameters + /// + /// - `params`: Optional JSON `Value` containing a `level` string specifying the desired log level. + /// + /// # Returns + /// + /// An empty JSON object `{}` on success. + /// + /// # Examples + /// + /// ``` + /// # async fn docs_example(server: &McpServer) { + /// let res = server + /// .handle_set_log_level(Some(serde_json::json!({ "level": "debug" }))) + /// .await + /// .unwrap(); + /// assert_eq!(res, serde_json::json!({})); + /// # } + /// ``` async fn handle_set_log_level(&self, params: Option) -> Result { #[derive(serde::Deserialize)] struct SetLevelParams { @@ -610,4 +980,4 @@ mod tests { fn test_log_level_default() { assert_eq!(LogLevel::default(), LogLevel::Info); } -} +} \ No newline at end of file diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 20a6d74..f2ee572 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -23,7 +23,24 @@ use std::sync::Arc; use crate::mcp::handler::McpHandler; use crate::service::{ContextService, MemoryService, PlanningService}; -/// Register all tools with the handler. +/// Registers the built-in MCP tools with the given handler using the provided services. +/// +/// The function registers a fixed set of tools organized by category (retrieval, index, +/// memory, planning, review, navigation, and workspace), constructing each tool with the +/// appropriate service(s) supplied. +/// +/// # Examples +/// +/// ``` +/// use std::sync::Arc; +/// +/// let mut handler = McpHandler::new(); +/// let ctx = Arc::new(ContextService::default()); +/// let mem = Arc::new(MemoryService::default()); +/// let plan = Arc::new(PlanningService::default()); +/// +/// register_all_tools(&mut handler, ctx, mem, plan); +/// ``` pub fn register_all_tools( handler: &mut McpHandler, context_service: Arc, @@ -102,4 +119,4 @@ pub fn register_all_tools( handler.register(workspace::WorkspaceStatsTool::new(context_service.clone())); handler.register(workspace::GitStatusTool::new(context_service.clone())); handler.register(workspace::ExtractSymbolsTool::new(context_service)); -} +} \ No newline at end of file diff --git a/src/tools/navigation.rs b/src/tools/navigation.rs index dacf215..11a0038 100644 --- a/src/tools/navigation.rs +++ b/src/tools/navigation.rs @@ -19,6 +19,18 @@ pub struct FindReferencesTool { } impl FindReferencesTool { + /// Creates a new instance of the tool that shares the provided context service. + /// + /// The `service` is held by the tool and used to access workspace state and perform + /// file search, definition lookup, or diff operations depending on the tool. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// let service = Arc::new(ContextService::new()); + /// let tool = FindReferencesTool::new(service.clone()); + /// ``` pub fn new(service: Arc) -> Self { Self { service } } @@ -26,6 +38,23 @@ impl FindReferencesTool { #[async_trait] impl ToolHandler for FindReferencesTool { + /// Returns the tool descriptor for the "find_references" tool, describing its name, + /// purpose, and expected input schema. + /// + /// The returned Tool has: + /// - name: "find_references" + /// - description: brief explanation of the tool's purpose (searches for symbol usages) + /// - input_schema: JSON schema requiring `symbol` and optionally accepting `file_pattern` + /// and `max_results` (default: 50). + /// + /// # Examples + /// + /// ``` + /// // Construct the tool descriptor and verify its name. + /// let svc = Arc::new(ContextService::new()); // pseudo-code: supply a real service in use + /// let tool = FindReferencesTool::new(svc).definition(); + /// assert_eq!(tool.name, "find_references"); + /// ``` fn definition(&self) -> Tool { Tool { name: "find_references".to_string(), @@ -51,6 +80,27 @@ impl ToolHandler for FindReferencesTool { } } + /// Finds occurrences of a symbol across the workspace and returns a Markdown-formatted summary of matches. + /// + /// If any references are found, the result contains a Markdown document with a header and a bullet list + /// of file paths, line numbers, and line context for each occurrence. If no references are found, the + /// result contains a success message stating that no references were discovered for the requested symbol. + /// + /// # Returns + /// + /// A `ToolResult` containing either the Markdown list of references or a success message indicating no references. + /// + /// # Examples + /// + /// ``` + /// # use std::collections::HashMap; + /// # use serde_json::json; + /// # use futures::executor::block_on; + /// # // assuming `tool` is an instance of the tool in a test setup + /// let mut args = HashMap::new(); + /// args.insert("symbol".to_string(), json!("my_symbol")); + /// // block_on(tool.execute(args)) // -> ToolResult with Markdown or "No references found..." + /// ``` async fn execute(&self, args: HashMap) -> Result { let symbol = get_string_arg(&args, "symbol")?; let file_pattern = args.get("file_pattern").and_then(|v| v.as_str()); @@ -94,6 +144,18 @@ pub struct GoToDefinitionTool { } impl GoToDefinitionTool { + /// Creates a new instance of the tool that shares the provided context service. + /// + /// The `service` is held by the tool and used to access workspace state and perform + /// file search, definition lookup, or diff operations depending on the tool. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// let service = Arc::new(ContextService::new()); + /// let tool = FindReferencesTool::new(service.clone()); + /// ``` pub fn new(service: Arc) -> Self { Self { service } } @@ -101,6 +163,18 @@ impl GoToDefinitionTool { #[async_trait] impl ToolHandler for GoToDefinitionTool { + /// Creates a Tool descriptor for the "go_to_definition" tool used to locate a symbol's definition. + /// + /// The returned `Tool` includes the tool name, a brief description, and an input JSON schema + /// that requires a `symbol` and optionally accepts a `language` hint (e.g., "rust", "python"). + /// + /// # Examples + /// + /// ```no_run + /// // Obtain the descriptor from a GoToDefinitionTool instance: + /// let tool = GoToDefinitionTool::new(std::sync::Arc::new(context_service)).definition(); + /// assert_eq!(tool.name, "go_to_definition"); + /// ``` fn definition(&self) -> Tool { Tool { name: "go_to_definition".to_string(), @@ -122,6 +196,31 @@ impl ToolHandler for GoToDefinitionTool { } } + /// Finds definitions for the provided symbol in the workspace and returns a Markdown document + /// describing each match with file path, line number, and a fenced code snippet tagged with the detected language. + /// + /// The `args` map must contain the key `"symbol"` with the symbol name to search for. It may also + /// include an optional `"language"` string to hint which language to prefer when locating definitions. + /// + /// # Returns + /// + /// A `ToolResult` containing a Markdown-formatted document listing each definition found. If no + /// definitions are found, the result contains a plain message stating that no definition was found. + /// + /// # Examples + /// + /// ```no_run + /// use std::collections::HashMap; + /// use serde_json::json; + /// + /// // `tool` is assumed to be an instance implementing this `execute` method. + /// let mut args = HashMap::new(); + /// args.insert("symbol".to_string(), json!("my_function")); + /// // Optionally: args.insert("language".to_string(), json!("rust")); + /// + /// // let result = tool.execute(args).await.unwrap(); + /// // println!("{}", result); + /// ``` async fn execute(&self, args: HashMap) -> Result { let symbol = get_string_arg(&args, "symbol")?; let language = args.get("language").and_then(|v| v.as_str()); @@ -154,6 +253,18 @@ pub struct DiffFilesTool { } impl DiffFilesTool { + /// Creates a new instance of the tool that shares the provided context service. + /// + /// The `service` is held by the tool and used to access workspace state and perform + /// file search, definition lookup, or diff operations depending on the tool. + /// + /// # Examples + /// + /// ``` + /// use std::sync::Arc; + /// let service = Arc::new(ContextService::new()); + /// let tool = FindReferencesTool::new(service.clone()); + /// ``` pub fn new(service: Arc) -> Self { Self { service } } @@ -161,6 +272,14 @@ impl DiffFilesTool { #[async_trait] impl ToolHandler for DiffFilesTool { + /// Provides the Tool descriptor for the "diff_files" tool which compares two files and produces a unified diff. + /// + /// The descriptor includes the tool name, a short description, and an input JSON schema that requires `file1` and `file2` + /// and accepts an optional `context_lines` integer to control the number of surrounding context lines (default: 3). + /// + /// # Returns + /// + /// A `Tool` value describing the "diff_files" tool, its description, and its input schema. fn definition(&self) -> Tool { Tool { name: "diff_files".to_string(), @@ -188,6 +307,26 @@ impl ToolHandler for DiffFilesTool { } } + /// Compute a unified-style diff for two files in the workspace and return it as a tool result. + /// + /// If both files are readable and identical, the result contains the message "Files are identical.". + /// If they differ, the result contains a markdown-formatted diff wrapped in ```diff fences. + /// If either file cannot be read, the result is an error ToolResult describing the read failure. + /// + /// # Examples + /// + /// ```no_run + /// # use std::collections::HashMap; + /// # use serde_json::json; + /// # async fn example(tool: &crate::tools::navigation::DiffFilesTool) { + /// let mut args = HashMap::new(); + /// args.insert("file1".to_string(), json!("Cargo.toml")); + /// args.insert("file2".to_string(), json!("Cargo.lock")); + /// // optional: args.insert("context_lines".to_string(), json!(5)); + /// let result = tool.execute(args).await.unwrap(); + /// // Inspect `result` to see the diff or an error message. + /// # } + /// ``` async fn execute(&self, args: HashMap) -> Result { let file1 = get_string_arg(&args, "file1")?; let file2 = get_string_arg(&args, "file2")?; @@ -235,7 +374,30 @@ struct Definition { language: String, } -/// Find symbol references in files. +/// Search the workspace for occurrences of a symbol and collect matching references. +/// +/// Searches files under `workspace`, optionally filtering files by `file_pattern`, +/// and returns up to `max_results` matches as `Reference` entries containing the +/// relative file path, 1-based line number, and the matching line as context. +/// +/// # Parameters +/// +/// - `file_pattern`: optional pattern to restrict searched files (supports suffix like `"*.rs"` or substring matching). +/// - `max_results`: maximum number of references to return; search stops once this limit is reached. +/// +/// # Returns +/// +/// A `Vec` containing one entry per found occurrence, in discovery order. +/// +/// # Examples +/// +/// ``` +/// # use std::path::Path; +/// # use tokio_test::block_on; +/// // Search the current directory for the string "main", returning at most 5 matches. +/// let refs = block_on(async { crate::tools::navigation::find_symbol_in_files(Path::new("."), "main", None, 5).await }); +/// assert!(refs.len() <= 5); +/// ``` async fn find_symbol_in_files( workspace: &Path, symbol: &str, @@ -314,7 +476,22 @@ async fn find_symbol_in_files( references } -/// Find symbol definition. +/// Searches the workspace for likely definitions of `symbol` and returns any matches found. +/// +/// If `language` is provided, the search is limited to files whose detected language matches the hint +/// (for example `"rust"`, `"python"`, `"typescript"`). Each returned `Definition` contains the +/// relative file path, a 1-based line number, a short context snippet (up to a few lines), and the +/// detected language for the file. +/// +/// # Examples +/// +/// ``` +/// use std::path::Path; +/// // Run the async function in a simple executor for the example. +/// let defs = futures::executor::block_on(find_definition(Path::new("path/to/workspace"), "my_symbol", None)); +/// // `defs` is a Vec; check if any definitions were found. +/// assert!(defs.is_empty() || defs.iter().all(|d| d.context.len() > 0)); +/// ``` async fn find_definition( workspace: &Path, symbol: &str, @@ -392,7 +569,35 @@ async fn find_definition( definitions } -/// Get definition patterns for a symbol. +/// Build a list of textual patterns commonly used to identify symbol definitions. + +/// + +/// The `symbol` is inserted into language-specific declaration snippets. The optional + +/// `language` hint restricts patterns to that language when possible; otherwise a generic + +/// set of patterns for several common languages is returned. + +/// + +/// # Examples + +/// + +/// ``` + +/// let pats = get_definition_patterns("my_fn", Some("rust")); + +/// assert!(pats.iter().any(|p| p == "fn my_fn(")); + +/// + +/// let generic = get_definition_patterns("Thing", None); + +/// assert!(generic.iter().any(|p| p.contains("class Thing") || p.contains("struct Thing"))); + +/// ``` fn get_definition_patterns(symbol: &str, language: Option<&str>) -> Vec { let mut patterns = Vec::new(); @@ -434,7 +639,17 @@ fn get_definition_patterns(symbol: &str, language: Option<&str>) -> Vec patterns } -/// Get language from file extension. +/// Map a file extension to a canonical language identifier. +/// +/// Recognizes common source file extensions and returns a short language name; unknown extensions return `"text"`. +/// +/// # Examples +/// +/// ``` +/// assert_eq!(get_language("rs"), "rust"); +/// assert_eq!(get_language("tsx"), "typescript"); +/// assert_eq!(get_language("unknown"), "text"); +/// ``` fn get_language(ext: &str) -> &'static str { match ext { "rs" => "rust", @@ -450,7 +665,18 @@ fn get_language(ext: &str) -> &'static str { } } -/// Simple pattern matching. +/// Checks whether a filename matches a simple pattern. +/// +/// Patterns starting with `"*."` are treated as extension matches (e.g., `"*.rs"` +/// matches `"foo.rs"`). All other patterns are matched by substring containment. +/// +/// # Examples +/// +/// ``` +/// assert!(matches_pattern("src/lib.rs", "*.rs")); +/// assert!(matches_pattern("README.md", "README")); +/// assert!(!matches_pattern("src/main.c", "*.rs")); +/// ``` fn matches_pattern(name: &str, pattern: &str) -> bool { if let Some(ext) = pattern.strip_prefix("*.") { name.ends_with(&format!(".{}", ext)) @@ -459,7 +685,25 @@ fn matches_pattern(name: &str, pattern: &str) -> bool { } } -/// Generate a simple unified diff. +/// Produces a unified-diff-like string describing differences between two file contents. +/// +/// The output starts with unified diff headers for `name1` and `name2` and contains one or more +/// hunks with context lines, removals marked with `-` and additions with `+`. If the contents +/// are identical, an empty string is returned. +/// +/// `context` controls how many unchanged lines around a change are included in each hunk. +/// +/// # Examples +/// +/// ``` +/// let a = "a\nb\nc\n"; +/// let b = "a\nB\nc\n"; +/// let diff = generate_diff("old.txt", "new.txt", a, b, 1); +/// assert!(diff.contains("--- old.txt")); +/// assert!(diff.contains("+++ new.txt")); +/// assert!(diff.contains("-b")); +/// assert!(diff.contains("+B")); +/// ``` fn generate_diff( name1: &str, name2: &str, @@ -638,4 +882,4 @@ mod tests { assert_eq!(definition.file, "src/lib.rs"); assert_eq!(definition.language, "rust"); } -} +} \ No newline at end of file diff --git a/src/tools/workspace.rs b/src/tools/workspace.rs index 5aa1220..2fc6494 100644 --- a/src/tools/workspace.rs +++ b/src/tools/workspace.rs @@ -19,6 +19,16 @@ pub struct WorkspaceStatsTool { } impl WorkspaceStatsTool { + /// Create a new WorkspaceStatsTool that uses the given ContextService. + /// + /// # Examples + /// + /// ```no_run + /// use std::sync::Arc; + /// // `service` should be an initialized `ContextService` from the application. + /// let service: Arc = Arc::new(/* ... */); + /// let tool = WorkspaceStatsTool::new(service); + /// ``` pub fn new(service: Arc) -> Self { Self { service } } @@ -26,6 +36,17 @@ impl WorkspaceStatsTool { #[async_trait] impl ToolHandler for WorkspaceStatsTool { + /// Returns the tool descriptor for the `workspace_stats` tool. + /// + /// The descriptor includes the tool's name, a short description of what it provides, + /// and the JSON input schema (optionally accepts `include_hidden: bool`). + /// + /// # Examples + /// + /// ``` + /// let tool = WorkspaceStatsTool::new(service).definition(); + /// assert_eq!(tool.name, "workspace_stats"); + /// ``` fn definition(&self) -> Tool { Tool { name: "workspace_stats".to_string(), @@ -43,6 +64,32 @@ impl ToolHandler for WorkspaceStatsTool { } } + /// Execute the workspace statistics tool with the given arguments. + /// + /// The `args` map may include an optional `"include_hidden"` boolean; when `true` hidden files and + /// directories are included in the statistics. On success this returns a `ToolResult` containing a + /// pretty-printed JSON string of workspace statistics (total files, total lines, per-language + /// breakdown, and directory count). On failure this returns an error `ToolResult` with a + /// descriptive message. + /// + /// # Parameters + /// + /// - `args`: A map of input arguments; recognizes the optional `"include_hidden"` boolean. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashMap; + /// use serde_json::json; + /// + /// // prepare args to include hidden files + /// let mut args = HashMap::new(); + /// args.insert("include_hidden".to_string(), json!(true)); + /// + /// // assume `tool` is an initialized `WorkspaceStatsTool` + /// // let result = tool.execute(args).await.unwrap(); + /// // println!("{}", result); + /// ``` async fn execute(&self, args: HashMap) -> Result { let include_hidden = args .get("include_hidden") @@ -73,6 +120,28 @@ struct LanguageStats { lines: usize, } +/// Collects aggregated statistics for the workspace rooted at `root`. +/// +/// Scans files and directories under `root` to compute total files, total lines, +/// a per-language breakdown (files and lines), and the number of directories encountered. +/// When `include_hidden` is `true`, hidden files and directories (those starting with a dot) +/// are included in the scan; otherwise they are skipped. +/// +/// # Examples +/// +/// ```no_run +/// # async fn example() -> anyhow::Result<()> { +/// use std::path::Path; +/// let stats = collect_workspace_stats(Path::new("."), false).await?; +/// // stats contains totals and per-language breakdowns +/// assert!(stats.total_files >= 0); +/// # Ok(()) } +/// ``` +/// +/// # Returns +/// +/// A `WorkspaceStats` value containing totals for files and lines, a language map with +/// per-language file/line counts, and the number of directories scanned. async fn collect_workspace_stats(root: &Path, include_hidden: bool) -> Result { let mut stats = WorkspaceStats { total_files: 0, @@ -85,6 +154,32 @@ async fn collect_workspace_stats(root: &Path, include_hidden: bool) -> Result