From 5066035aecd734ef87bce4e58f28d6816ce0c599 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Tue, 11 Nov 2025 10:56:44 +0800 Subject: [PATCH] feat(task): add task support (SEP-1686) Signed-off-by: jokemanfire --- crates/rmcp-macros/README.md | 65 +++- crates/rmcp-macros/src/lib.rs | 15 + crates/rmcp-macros/src/task_handler.rs | 278 +++++++++++++++++ crates/rmcp/Cargo.toml | 5 + crates/rmcp/README.md | 15 +- crates/rmcp/src/error.rs | 4 + crates/rmcp/src/handler/server.rs | 77 ++++- crates/rmcp/src/handler/server/tool.rs | 8 +- crates/rmcp/src/lib.rs | 2 + crates/rmcp/src/model.rs | 73 ++++- crates/rmcp/src/model/capabilities.rs | 47 ++- crates/rmcp/src/model/meta.rs | 4 + crates/rmcp/src/model/task.rs | 79 +++++ crates/rmcp/src/task_manager.rs | 294 ++++++++++++++++++ .../src/transport/streamable_http_client.rs | 1 + .../client_json_rpc_message_schema.json | 181 +++++++++++ ...lient_json_rpc_message_schema_current.json | 181 +++++++++++ .../server_json_rpc_message_schema.json | 220 +++++++++++++ ...erver_json_rpc_message_schema_current.json | 220 +++++++++++++ crates/rmcp/tests/test_progress_subscriber.rs | 1 + crates/rmcp/tests/test_task.rs | 77 +++++ crates/rmcp/tests/test_tool_macros.rs | 2 + examples/clients/src/collection.rs | 1 + examples/clients/src/everything_stdio.rs | 2 + examples/clients/src/git_stdio.rs | 1 + examples/clients/src/progress_client.rs | 2 + examples/clients/src/sampling_stdio.rs | 1 + examples/clients/src/streamable_http.rs | 1 + examples/rig-integration/src/mcp_adaptor.rs | 1 + examples/servers/src/common/counter.rs | 118 +++++-- examples/simple-chat-client/src/tool.rs | 1 + examples/transport/src/named-pipe.rs | 1 + examples/transport/src/unix_socket.rs | 1 + 33 files changed, 1922 insertions(+), 57 deletions(-) create mode 100644 crates/rmcp-macros/src/task_handler.rs create mode 100644 crates/rmcp/src/model/task.rs create mode 100644 crates/rmcp/src/task_manager.rs create mode 100644 crates/rmcp/tests/test_task.rs diff --git a/crates/rmcp-macros/README.md b/crates/rmcp-macros/README.md index f009329b..62ea2f5e 100644 --- a/crates/rmcp-macros/README.md +++ b/crates/rmcp-macros/README.md @@ -6,7 +6,10 @@ This library primarily provides the following macros: -- `#[tool]`: Used to mark functions as RMCP tools, automatically generating necessary metadata and invocation mechanisms +- `#[tool]`: Mark an async/sync function as an RMCP tool and generate metadata + schema glue +- `#[tool_router]`: Collect all `#[tool]` functions in an impl block into a router value +- `#[tool_handler]`: Implement the `call_tool` and `list_tools` entry points by delegating to a router expression +- `#[task_handler]`: Wire up the task lifecycle (list/enqueue/get/cancel) on top of an `OperationProcessor` ## Usage @@ -16,7 +19,7 @@ This macro is used to mark a function as a tool handler. This will generate a function that return the attribute of this tool, with type `rmcp::model::Tool`. -#### Usage +#### Tool attributes | field | type | usage | | :- | :- | :- | @@ -25,7 +28,7 @@ This will generate a function that return the attribute of this tool, with type | `input_schema` | `Expr` | A JSON Schema object defining the expected parameters for the tool. If not provide, if will use the json schema of its argument with type `Parameters` | | `annotations` | `ToolAnnotationsAttribute` | Additional tool information. Defaults to `None`. | -#### Example +#### Tool example ```rust #[tool(name = "my_tool", description = "This is my tool", annotations(title = "我的工具", read_only_hint = true))] @@ -42,14 +45,14 @@ It creates a function that returns a `ToolRouter` instance. In most case, you need to add a field for handler to store the router information and initialize it when creating handler, or store it with a static variable. -#### Usage +#### Router attributes | field | type | usage | | :- | :- | :- | | `router` | `Ident` | The name of the router function to be generated. Defaults to `tool_router`. | | `vis` | `Visibility` | The visibility of the generated router function. Defaults to empty. | -#### Example +#### Router example ```rust #[tool_router] @@ -104,13 +107,14 @@ impl MyToolHandler { This macro will generate the handler for `tool_call` and `list_tools` methods in the implementation block, by using an existing `ToolRouter` instance. -#### Usage +#### Handler attributes | field | type | usage | | :- | :- | :- | | `router` | `Expr` | The expression to access the `ToolRouter` instance. Defaults to `self.tool_router`. | -#### Example +#### Handler example + ```rust #[tool_handler] impl ServerHandler for MyToolHandler { @@ -119,6 +123,7 @@ impl ServerHandler for MyToolHandler { ``` or using a custom router expression: + ```rust #[tool_handler(router = self.get_router().await)] impl ServerHandler for MyToolHandler { @@ -126,8 +131,10 @@ impl ServerHandler for MyToolHandler { } ``` -#### Explained +#### Handler expansion + This macro will be expended to something like this: + ```rust impl ServerHandler for MyToolHandler { async fn call_tool( @@ -150,6 +157,46 @@ impl ServerHandler for MyToolHandler { } ``` +### task_handler + +This macro wires the task lifecycle endpoints (`list_tasks`, `enqueue_task`, `get_task`, `cancel_task`) to an implementation of `OperationProcessor`. It keeps the handler lean by delegating scheduling, status tracking, and cancellation semantics to the processor. + +#### Task handler attributes + +| field | type | usage | +| :- | :- | :- | +| `processor` | `Expr` | Expression that yields an `Arc` (or compatible trait object). Defaults to `self.processor.clone()`. | + +#### Task handler example + +```rust +#[derive(Clone)] +pub struct TaskHandler { + processor: Arc + Send + Sync>, +} + +#[task_handler(processor = self.processor.clone())] +impl ServerHandler for TaskHandler {} +``` + +#### Task handler expansion + +At expansion time the macro implements the task-specific handler methods by forwarding to the processor expression, roughly equivalent to: + +```rust +impl ServerHandler for TaskHandler { + async fn list_tasks(&self, request: TaskListRequest, ctx: RequestContext) -> Result { + self.processor.list_tasks(request, ctx).await + } + + async fn enqueue_task(&self, request: TaskEnqueueRequest, ctx: RequestContext) -> Result { + self.processor.enqueue_task(request, ctx).await + } + + // get_task and cancel_task are generated in the same manner. +} +``` + ## Advanced Features @@ -159,4 +206,4 @@ impl ServerHandler for MyToolHandler { ## License -Please refer to the LICENSE file in the project root directory. +Please refer to the LICENSE file in the project root directory. diff --git a/crates/rmcp-macros/src/lib.rs b/crates/rmcp-macros/src/lib.rs index 6bb06827..ea79f465 100644 --- a/crates/rmcp-macros/src/lib.rs +++ b/crates/rmcp-macros/src/lib.rs @@ -5,6 +5,7 @@ mod common; mod prompt; mod prompt_handler; mod prompt_router; +mod task_handler; mod tool; mod tool_handler; mod tool_router; @@ -263,3 +264,17 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> TokenStream { .unwrap_or_else(|err| err.to_compile_error()) .into() } + +/// # task_handler +/// +/// Generates basic task-handling methods (`enqueue_task` and `list_tasks`) for a server handler +/// using a shared \[`OperationProcessor`\]. The default processor expression assumes a +/// `self.processor` field holding an `Arc>`, but it can be customized +/// via `#[task_handler(processor = ...)]`. Because the macro captures `self` inside spawned +/// futures, the handler type must implement [`Clone`]. +#[proc_macro_attribute] +pub fn task_handler(attr: TokenStream, input: TokenStream) -> TokenStream { + task_handler::task_handler(attr.into(), input.into()) + .unwrap_or_else(|err| err.to_compile_error()) + .into() +} diff --git a/crates/rmcp-macros/src/task_handler.rs b/crates/rmcp-macros/src/task_handler.rs new file mode 100644 index 00000000..f94cf130 --- /dev/null +++ b/crates/rmcp-macros/src/task_handler.rs @@ -0,0 +1,278 @@ +use darling::{FromMeta, ast::NestedMeta}; +use proc_macro2::TokenStream; +use quote::{ToTokens, quote}; +use syn::{Expr, ImplItem, ItemImpl}; + +#[derive(FromMeta)] +#[darling(default)] +struct TaskHandlerAttribute { + processor: Expr, +} + +impl Default for TaskHandlerAttribute { + fn default() -> Self { + Self { + processor: syn::parse2(quote! { self.processor }).expect("default processor expr"), + } + } +} + +pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result { + let attr_args = NestedMeta::parse_meta_list(attr)?; + let TaskHandlerAttribute { processor } = TaskHandlerAttribute::from_list(&attr_args)?; + let mut item_impl = syn::parse2::(input.clone())?; + + let has_method = |name: &str, item_impl: &ItemImpl| -> bool { + item_impl.items.iter().any(|item| match item { + ImplItem::Fn(func) => func.sig.ident == name, + _ => false, + }) + }; + + if !has_method("list_tasks", &item_impl) { + let list_fn = quote! { + async fn list_tasks( + &self, + _request: Option, + _: rmcp::service::RequestContext, + ) -> Result { + let running_ids = (#processor).lock().await.list_running(); + let total = running_ids.len() as u64; + let tasks = running_ids + .into_iter() + .map(|task_id| { + let timestamp = rmcp::task_manager::current_timestamp(); + rmcp::model::Task { + task_id, + status: rmcp::model::TaskStatus::Working, + status_message: None, + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: None, + poll_interval: None, + } + }) + .collect::>(); + + Ok(rmcp::model::ListTasksResult { + tasks, + next_cursor: None, + total: Some(total), + }) + } + }; + item_impl.items.push(syn::parse2::(list_fn)?); + } + + if !has_method("enqueue_task", &item_impl) { + let enqueue_fn = quote! { + async fn enqueue_task( + &self, + request: rmcp::model::CallToolRequestParam, + context: rmcp::service::RequestContext, + ) -> Result { + use rmcp::task_manager::{ + current_timestamp, OperationDescriptor, OperationMessage, OperationResultTransport, + ToolCallTaskResult, + }; + let task_id = context.id.to_string(); + let operation_name = request.name.to_string(); + let future_request = request.clone(); + let future_context = context.clone(); + let server = self.clone(); + + let descriptor = OperationDescriptor::new(task_id.clone(), operation_name) + .with_context(context) + .with_client_request(rmcp::model::ClientRequest::CallToolRequest( + rmcp::model::Request::new(request), + )); + + let task_result_id = task_id.clone(); + let future = Box::pin(async move { + let result = server.call_tool(future_request, future_context).await; + Ok( + Box::new(ToolCallTaskResult::new(task_result_id, result)) + as Box, + ) + }); + + (#processor) + .lock() + .await + .submit_operation(OperationMessage::new(descriptor, future)) + .map_err(|err| rmcp::ErrorData::internal_error( + format!("failed to enqueue task: {err}"), + None, + ))?; + + let timestamp = current_timestamp(); + let task = rmcp::model::Task { + task_id, + status: rmcp::model::TaskStatus::Working, + status_message: Some("Task accepted".to_string()), + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: None, + poll_interval: None, + }; + + Ok(rmcp::model::CreateTaskResult { task }) + } + }; + item_impl.items.push(syn::parse2::(enqueue_fn)?); + } + + if !has_method("get_task_info", &item_impl) { + let get_info_fn = quote! { + async fn get_task_info( + &self, + request: rmcp::model::GetTaskInfoParam, + _context: rmcp::service::RequestContext, + ) -> Result { + use rmcp::task_manager::current_timestamp; + let task_id = request.task_id.clone(); + let mut processor = (#processor).lock().await; + processor.collect_completed_results(); + + // Check completed results first + let completed = processor.peek_completed().iter().rev().find(|r| r.descriptor.operation_id == task_id); + if let Some(completed_result) = completed { + // Determine Finished vs Failed + let status = match &completed_result.result { + Ok(boxed) => { + if let Some(tool) = boxed.as_any().downcast_ref::() { + match &tool.result { + Ok(_) => rmcp::model::TaskStatus::Completed, + Err(_) => rmcp::model::TaskStatus::Failed, + } + } else { + rmcp::model::TaskStatus::Completed + } + } + Err(_) => rmcp::model::TaskStatus::Failed, + }; + let timestamp = current_timestamp(); + let task = rmcp::model::Task { + task_id, + status, + status_message: None, + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: completed_result.descriptor.ttl, + poll_interval: None, + }; + return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) }); + } + + // If not completed, check running + let running = processor.list_running(); + if running.into_iter().any(|id| id == task_id) { + let timestamp = current_timestamp(); + let task = rmcp::model::Task { + task_id, + status: rmcp::model::TaskStatus::Working, + status_message: None, + created_at: timestamp.clone(), + last_updated_at: Some(timestamp), + ttl: None, + poll_interval: None, + }; + return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) }); + } + + Ok(rmcp::model::GetTaskInfoResult { task: None }) + } + }; + item_impl.items.push(syn::parse2::(get_info_fn)?); + } + + if !has_method("get_task_result", &item_impl) { + let get_result_fn = quote! { + async fn get_task_result( + &self, + request: rmcp::model::GetTaskResultParam, + _context: rmcp::service::RequestContext, + ) -> Result { + use std::time::Duration; + let task_id = request.task_id.clone(); + + loop { + // Scope the lock so we can await outside if needed + { + let mut processor = (#processor).lock().await; + processor.collect_completed_results(); + + if let Some(task_result) = processor.take_completed_result(&task_id) { + match task_result.result { + Ok(boxed) => { + if let Some(tool) = boxed.as_any().downcast_ref::() { + match &tool.result { + Ok(call_tool) => { + let value = ::serde_json::to_value(call_tool).unwrap_or(::serde_json::Value::Null); + return Ok(rmcp::model::TaskResult { + content_type: "application/json".to_string(), + value, + summary: None, + }); + } + Err(err) => return Err(McpError::internal_error( + format!("task failed: {}", err), + None, + )), + } + } else { + return Err(McpError::internal_error("unsupported task result transport", None)); + } + } + Err(err) => return Err(McpError::internal_error( + format!("task execution error: {}", err), + None, + )), + } + } + + // Not completed yet: if not running, return not found + let running = processor.list_running(); + if !running.iter().any(|id| id == &task_id) { + return Err(McpError::resource_not_found(format!("task not found: {}", task_id), None)); + } + } + + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + }; + item_impl + .items + .push(syn::parse2::(get_result_fn)?); + } + + if !has_method("cancel_task", &item_impl) { + let cancel_fn = quote! { + async fn cancel_task( + &self, + request: rmcp::model::CancelTaskParam, + _context: rmcp::service::RequestContext, + ) -> Result<(), McpError> { + let task_id = request.task_id; + let mut processor = (#processor).lock().await; + processor.collect_completed_results(); + + if processor.cancel_task(&task_id) { + return Ok(()); + } + + // If already completed, signal it's not cancellable + let exists_completed = processor.peek_completed().iter().any(|r| r.descriptor.operation_id == task_id); + if exists_completed { + return Err(McpError::invalid_request(format!("task already completed: {}", task_id), None)); + } + + Err(McpError::resource_not_found(format!("task not found: {}", task_id), None)) + } + }; + item_impl.items.push(syn::parse2::(cancel_fn)?); + } + + Ok(item_impl.into_token_stream()) +} diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 4703c667..2b63f66f 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -197,3 +197,8 @@ path = "tests/test_progress_subscriber.rs" name = "test_elicitation" required-features = ["elicitation", "client", "server"] path = "tests/test_elicitation.rs" + +[[test]] +name = "test_task" +required-features = ["server", "client", "macros"] +path = "tests/test_task.rs" \ No newline at end of file diff --git a/crates/rmcp/README.md b/crates/rmcp/README.md index 8f021d33..e28df000 100644 --- a/crates/rmcp/README.md +++ b/crates/rmcp/README.md @@ -2,9 +2,6 @@ `rmcp` is the official Rust implementation of the Model Context Protocol (MCP), a protocol designed for AI assistants to communicate with other services. This library can be used to build both servers that expose capabilities to AI assistants and clients that interact with such servers. -wait for the first release. - @@ -81,6 +78,17 @@ async fn main() -> Result<(), Box> { } ``` +## Tasks + +RMCP implements the task lifecycle from SEP-1686 so long-running or asynchronous tool calls can be queued and polled safely. + +- **Create:** set the `task` field on `CallToolRequestParam` to ask the server to enqueue the tool call. The response is a `CreateTaskResult` that includes the generated `task.task_id`. +- **Inspect:** use `tasks/get` (`GetTaskInfoRequest`) to retrieve metadata such as status, timestamps, TTL, and poll interval. +- **Await results:** call `tasks/result` (`GetTaskResultRequest`) to block until the task completes and receive either the final `CallToolResult` payload or a protocol error. +- **Cancel:** call `tasks/cancel` (`CancelTaskRequest`) to request termination of a running task. + +To expose task support, enable the `tasks` capability when building `ServerCapabilities`. The `#[task_handler]` macro and `OperationProcessor` utility provide reference implementations for enqueuing, tracking, and collecting task results. + ### Client Implementation Creating a client to interact with a server: @@ -117,6 +125,7 @@ async fn main() -> Result<(), Box> { .call_tool(CallToolRequestParam { name: "increment".into(), arguments: None, + task: None, }) .await?; println!("Result: {result:#?}"); diff --git a/crates/rmcp/src/error.rs b/crates/rmcp/src/error.rs index e0da2b3d..f51a7158 100644 --- a/crates/rmcp/src/error.rs +++ b/crates/rmcp/src/error.rs @@ -41,6 +41,10 @@ pub enum RmcpError { error: Box, }, // and cancellation shouldn't be an error? + + // TODO: add more error variants as needed + #[error("Task error: {0}")] + TaskError(String), } impl RmcpError { diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index b16aeddc..f10cfa7c 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -11,6 +11,7 @@ pub mod router; pub mod tool; pub mod tool_name_validation; pub mod wrapper; + impl Service for H { async fn handle_request( &self, @@ -61,10 +62,18 @@ impl Service for H { .unsubscribe(request.params, context) .await .map(ServerResult::empty), - ClientRequest::CallToolRequest(request) => self - .call_tool(request.params, context) - .await - .map(ServerResult::CallToolResult), + ClientRequest::CallToolRequest(request) => { + if request.params.task.is_some() { + tracing::info!("Enqueueing task for tool call: {}", request.params.name); + self.enqueue_task(request.params, context.clone()) + .await + .map(ServerResult::CreateTaskResult) + } else { + self.call_tool(request.params, context) + .await + .map(ServerResult::CallToolResult) + } + } ClientRequest::ListToolsRequest(request) => self .list_tools(request.params, context) .await @@ -73,6 +82,22 @@ impl Service for H { .on_custom_request(request, context) .await .map(ServerResult::CustomResult), + ClientRequest::ListTasksRequest(request) => self + .list_tasks(request.params, context) + .await + .map(ServerResult::ListTasksResult), + ClientRequest::GetTaskInfoRequest(request) => self + .get_task_info(request.params, context) + .await + .map(ServerResult::GetTaskInfoResult), + ClientRequest::GetTaskResultRequest(request) => self + .get_task_result(request.params, context) + .await + .map(ServerResult::TaskResult), + ClientRequest::CancelTaskRequest(request) => self + .cancel_task(request.params, context) + .await + .map(ServerResult::empty), } } @@ -108,6 +133,16 @@ impl Service for H { #[allow(unused_variables)] pub trait ServerHandler: Sized + Send + Sync + 'static { + fn enqueue_task( + &self, + _request: CallToolRequestParam, + _context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::internal_error( + "Task processing not implemented".to_string(), + None, + ))) + } fn ping( &self, context: RequestContext, @@ -257,4 +292,38 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { fn get_info(&self) -> ServerInfo { ServerInfo::default() } + + fn list_tasks( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + + fn get_task_info( + &self, + request: GetTaskInfoParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + std::future::ready(Err(McpError::method_not_found::())) + } + + fn get_task_result( + &self, + request: GetTaskResultParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + let _ = (request, context); + std::future::ready(Err(McpError::method_not_found::())) + } + + fn cancel_task( + &self, + request: CancelTaskParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + let _ = (request, context); + std::future::ready(Err(McpError::method_not_found::())) + } } diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index 21e7e1a2..16435e42 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -33,12 +33,17 @@ pub struct ToolCallContext<'s, S> { pub service: &'s S, pub name: Cow<'static, str>, pub arguments: Option, + pub task: Option, } impl<'s, S> ToolCallContext<'s, S> { pub fn new( service: &'s S, - CallToolRequestParam { name, arguments }: CallToolRequestParam, + CallToolRequestParam { + name, + arguments, + task, + }: CallToolRequestParam, request_context: RequestContext, ) -> Self { Self { @@ -46,6 +51,7 @@ impl<'s, S> ToolCallContext<'s, S> { service, name, arguments, + task, } } pub fn name(&self) -> &str { diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 9f81eabe..3ab7c5d9 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -125,6 +125,7 @@ //! .call_tool(CallToolRequestParam { //! name: "git_status".into(), //! arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), +//! task: None, //! }) //! .await?; //! println!("Tool result: {tool_result:#?}"); @@ -162,6 +163,7 @@ pub use service::{RoleClient, serve_client}; pub use service::{RoleServer, serve_server}; pub mod handler; +pub mod task_manager; pub mod transport; // re-export diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index 8837bdf1..36cd8fb9 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -8,6 +8,7 @@ mod meta; mod prompt; mod resource; mod serde_impl; +mod task; mod tool; pub use annotated::*; pub use capabilities::*; @@ -19,6 +20,7 @@ pub use prompt::*; pub use resource::*; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::Value; +pub use task::*; pub use tool::*; /// A JSON object type alias for convenient handling of JSON data. @@ -1705,6 +1707,8 @@ pub struct CallToolRequestParam { /// Arguments to pass to the tool (must match the tool's input schema) #[serde(skip_serializing_if = "Option::is_none")] pub arguments: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, } /// Request to call a specific tool @@ -1743,6 +1747,61 @@ pub struct GetPromptResult { pub messages: Vec, } +// ============================================================================= +// TASK MANAGEMENT +// ============================================================================= + +const_string!(GetTaskInfoMethod = "tasks/get"); +pub type GetTaskInfoRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskInfoParam { + pub task_id: String, +} + +const_string!(ListTasksMethod = "tasks/list"); +pub type ListTasksRequest = RequestOptionalParam; + +const_string!(GetTaskResultMethod = "tasks/result"); +pub type GetTaskResultRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskResultParam { + pub task_id: String, +} + +const_string!(CancelTaskMethod = "tasks/cancel"); +pub type CancelTaskRequest = Request; + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CancelTaskParam { + pub task_id: String, +} +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct GetTaskInfoResult { + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct ListTasksResult { + pub tasks: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, +} + // ============================================================================= // MESSAGE TYPE UNIONS // ============================================================================= @@ -1810,7 +1869,11 @@ ts_union!( | UnsubscribeRequest | CallToolRequest | ListToolsRequest - | CustomRequest; + | CustomRequest + | GetTaskInfoRequest + | ListTasksRequest + | GetTaskResultRequest + | CancelTaskRequest; ); impl ClientRequest { @@ -1830,6 +1893,10 @@ impl ClientRequest { ClientRequest::CallToolRequest(r) => r.method.as_str(), ClientRequest::ListToolsRequest(r) => r.method.as_str(), ClientRequest::CustomRequest(r) => r.method.as_str(), + ClientRequest::GetTaskInfoRequest(r) => r.method.as_str(), + ClientRequest::ListTasksRequest(r) => r.method.as_str(), + ClientRequest::GetTaskResultRequest(r) => r.method.as_str(), + ClientRequest::CancelTaskRequest(r) => r.method.as_str(), } } } @@ -1895,6 +1962,10 @@ ts_union!( | CreateElicitationResult | EmptyResult | CustomResult + | CreateTaskResult + | ListTasksResult + | GetTaskInfoResult + | TaskResult ; ); diff --git a/crates/rmcp/src/model/capabilities.rs b/crates/rmcp/src/model/capabilities.rs index cbe1a6ea..1740b3ee 100644 --- a/crates/rmcp/src/model/capabilities.rs +++ b/crates/rmcp/src/model/capabilities.rs @@ -40,6 +40,25 @@ pub struct RootsCapabilities { pub list_changed: Option, } +/// Task capability negotiation for SEP-1686. +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TasksCapability { + /// Map of request category (e.g. "tools.call") to a boolean indicating support. + #[serde(skip_serializing_if = "Option::is_none")] + pub requests: Option, + /// Whether the receiver supports `tasks/list`. + #[serde(skip_serializing_if = "Option::is_none")] + pub list: Option, + /// Whether the receiver supports `tasks/cancel`. + #[serde(skip_serializing_if = "Option::is_none")] + pub cancel: Option, +} + +/// A convenience alias for describing per-request task support. +pub type TaskRequestMap = BTreeMap; + /// Capability for handling elicitation requests from servers. /// /// Elicitation allows servers to request interactive input from users during tool execution. @@ -78,6 +97,8 @@ pub struct ClientCapabilities { /// Capability to handle elicitation requests from servers for interactive user input #[serde(skip_serializing_if = "Option::is_none")] pub elicitation: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tasks: Option, } /// @@ -109,6 +130,8 @@ pub struct ServerCapabilities { pub resources: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tasks: Option, } macro_rules! builder { @@ -223,12 +246,13 @@ builder! { completions: JsonObject, prompts: PromptsCapability, resources: ResourcesCapability, - tools: ToolsCapability + tools: ToolsCapability, + tasks: TasksCapability } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_tool_list_changed(mut self) -> Self { if let Some(c) = self.tools.as_mut() { @@ -238,8 +262,8 @@ impl } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_prompts_list_changed(mut self) -> Self { if let Some(c) = self.prompts.as_mut() { @@ -249,8 +273,8 @@ impl } } -impl - ServerCapabilitiesBuilder> +impl + ServerCapabilitiesBuilder> { pub fn enable_resources_list_changed(mut self) -> Self { if let Some(c) = self.resources.as_mut() { @@ -273,11 +297,12 @@ builder! { roots: RootsCapabilities, sampling: JsonObject, elicitation: ElicitationCapability, + tasks: TasksCapability, } } -impl - ClientCapabilitiesBuilder> +impl + ClientCapabilitiesBuilder> { pub fn enable_roots_list_changed(mut self) -> Self { if let Some(c) = self.roots.as_mut() { @@ -288,8 +313,8 @@ impl } #[cfg(feature = "elicitation")] -impl - ClientCapabilitiesBuilder> +impl + ClientCapabilitiesBuilder> { /// Enable JSON Schema validation for elicitation responses. /// When enabled, the client will validate user input against the requested_schema diff --git a/crates/rmcp/src/model/meta.rs b/crates/rmcp/src/model/meta.rs index e93ebf19..acda3900 100644 --- a/crates/rmcp/src/model/meta.rs +++ b/crates/rmcp/src/model/meta.rs @@ -107,6 +107,10 @@ variant_extension! { CallToolRequest ListToolsRequest CustomRequest + GetTaskInfoRequest + ListTasksRequest + GetTaskResultRequest + CancelTaskRequest } } diff --git a/crates/rmcp/src/model/task.rs b/crates/rmcp/src/model/task.rs new file mode 100644 index 00000000..8cb0ee58 --- /dev/null +++ b/crates/rmcp/src/model/task.rs @@ -0,0 +1,79 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Canonical task lifecycle status as defined by SEP-1686. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub enum TaskStatus { + /// The receiver accepted the request and is currently working on it. + #[default] + Working, + /// The receiver requires additional input before work can continue. + InputRequired, + /// The underlying operation completed successfully and the result is ready. + Completed, + /// The underlying operation failed and will not continue. + Failed, + /// The task was cancelled and will not continue processing. + Cancelled, +} + +/// Final result for a succeeded task (returned from `tasks/result`). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskResult { + /// MIME type or custom content-type identifier. + pub content_type: String, + /// The actual result payload, matching the underlying request's schema. + pub value: Value, + /// Optional short summary for UI surfaces. + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +/// Primary Task object that surfaces metadata during the task lifecycle. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct Task { + /// Unique task identifier generated by the receiver. + pub task_id: String, + /// Current lifecycle status (see [`TaskStatus`]). + pub status: TaskStatus, + /// Optional human-readable status message for UI surfaces. + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, + /// ISO-8601 creation timestamp. + pub created_at: String, + /// ISO-8601 timestamp for the most recent status change. + #[serde(skip_serializing_if = "Option::is_none")] + pub last_updated_at: Option, + /// Retention window in milliseconds that the receiver agreed to honor. + #[serde(skip_serializing_if = "Option::is_none")] + pub ttl: Option, + /// Suggested polling interval (milliseconds). + #[serde(skip_serializing_if = "Option::is_none")] + pub poll_interval: Option, +} + +/// Wrapper returned by task-augmented requests (CreateTaskResult in SEP-1686). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct CreateTaskResult { + pub task: Task, +} + +/// Paginated list of tasks +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))] +pub struct TaskList { + pub tasks: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total: Option, +} diff --git a/crates/rmcp/src/task_manager.rs b/crates/rmcp/src/task_manager.rs new file mode 100644 index 00000000..d8768902 --- /dev/null +++ b/crates/rmcp/src/task_manager.rs @@ -0,0 +1,294 @@ +use std::{any::Any, collections::HashMap, pin::Pin}; + +use futures::Future; +use tokio::{ + sync::mpsc, + time::{Duration, timeout}, +}; + +use crate::{ + RoleServer, + error::{ErrorData as McpError, RmcpError as Error}, + model::{CallToolResult, ClientRequest}, + service::RequestContext, +}; + +/// Boxed future that represents an asynchronous operation managed by the processor. +pub type OperationFuture = + Pin, Error>> + Send>>; + +/// Describes metadata associated with an enqueued task. +#[derive(Debug, Clone)] +pub struct OperationDescriptor { + pub operation_id: String, + pub name: String, + pub client_request: Option, + pub context: Option>, + pub ttl: Option, +} + +impl OperationDescriptor { + pub fn new(operation_id: impl Into, name: impl Into) -> Self { + Self { + operation_id: operation_id.into(), + name: name.into(), + client_request: None, + context: None, + ttl: None, + } + } + + pub fn with_client_request(mut self, request: ClientRequest) -> Self { + self.client_request = Some(request); + self + } + + pub fn with_context(mut self, context: RequestContext) -> Self { + self.context = Some(context); + self + } + + pub fn with_ttl(mut self, ttl: u64) -> Self { + self.ttl = Some(ttl); + self + } +} + +/// Operation message describing a unit of asynchronous work. +pub struct OperationMessage { + pub descriptor: OperationDescriptor, + pub future: OperationFuture, +} + +impl OperationMessage { + pub fn new(descriptor: OperationDescriptor, future: OperationFuture) -> Self { + Self { descriptor, future } + } +} + +/// Trait for operation result transport +pub trait OperationResultTransport: Send + Sync + 'static { + fn operation_id(&self) -> &String; + fn as_any(&self) -> &dyn std::any::Any; +} + +// ===== Operation Processor ===== +pub const DEFAULT_TASK_TIMEOUT_SECS: u64 = 300; // 5 minutes +/// Operation processor that coordinates extractors and handlers +pub struct OperationProcessor { + /// Currently running tasks keyed by id + running_tasks: HashMap, + /// Completed results waiting to be collected + completed_results: Vec, + task_result_receiver: Option>, + task_result_sender: mpsc::UnboundedSender, +} + +struct RunningTask { + task_handle: tokio::task::JoinHandle<()>, + started_at: std::time::Instant, + timeout: Option, + descriptor: OperationDescriptor, +} + +pub struct TaskResult { + pub descriptor: OperationDescriptor, + pub result: Result, Error>, +} + +/// Helper to generate an ISO 8601 timestamp for task metadata. +pub fn current_timestamp() -> String { + chrono::Utc::now().to_rfc3339() +} + +/// Result transport for tool calls executed as tasks. +pub struct ToolCallTaskResult { + id: String, + pub result: Result, +} + +impl ToolCallTaskResult { + pub fn new(id: impl Into, result: Result) -> Self { + Self { + id: id.into(), + result, + } + } +} + +impl OperationResultTransport for ToolCallTaskResult { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +impl Default for OperationProcessor { + fn default() -> Self { + Self::new() + } +} + +impl OperationProcessor { + pub fn new() -> Self { + let (task_result_sender, task_result_receiver) = mpsc::unbounded_channel(); + Self { + running_tasks: HashMap::new(), + completed_results: Vec::new(), + task_result_receiver: Some(task_result_receiver), + task_result_sender, + } + } + + /// Submit an operation for asynchronous execution. + #[allow(clippy::result_large_err)] + pub fn submit_operation(&mut self, message: OperationMessage) -> Result<(), Error> { + if self + .running_tasks + .contains_key(&message.descriptor.operation_id) + { + return Err(Error::TaskError(format!( + "Operation with id {} is already running", + message.descriptor.operation_id + ))); + } + self.spawn_async_task(message); + Ok(()) + } + + fn spawn_async_task(&mut self, message: OperationMessage) { + let OperationMessage { descriptor, future } = message; + let task_id = descriptor.operation_id.clone(); + let timeout_secs = descriptor.ttl.or(Some(DEFAULT_TASK_TIMEOUT_SECS)); + let sender = self.task_result_sender.clone(); + let descriptor_for_result = descriptor.clone(); + + let timed_future = async move { + if let Some(secs) = timeout_secs { + match timeout(Duration::from_secs(secs), future).await { + Ok(result) => result, + Err(_) => Err(Error::TaskError("Operation timed out".to_string())), + } + } else { + future.await + } + }; + + let handle = tokio::spawn(async move { + let result = timed_future.await; + let task_result = TaskResult { + descriptor: descriptor_for_result, + result, + }; + let _ = sender.send(task_result); + }); + let running_task = RunningTask { + task_handle: handle, + started_at: std::time::Instant::now(), + timeout: timeout_secs, + descriptor, + }; + self.running_tasks.insert(task_id, running_task); + } + + /// Collect completed results from running tasks and remove them from the running tasks map. + pub fn collect_completed_results(&mut self) -> Vec { + if let Some(receiver) = &mut self.task_result_receiver { + while let Ok(result) = receiver.try_recv() { + self.running_tasks.remove(&result.descriptor.operation_id); + self.completed_results.push(result); + } + } + std::mem::take(&mut self.completed_results) + } + + /// Check for tasks that have exceeded their timeout and handle them appropriately. + pub fn check_timeouts(&mut self) { + let now = std::time::Instant::now(); + let mut timed_out_tasks = Vec::new(); + + for (task_id, task) in &self.running_tasks { + if let Some(timeout_duration) = task.timeout { + if now.duration_since(task.started_at).as_secs() > timeout_duration { + task.task_handle.abort(); + timed_out_tasks.push(task_id.clone()); + } + } + } + + for task_id in timed_out_tasks { + if let Some(task) = self.running_tasks.remove(&task_id) { + let timeout_result = TaskResult { + descriptor: task.descriptor, + result: Err(Error::TaskError("Operation timed out".to_string())), + }; + self.completed_results.push(timeout_result); + } + } + } + + /// Get the number of running tasks. + pub fn running_task_count(&self) -> usize { + self.running_tasks.len() + } + + /// Cancel all running tasks. + pub fn cancel_all_tasks(&mut self) { + for (_, task) in self.running_tasks.drain() { + task.task_handle.abort(); + } + self.completed_results.clear(); + } + /// List running task ids. + pub fn list_running(&self) -> Vec { + self.running_tasks.keys().cloned().collect() + } + + /// Note: collectors should call collect_completed_results; this provides a snapshot of queued results. + pub fn peek_completed(&self) -> &[TaskResult] { + &self.completed_results + } + + /// Fetch the metadata for a running or recently completed task. + pub fn task_descriptor(&self, task_id: &str) -> Option<&OperationDescriptor> { + if let Some(task) = self.running_tasks.get(task_id) { + return Some(&task.descriptor); + } + self.completed_results + .iter() + .rev() + .find(|result| result.descriptor.operation_id == task_id) + .map(|result| &result.descriptor) + } + + /// Attempt to cancel a running task. + pub fn cancel_task(&mut self, task_id: &str) -> bool { + if let Some(task) = self.running_tasks.remove(task_id) { + task.task_handle.abort(); + // Insert a cancelled result so callers can observe the terminal state. + let cancel_result = TaskResult { + descriptor: task.descriptor, + result: Err(Error::TaskError("Operation cancelled".to_string())), + }; + self.completed_results.push(cancel_result); + return true; + } + false + } + + /// Retrieve a completed task result if available. + pub fn take_completed_result(&mut self, task_id: &str) -> Option { + if let Some(position) = self + .completed_results + .iter() + .position(|result| result.descriptor.operation_id == task_id) + { + Some(self.completed_results.remove(position)) + } else { + None + } + } +} diff --git a/crates/rmcp/src/transport/streamable_http_client.rs b/crates/rmcp/src/transport/streamable_http_client.rs index 4db461a4..61f5074b 100644 --- a/crates/rmcp/src/transport/streamable_http_client.rs +++ b/crates/rmcp/src/transport/streamable_http_client.rs @@ -384,6 +384,7 @@ impl Worker for StreamableHttpClientWorker { "process initialized notification response", ))?; let _ = initialized_notification.responder.send(Ok(())); + #[allow(clippy::large_enum_variant)] enum Event { ClientMessage(WorkerSendRequest), ServerMessage(ServerJsonRpcMessage), diff --git a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json index 4474dc82..5ae242a3 100644 --- a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json +++ b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema.json @@ -202,12 +202,35 @@ "name": { "description": "The name of the tool to call", "type": "string" + }, + "task": { + "type": [ + "object", + "null" + ], + "additionalProperties": true } }, "required": [ "name" ] }, + "CancelTaskMethod": { + "type": "string", + "format": "const", + "const": "tasks/cancel" + }, + "CancelTaskParam": { + "type": "object", + "properties": { + "taskId": { + "type": "string" + } + }, + "required": [ + "taskId" + ] + }, "CancelledNotificationMethod": { "type": "string", "format": "const", @@ -272,6 +295,16 @@ "null" ], "additionalProperties": true + }, + "tasks": { + "anyOf": [ + { + "$ref": "#/definitions/TasksCapability" + }, + { + "type": "null" + } + ] } } }, @@ -520,6 +553,38 @@ "name" ] }, + "GetTaskInfoMethod": { + "type": "string", + "format": "const", + "const": "tasks/get" + }, + "GetTaskInfoParam": { + "type": "object", + "properties": { + "taskId": { + "type": "string" + } + }, + "required": [ + "taskId" + ] + }, + "GetTaskResultMethod": { + "type": "string", + "format": "const", + "const": "tasks/result" + }, + "GetTaskResultParam": { + "type": "object", + "properties": { + "taskId": { + "type": "string" + } + }, + "required": [ + "taskId" + ] + }, "Icon": { "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", "type": "object", @@ -730,6 +795,18 @@ }, { "$ref": "#/definitions/CustomRequest" + }, + { + "$ref": "#/definitions/Request9" + }, + { + "$ref": "#/definitions/RequestOptionalParam5" + }, + { + "$ref": "#/definitions/Request10" + }, + { + "$ref": "#/definitions/Request11" } ], "required": [ @@ -790,6 +867,11 @@ "roots" ] }, + "ListTasksMethod": { + "type": "string", + "format": "const", + "const": "tasks/list" + }, "ListToolsRequestMethod": { "type": "string", "format": "const", @@ -1168,6 +1250,38 @@ "params" ] }, + "Request10": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/GetTaskResultMethod" + }, + "params": { + "$ref": "#/definitions/GetTaskResultParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request11": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CancelTaskMethod" + }, + "params": { + "$ref": "#/definitions/CancelTaskParam" + } + }, + "required": [ + "method", + "params" + ] + }, "Request2": { "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", "type": "object", @@ -1280,6 +1394,22 @@ "params" ] }, + "Request9": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/GetTaskInfoMethod" + }, + "params": { + "$ref": "#/definitions/GetTaskInfoParam" + } + }, + "required": [ + "method", + "params" + ] + }, "RequestNoParam": { "type": "object", "properties": { @@ -1375,6 +1505,27 @@ "method" ] }, + "RequestOptionalParam5": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListTasksMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, "ResourceContents": { "anyOf": [ { @@ -1534,6 +1685,36 @@ "uri" ] }, + "TasksCapability": { + "description": "Task capability negotiation for SEP-1686.", + "type": "object", + "properties": { + "cancel": { + "description": "Whether the receiver supports `tasks/cancel`.", + "type": [ + "boolean", + "null" + ] + }, + "list": { + "description": "Whether the receiver supports `tasks/list`.", + "type": [ + "boolean", + "null" + ] + }, + "requests": { + "description": "Map of request category (e.g. \"tools.call\") to a boolean indicating support.", + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "boolean" + } + } + } + }, "UnsubscribeRequestMethod": { "type": "string", "format": "const", diff --git a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json index 4474dc82..5ae242a3 100644 --- a/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json +++ b/crates/rmcp/tests/test_message_schema/client_json_rpc_message_schema_current.json @@ -202,12 +202,35 @@ "name": { "description": "The name of the tool to call", "type": "string" + }, + "task": { + "type": [ + "object", + "null" + ], + "additionalProperties": true } }, "required": [ "name" ] }, + "CancelTaskMethod": { + "type": "string", + "format": "const", + "const": "tasks/cancel" + }, + "CancelTaskParam": { + "type": "object", + "properties": { + "taskId": { + "type": "string" + } + }, + "required": [ + "taskId" + ] + }, "CancelledNotificationMethod": { "type": "string", "format": "const", @@ -272,6 +295,16 @@ "null" ], "additionalProperties": true + }, + "tasks": { + "anyOf": [ + { + "$ref": "#/definitions/TasksCapability" + }, + { + "type": "null" + } + ] } } }, @@ -520,6 +553,38 @@ "name" ] }, + "GetTaskInfoMethod": { + "type": "string", + "format": "const", + "const": "tasks/get" + }, + "GetTaskInfoParam": { + "type": "object", + "properties": { + "taskId": { + "type": "string" + } + }, + "required": [ + "taskId" + ] + }, + "GetTaskResultMethod": { + "type": "string", + "format": "const", + "const": "tasks/result" + }, + "GetTaskResultParam": { + "type": "object", + "properties": { + "taskId": { + "type": "string" + } + }, + "required": [ + "taskId" + ] + }, "Icon": { "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", "type": "object", @@ -730,6 +795,18 @@ }, { "$ref": "#/definitions/CustomRequest" + }, + { + "$ref": "#/definitions/Request9" + }, + { + "$ref": "#/definitions/RequestOptionalParam5" + }, + { + "$ref": "#/definitions/Request10" + }, + { + "$ref": "#/definitions/Request11" } ], "required": [ @@ -790,6 +867,11 @@ "roots" ] }, + "ListTasksMethod": { + "type": "string", + "format": "const", + "const": "tasks/list" + }, "ListToolsRequestMethod": { "type": "string", "format": "const", @@ -1168,6 +1250,38 @@ "params" ] }, + "Request10": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/GetTaskResultMethod" + }, + "params": { + "$ref": "#/definitions/GetTaskResultParam" + } + }, + "required": [ + "method", + "params" + ] + }, + "Request11": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/CancelTaskMethod" + }, + "params": { + "$ref": "#/definitions/CancelTaskParam" + } + }, + "required": [ + "method", + "params" + ] + }, "Request2": { "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", "type": "object", @@ -1280,6 +1394,22 @@ "params" ] }, + "Request9": { + "description": "Represents a JSON-RPC request with method, parameters, and extensions.\n\nThis is the core structure for all MCP requests, containing:\n- `method`: The name of the method being called\n- `params`: The parameters for the method\n- `extensions`: Additional context data (similar to HTTP headers)", + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/GetTaskInfoMethod" + }, + "params": { + "$ref": "#/definitions/GetTaskInfoParam" + } + }, + "required": [ + "method", + "params" + ] + }, "RequestNoParam": { "type": "object", "properties": { @@ -1375,6 +1505,27 @@ "method" ] }, + "RequestOptionalParam5": { + "type": "object", + "properties": { + "method": { + "$ref": "#/definitions/ListTasksMethod" + }, + "params": { + "anyOf": [ + { + "$ref": "#/definitions/PaginatedRequestParam" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "method" + ] + }, "ResourceContents": { "anyOf": [ { @@ -1534,6 +1685,36 @@ "uri" ] }, + "TasksCapability": { + "description": "Task capability negotiation for SEP-1686.", + "type": "object", + "properties": { + "cancel": { + "description": "Whether the receiver supports `tasks/cancel`.", + "type": [ + "boolean", + "null" + ] + }, + "list": { + "description": "Whether the receiver supports `tasks/list`.", + "type": [ + "boolean", + "null" + ] + }, + "requests": { + "description": "Map of request category (e.g. \"tools.call\") to a boolean indicating support.", + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "boolean" + } + } + } + }, "UnsubscribeRequestMethod": { "type": "string", "format": "const", diff --git a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json index 00b5d11d..6889a172 100644 --- a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json +++ b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema.json @@ -603,6 +603,18 @@ "maxTokens" ] }, + "CreateTaskResult": { + "description": "Wrapper returned by task-augmented requests (CreateTaskResult in SEP-1686).", + "type": "object", + "properties": { + "task": { + "$ref": "#/definitions/Task" + } + }, + "required": [ + "task" + ] + }, "CustomNotification": { "description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", "type": "object", @@ -810,6 +822,21 @@ "messages" ] }, + "GetTaskInfoResult": { + "type": "object", + "properties": { + "task": { + "anyOf": [ + { + "$ref": "#/definitions/Task" + }, + { + "type": "null" + } + ] + } + } + }, "Icon": { "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", "type": "object", @@ -1168,6 +1195,34 @@ "format": "const", "const": "roots/list" }, + "ListTasksResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "tasks": { + "type": "array", + "items": { + "$ref": "#/definitions/Task" + } + }, + "total": { + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + } + }, + "required": [ + "tasks" + ] + }, "ListToolsResult": { "type": "object", "properties": { @@ -2255,6 +2310,16 @@ } ] }, + "tasks": { + "anyOf": [ + { + "$ref": "#/definitions/TasksCapability" + }, + { + "type": "null" + } + ] + }, "tools": { "anyOf": [ { @@ -2304,6 +2369,18 @@ }, { "$ref": "#/definitions/CustomResult" + }, + { + "$ref": "#/definitions/CreateTaskResult" + }, + { + "$ref": "#/definitions/ListTasksResult" + }, + { + "$ref": "#/definitions/GetTaskInfoResult" + }, + { + "$ref": "#/definitions/TaskResult" } ] }, @@ -2397,6 +2474,149 @@ "format": "const", "const": "string" }, + "Task": { + "description": "Primary Task object that surfaces metadata during the task lifecycle.", + "type": "object", + "properties": { + "createdAt": { + "description": "ISO-8601 creation timestamp.", + "type": "string" + }, + "lastUpdatedAt": { + "description": "ISO-8601 timestamp for the most recent status change.", + "type": [ + "string", + "null" + ] + }, + "pollInterval": { + "description": "Suggested polling interval (milliseconds).", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + }, + "status": { + "description": "Current lifecycle status (see [`TaskStatus`]).", + "allOf": [ + { + "$ref": "#/definitions/TaskStatus" + } + ] + }, + "statusMessage": { + "description": "Optional human-readable status message for UI surfaces.", + "type": [ + "string", + "null" + ] + }, + "taskId": { + "description": "Unique task identifier generated by the receiver.", + "type": "string" + }, + "ttl": { + "description": "Retention window in milliseconds that the receiver agreed to honor.", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + } + }, + "required": [ + "taskId", + "status", + "createdAt" + ] + }, + "TaskResult": { + "description": "Final result for a succeeded task (returned from `tasks/result`).", + "type": "object", + "properties": { + "contentType": { + "description": "MIME type or custom content-type identifier.", + "type": "string" + }, + "summary": { + "description": "Optional short summary for UI surfaces.", + "type": [ + "string", + "null" + ] + }, + "value": { + "description": "The actual result payload, matching the underlying request's schema." + } + }, + "required": [ + "contentType", + "value" + ] + }, + "TaskStatus": { + "description": "Canonical task lifecycle status as defined by SEP-1686.", + "oneOf": [ + { + "description": "The receiver accepted the request and is currently working on it.", + "type": "string", + "const": "working" + }, + { + "description": "The receiver requires additional input before work can continue.", + "type": "string", + "const": "input_required" + }, + { + "description": "The underlying operation completed successfully and the result is ready.", + "type": "string", + "const": "completed" + }, + { + "description": "The underlying operation failed and will not continue.", + "type": "string", + "const": "failed" + }, + { + "description": "The task was cancelled and will not continue processing.", + "type": "string", + "const": "cancelled" + } + ] + }, + "TasksCapability": { + "description": "Task capability negotiation for SEP-1686.", + "type": "object", + "properties": { + "cancel": { + "description": "Whether the receiver supports `tasks/cancel`.", + "type": [ + "boolean", + "null" + ] + }, + "list": { + "description": "Whether the receiver supports `tasks/list`.", + "type": [ + "boolean", + "null" + ] + }, + "requests": { + "description": "Map of request category (e.g. \"tools.call\") to a boolean indicating support.", + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "boolean" + } + } + } + }, "Tool": { "description": "A tool that can be used by a model.", "type": "object", diff --git a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json index 00b5d11d..6889a172 100644 --- a/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json +++ b/crates/rmcp/tests/test_message_schema/server_json_rpc_message_schema_current.json @@ -603,6 +603,18 @@ "maxTokens" ] }, + "CreateTaskResult": { + "description": "Wrapper returned by task-augmented requests (CreateTaskResult in SEP-1686).", + "type": "object", + "properties": { + "task": { + "$ref": "#/definitions/Task" + } + }, + "required": [ + "task" + ] + }, "CustomNotification": { "description": "A catch-all notification either side can use to send custom messages to its peer.\n\nThis preserves the raw `method` name and `params` payload so handlers can\ndeserialize them into domain-specific types.", "type": "object", @@ -810,6 +822,21 @@ "messages" ] }, + "GetTaskInfoResult": { + "type": "object", + "properties": { + "task": { + "anyOf": [ + { + "$ref": "#/definitions/Task" + }, + { + "type": "null" + } + ] + } + } + }, "Icon": { "description": "A URL pointing to an icon resource or a base64-encoded data URI.\n\nClients that support rendering icons MUST support at least the following MIME types:\n- image/png - PNG images (safe, universal compatibility)\n- image/jpeg (and image/jpg) - JPEG images (safe, universal compatibility)\n\nClients that support rendering icons SHOULD also support:\n- image/svg+xml - SVG images (scalable but requires security precautions)\n- image/webp - WebP images (modern, efficient format)", "type": "object", @@ -1168,6 +1195,34 @@ "format": "const", "const": "roots/list" }, + "ListTasksResult": { + "type": "object", + "properties": { + "nextCursor": { + "type": [ + "string", + "null" + ] + }, + "tasks": { + "type": "array", + "items": { + "$ref": "#/definitions/Task" + } + }, + "total": { + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + } + }, + "required": [ + "tasks" + ] + }, "ListToolsResult": { "type": "object", "properties": { @@ -2255,6 +2310,16 @@ } ] }, + "tasks": { + "anyOf": [ + { + "$ref": "#/definitions/TasksCapability" + }, + { + "type": "null" + } + ] + }, "tools": { "anyOf": [ { @@ -2304,6 +2369,18 @@ }, { "$ref": "#/definitions/CustomResult" + }, + { + "$ref": "#/definitions/CreateTaskResult" + }, + { + "$ref": "#/definitions/ListTasksResult" + }, + { + "$ref": "#/definitions/GetTaskInfoResult" + }, + { + "$ref": "#/definitions/TaskResult" } ] }, @@ -2397,6 +2474,149 @@ "format": "const", "const": "string" }, + "Task": { + "description": "Primary Task object that surfaces metadata during the task lifecycle.", + "type": "object", + "properties": { + "createdAt": { + "description": "ISO-8601 creation timestamp.", + "type": "string" + }, + "lastUpdatedAt": { + "description": "ISO-8601 timestamp for the most recent status change.", + "type": [ + "string", + "null" + ] + }, + "pollInterval": { + "description": "Suggested polling interval (milliseconds).", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + }, + "status": { + "description": "Current lifecycle status (see [`TaskStatus`]).", + "allOf": [ + { + "$ref": "#/definitions/TaskStatus" + } + ] + }, + "statusMessage": { + "description": "Optional human-readable status message for UI surfaces.", + "type": [ + "string", + "null" + ] + }, + "taskId": { + "description": "Unique task identifier generated by the receiver.", + "type": "string" + }, + "ttl": { + "description": "Retention window in milliseconds that the receiver agreed to honor.", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + } + }, + "required": [ + "taskId", + "status", + "createdAt" + ] + }, + "TaskResult": { + "description": "Final result for a succeeded task (returned from `tasks/result`).", + "type": "object", + "properties": { + "contentType": { + "description": "MIME type or custom content-type identifier.", + "type": "string" + }, + "summary": { + "description": "Optional short summary for UI surfaces.", + "type": [ + "string", + "null" + ] + }, + "value": { + "description": "The actual result payload, matching the underlying request's schema." + } + }, + "required": [ + "contentType", + "value" + ] + }, + "TaskStatus": { + "description": "Canonical task lifecycle status as defined by SEP-1686.", + "oneOf": [ + { + "description": "The receiver accepted the request and is currently working on it.", + "type": "string", + "const": "working" + }, + { + "description": "The receiver requires additional input before work can continue.", + "type": "string", + "const": "input_required" + }, + { + "description": "The underlying operation completed successfully and the result is ready.", + "type": "string", + "const": "completed" + }, + { + "description": "The underlying operation failed and will not continue.", + "type": "string", + "const": "failed" + }, + { + "description": "The task was cancelled and will not continue processing.", + "type": "string", + "const": "cancelled" + } + ] + }, + "TasksCapability": { + "description": "Task capability negotiation for SEP-1686.", + "type": "object", + "properties": { + "cancel": { + "description": "Whether the receiver supports `tasks/cancel`.", + "type": [ + "boolean", + "null" + ] + }, + "list": { + "description": "Whether the receiver supports `tasks/list`.", + "type": [ + "boolean", + "null" + ] + }, + "requests": { + "description": "Map of request category (e.g. \"tools.call\") to a boolean indicating support.", + "type": [ + "object", + "null" + ], + "additionalProperties": { + "type": "boolean" + } + } + } + }, "Tool": { "description": "A tool that can be used by a model.", "type": "object", diff --git a/crates/rmcp/tests/test_progress_subscriber.rs b/crates/rmcp/tests/test_progress_subscriber.rs index 531b1692..b5d185ab 100644 --- a/crates/rmcp/tests/test_progress_subscriber.rs +++ b/crates/rmcp/tests/test_progress_subscriber.rs @@ -110,6 +110,7 @@ async fn test_progress_subscriber() -> anyhow::Result<()> { ClientRequest::CallToolRequest(Request::new(CallToolRequestParam { name: "some_progress".into(), arguments: None, + task: None, })), PeerRequestOptions::no_options(), ) diff --git a/crates/rmcp/tests/test_task.rs b/crates/rmcp/tests/test_task.rs new file mode 100644 index 00000000..31fc9a9b --- /dev/null +++ b/crates/rmcp/tests/test_task.rs @@ -0,0 +1,77 @@ +use std::{any::Any, time::Duration}; + +use rmcp::task_manager::{ + OperationDescriptor, OperationMessage, OperationProcessor, OperationResultTransport, +}; + +struct DummyTransport { + id: String, + value: u32, +} + +impl OperationResultTransport for DummyTransport { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +#[tokio::test] +async fn executes_enqueued_future() { + let mut processor = OperationProcessor::new(); + let descriptor = OperationDescriptor::new("op1", "dummy"); + let future = Box::pin(async { + tokio::time::sleep(Duration::from_millis(10)).await; + Ok(Box::new(DummyTransport { + id: "op1".to_string(), + value: 42, + }) as Box) + }); + + processor + .submit_operation(OperationMessage::new(descriptor, future)) + .expect("submit operation"); + + tokio::time::sleep(Duration::from_millis(30)).await; + let results = processor.collect_completed_results(); + assert_eq!(results.len(), 1); + let payload = results[0] + .result + .as_ref() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(payload.value, 42); +} + +#[tokio::test] +async fn rejects_duplicate_operation_ids() { + let mut processor = OperationProcessor::new(); + let descriptor = OperationDescriptor::new("dup", "dummy"); + let future = Box::pin(async { + Ok(Box::new(DummyTransport { + id: "dup".to_string(), + value: 1, + }) as Box) + }); + processor + .submit_operation(OperationMessage::new(descriptor, future)) + .expect("first submit"); + + let descriptor_dup = OperationDescriptor::new("dup", "dummy"); + let future_dup = Box::pin(async { + Ok(Box::new(DummyTransport { + id: "dup".to_string(), + value: 2, + }) as Box) + }); + + let err = processor + .submit_operation(OperationMessage::new(descriptor_dup, future_dup)) + .expect_err("duplicate should fail"); + assert!(format!("{err}").contains("already running")); +} diff --git a/crates/rmcp/tests/test_tool_macros.rs b/crates/rmcp/tests/test_tool_macros.rs index db5242b3..763c4f43 100644 --- a/crates/rmcp/tests/test_tool_macros.rs +++ b/crates/rmcp/tests/test_tool_macros.rs @@ -320,6 +320,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), + task: None, }) .await?; @@ -348,6 +349,7 @@ async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { .unwrap() .clone(), ), + task: None, }) .await?; diff --git a/examples/clients/src/collection.rs b/examples/clients/src/collection.rs index 67969ae4..c714da54 100644 --- a/examples/clients/src/collection.rs +++ b/examples/clients/src/collection.rs @@ -49,6 +49,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), + task: None, }) .await?; } diff --git a/examples/clients/src/everything_stdio.rs b/examples/clients/src/everything_stdio.rs index 107adc07..f1cbcae5 100644 --- a/examples/clients/src/everything_stdio.rs +++ b/examples/clients/src/everything_stdio.rs @@ -40,6 +40,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "echo".into(), arguments: Some(object!({ "message": "hi from rmcp" })), + task: None, }) .await?; tracing::info!("Tool result for echo: {tool_result:#?}"); @@ -49,6 +50,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "longRunningOperation".into(), arguments: Some(object!({ "duration": 3, "steps": 1 })), + task: None, }) .await?; tracing::info!("Tool result for longRunningOperation: {tool_result:#?}"); diff --git a/examples/clients/src/git_stdio.rs b/examples/clients/src/git_stdio.rs index d1298b36..7b516f38 100644 --- a/examples/clients/src/git_stdio.rs +++ b/examples/clients/src/git_stdio.rs @@ -42,6 +42,7 @@ async fn main() -> Result<(), RmcpError> { .call_tool(CallToolRequestParam { name: "git_status".into(), arguments: serde_json::json!({ "repo_path": "." }).as_object().cloned(), + task: None, }) .await?; tracing::info!("Tool result: {tool_result:#?}"); diff --git a/examples/clients/src/progress_client.rs b/examples/clients/src/progress_client.rs index ddf18b2f..c795ce22 100644 --- a/examples/clients/src/progress_client.rs +++ b/examples/clients/src/progress_client.rs @@ -184,6 +184,7 @@ async fn test_stdio_transport(records: u32) -> Result<()> { .call_tool(CallToolRequestParam { name: "stream_processor".into(), arguments: None, + task: None, }) .await?; @@ -238,6 +239,7 @@ async fn test_http_transport(http_url: &str, records: u32) -> Result<()> { .call_tool(CallToolRequestParam { name: "stream_processor".into(), arguments: None, + task: None, }) .await?; diff --git a/examples/clients/src/sampling_stdio.rs b/examples/clients/src/sampling_stdio.rs index 8f5aba22..b30a3c26 100644 --- a/examples/clients/src/sampling_stdio.rs +++ b/examples/clients/src/sampling_stdio.rs @@ -106,6 +106,7 @@ async fn main() -> Result<()> { arguments: Some(object!({ "question": "Hello world" })), + task: None, }) .await { diff --git a/examples/clients/src/streamable_http.rs b/examples/clients/src/streamable_http.rs index 2f1f1598..cd4b73c4 100644 --- a/examples/clients/src/streamable_http.rs +++ b/examples/clients/src/streamable_http.rs @@ -44,6 +44,7 @@ async fn main() -> Result<()> { .call_tool(CallToolRequestParam { name: "increment".into(), arguments: serde_json::json!({}).as_object().cloned(), + task: None, }) .await?; tracing::info!("Tool result: {tool_result:#?}"); diff --git a/examples/rig-integration/src/mcp_adaptor.rs b/examples/rig-integration/src/mcp_adaptor.rs index 483c6e02..286e58d5 100644 --- a/examples/rig-integration/src/mcp_adaptor.rs +++ b/examples/rig-integration/src/mcp_adaptor.rs @@ -47,6 +47,7 @@ impl RigTool for McpToolAdaptor { name: self.tool.name.clone(), arguments: serde_json::from_str(&args) .map_err(rig::tool::ToolError::JsonError)?, + task: None, }) .await .inspect(|result| tracing::info!(?result)) diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs index dc2472bb..ac271cba 100644 --- a/examples/servers/src/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,6 +1,7 @@ #![allow(dead_code)] -use std::sync::Arc; +use std::{any::Any, sync::Arc}; +use chrono::Utc; use rmcp::{ ErrorData as McpError, RoleServer, ServerHandler, handler::server::{ @@ -10,10 +11,30 @@ use rmcp::{ model::*, prompt, prompt_handler, prompt_router, schemars, service::RequestContext, + task_handler, + task_manager::{ + OperationDescriptor, OperationMessage, OperationProcessor, OperationResultTransport, + }, tool, tool_handler, tool_router, }; use serde_json::json; use tokio::sync::Mutex; +use tracing::info; + +struct ToolCallOperationResult { + id: String, + result: Result, +} + +impl OperationResultTransport for ToolCallOperationResult { + fn operation_id(&self) -> &String { + &self.id + } + + fn as_any(&self) -> &dyn Any { + self + } +} #[derive(Debug, serde::Deserialize, schemars::JsonSchema)] pub struct StructRequest { @@ -41,6 +62,7 @@ pub struct Counter { counter: Arc>, tool_router: ToolRouter, prompt_router: PromptRouter, + processor: Arc>, } #[tool_router] @@ -51,6 +73,7 @@ impl Counter { counter: Arc::new(Mutex::new(0)), tool_router: Self::tool_router(), prompt_router: Self::prompt_router(), + processor: Arc::new(Mutex::new(OperationProcessor::new())), } } @@ -84,6 +107,14 @@ impl Counter { )])) } + #[tool(description = "Long running task example")] + async fn long_task(&self) -> Result { + tokio::time::sleep(std::time::Duration::from_secs(10)).await; + Ok(CallToolResult::success(vec![Content::text( + "Long task completed", + )])) + } + #[tool(description = "Say hello to the client")] fn say_hello(&self) -> Result { Ok(CallToolResult::success(vec![Content::text("hello")])) @@ -166,6 +197,7 @@ impl Counter { #[tool_handler(meta = Meta(rmcp::object!({"tool_meta_key": "tool_meta_value"})))] #[prompt_handler(meta = Meta(rmcp::object!({"router_meta_key": "router_meta_value"})))] +#[task_handler] impl ServerHandler for Counter { fn get_info(&self) -> ServerInfo { ServerInfo { @@ -250,8 +282,16 @@ impl ServerHandler for Counter { #[cfg(test)] mod tests { + use rmcp::{ClientHandler, ServiceExt}; + use tokio::time::Duration; + use super::*; + #[derive(Default, Clone)] + struct TestClient; + + impl ClientHandler for TestClient {} + #[tokio::test] async fn test_prompt_attributes_generated() { // Verify that the prompt macros generate the expected attributes @@ -289,34 +329,56 @@ mod tests { } #[tokio::test] - async fn test_example_prompt_execution() { + async fn test_client_enqueues_long_task() -> anyhow::Result<()> { let counter = Counter::new(); - let context = rmcp::handler::server::prompt::PromptContext::new( - &counter, - "example_prompt".to_string(), - Some({ - let mut map = serde_json::Map::new(); - map.insert( - "message".to_string(), - serde_json::Value::String("Test message".to_string()), - ); - map - }), - RequestContext { - meta: Default::default(), - ct: tokio_util::sync::CancellationToken::new(), - id: rmcp::model::NumberOrString::String("test-1".to_string()), - peer: Default::default(), - extensions: Default::default(), - }, + let processor = counter.processor.clone(); + let client = TestClient::default(); + + let (server_transport, client_transport) = tokio::io::duplex(4096); + let server_handle = tokio::spawn(async move { + let service = counter.serve(server_transport).await?; + service.waiting().await?; + anyhow::Ok(()) + }); + + let client_service = client.serve(client_transport).await?; + let mut task_meta = serde_json::Map::new(); + task_meta.insert( + "source".into(), + serde_json::Value::String("integration-test".into()), ); - - let router = Counter::prompt_router(); - let result = router.get_prompt(context).await; - assert!(result.is_ok()); - - let prompt_result = result.unwrap(); - assert_eq!(prompt_result.messages.len(), 1); - assert_eq!(prompt_result.messages[0].role, PromptMessageRole::User); + let params = CallToolRequestParam { + name: "long_task".into(), + arguments: None, + task: Some(task_meta), + }; + let response = client_service + .send_request(ClientRequest::CallToolRequest(Request::new(params.clone()))) + .await?; + + let ServerResult::CreateTaskResult(info) = response else { + panic!("expected task creation result, got {response:?}"); + }; + let task = info.task; + + assert_eq!(task.status, TaskStatus::Working); + // task list should show the task + let tasks = client_service + .send_request(ClientRequest::ListTasksRequest( + RequestOptionalParam::default(), + )) + .await + .unwrap(); + let ServerResult::ListTasksResult(listed) = tasks else { + panic!("expected list tasks result, got {tasks:?}"); + }; + assert_eq!(listed.tasks[0].task_id, task.task_id); + tokio::time::sleep(Duration::from_millis(50)).await; + let running = processor.lock().await.running_task_count(); + assert_eq!(running, 1); + + client_service.cancel().await?; + let _ = server_handle.await; + Ok(()) } } diff --git a/examples/simple-chat-client/src/tool.rs b/examples/simple-chat-client/src/tool.rs index 771b4e9e..174f4274 100644 --- a/examples/simple-chat-client/src/tool.rs +++ b/examples/simple-chat-client/src/tool.rs @@ -62,6 +62,7 @@ impl Tool for McpToolAdapter { .call_tool(CallToolRequestParam { name: self.tool.name.clone(), arguments, + task: None, }) .await?; diff --git a/examples/transport/src/named-pipe.rs b/examples/transport/src/named-pipe.rs index c472fad6..b070d02b 100644 --- a/examples/transport/src/named-pipe.rs +++ b/examples/transport/src/named-pipe.rs @@ -54,6 +54,7 @@ async fn main() -> anyhow::Result<()> { "a": 10, "b": 20 })), + task: None, }) .await?; diff --git a/examples/transport/src/unix_socket.rs b/examples/transport/src/unix_socket.rs index feeb2b87..0d91dfee 100644 --- a/examples/transport/src/unix_socket.rs +++ b/examples/transport/src/unix_socket.rs @@ -52,6 +52,7 @@ async fn main() -> anyhow::Result<()> { "a": 10, "b": 20 })), + task: None, }) .await?;