Skip to content

Commit 5f06c0e

Browse files
committed
feat(task): add task support (SEP-1686)
Signed-off-by: jokemanfire <hu.dingyang@zte.com.cn>
1 parent f20ed20 commit 5f06c0e

File tree

26 files changed

+1050
-45
lines changed

26 files changed

+1050
-45
lines changed

crates/rmcp-macros/src/lib.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ mod common;
55
mod prompt;
66
mod prompt_handler;
77
mod prompt_router;
8+
mod task_handler;
89
mod tool;
910
mod tool_handler;
1011
mod tool_router;
@@ -263,3 +264,17 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> TokenStream {
263264
.unwrap_or_else(|err| err.to_compile_error())
264265
.into()
265266
}
267+
268+
/// # task_handler
269+
///
270+
/// Generates basic task-handling methods (`enqueue_task` and `list_tasks`) for a server handler
271+
/// using a shared [`OperationProcessor`]. The default processor expression assumes a
272+
/// `self.processor` field holding an `Arc<Mutex<OperationProcessor>>`, but it can be customized
273+
/// via `#[task_handler(processor = ...)]`. Because the macro captures `self` inside spawned
274+
/// futures, the handler type must implement [`Clone`].
275+
#[proc_macro_attribute]
276+
pub fn task_handler(attr: TokenStream, input: TokenStream) -> TokenStream {
277+
task_handler::task_handler(attr.into(), input.into())
278+
.unwrap_or_else(|err| err.to_compile_error())
279+
.into()
280+
}
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
use darling::{FromMeta, ast::NestedMeta};
2+
use proc_macro2::TokenStream;
3+
use quote::{ToTokens, quote};
4+
use syn::{Expr, ImplItem, ItemImpl};
5+
6+
#[derive(FromMeta)]
7+
#[darling(default)]
8+
struct TaskHandlerAttribute {
9+
processor: Expr,
10+
}
11+
12+
impl Default for TaskHandlerAttribute {
13+
fn default() -> Self {
14+
Self {
15+
processor: syn::parse2(quote! { self.processor }).expect("default processor expr"),
16+
}
17+
}
18+
}
19+
20+
pub fn task_handler(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
21+
let attr_args = NestedMeta::parse_meta_list(attr)?;
22+
let TaskHandlerAttribute { processor } = TaskHandlerAttribute::from_list(&attr_args)?;
23+
let mut item_impl = syn::parse2::<ItemImpl>(input.clone())?;
24+
25+
let has_method = |name: &str, item_impl: &ItemImpl| -> bool {
26+
item_impl.items.iter().any(|item| match item {
27+
ImplItem::Fn(func) => func.sig.ident == name,
28+
_ => false,
29+
})
30+
};
31+
32+
if !has_method("list_tasks", &item_impl) {
33+
let list_fn = quote! {
34+
async fn list_tasks(
35+
&self,
36+
_request: Option<rmcp::model::PaginatedRequestParam>,
37+
_: rmcp::service::RequestContext<rmcp::RoleServer>,
38+
) -> Result<rmcp::model::ListTasksResult, McpError> {
39+
let running_ids = (#processor).lock().await.list_running();
40+
let total = running_ids.len() as u64;
41+
let tasks = running_ids
42+
.into_iter()
43+
.map(|task_id| {
44+
let timestamp = rmcp::task_manager::current_timestamp();
45+
rmcp::model::Task {
46+
task_id,
47+
status: rmcp::model::TaskStatus::Working,
48+
status_message: None,
49+
created_at: timestamp.clone(),
50+
last_updated_at: Some(timestamp),
51+
ttl: None,
52+
poll_interval: None,
53+
}
54+
})
55+
.collect::<Vec<_>>();
56+
57+
Ok(rmcp::model::ListTasksResult {
58+
tasks,
59+
next_cursor: None,
60+
total: Some(total),
61+
})
62+
}
63+
};
64+
item_impl.items.push(syn::parse2::<ImplItem>(list_fn)?);
65+
}
66+
67+
if !has_method("enqueue_task", &item_impl) {
68+
let enqueue_fn = quote! {
69+
async fn enqueue_task(
70+
&self,
71+
request: rmcp::model::CallToolRequestParam,
72+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
73+
) -> Result<rmcp::model::CreateTaskResult, McpError> {
74+
use rmcp::task_manager::{
75+
current_timestamp, OperationDescriptor, OperationMessage, OperationResultTransport,
76+
ToolCallTaskResult,
77+
};
78+
let task_id = context.id.to_string();
79+
let operation_name = request.name.to_string();
80+
let future_request = request.clone();
81+
let future_context = context.clone();
82+
let server = self.clone();
83+
84+
let descriptor = OperationDescriptor::new(task_id.clone(), operation_name)
85+
.with_context(context)
86+
.with_client_request(rmcp::model::ClientRequest::CallToolRequest(
87+
rmcp::model::Request::new(request),
88+
));
89+
90+
let task_result_id = task_id.clone();
91+
let future = Box::pin(async move {
92+
let result = server.call_tool(future_request, future_context).await;
93+
Ok(
94+
Box::new(ToolCallTaskResult::new(task_result_id, result))
95+
as Box<dyn OperationResultTransport>,
96+
)
97+
});
98+
99+
(#processor)
100+
.lock()
101+
.await
102+
.submit_operation(OperationMessage::new(descriptor, future))
103+
.map_err(|err| rmcp::ErrorData::internal_error(
104+
format!("failed to enqueue task: {err}"),
105+
None,
106+
))?;
107+
108+
let timestamp = current_timestamp();
109+
let task = rmcp::model::Task {
110+
task_id,
111+
status: rmcp::model::TaskStatus::Working,
112+
status_message: Some("Task accepted".to_string()),
113+
created_at: timestamp.clone(),
114+
last_updated_at: Some(timestamp),
115+
ttl: None,
116+
poll_interval: None,
117+
};
118+
119+
Ok(rmcp::model::CreateTaskResult { task })
120+
}
121+
};
122+
item_impl.items.push(syn::parse2::<ImplItem>(enqueue_fn)?);
123+
}
124+
125+
if !has_method("get_task_info", &item_impl) {
126+
let get_info_fn = quote! {
127+
async fn get_task_info(
128+
&self,
129+
request: rmcp::model::GetTaskInfoParam,
130+
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
131+
) -> Result<rmcp::model::GetTaskInfoResult, McpError> {
132+
use rmcp::task_manager::current_timestamp;
133+
let task_id = request.task_id.clone();
134+
let mut processor = (#processor).lock().await;
135+
processor.collect_completed_results();
136+
137+
// Check completed results first
138+
let completed = processor.peek_completed().iter().rev().find(|r| r.descriptor.operation_id == task_id);
139+
if let Some(completed_result) = completed {
140+
// Determine Finished vs Failed
141+
let status = match &completed_result.result {
142+
Ok(boxed) => {
143+
if let Some(tool) = boxed.as_any().downcast_ref::<rmcp::task_manager::ToolCallTaskResult>() {
144+
match &tool.result {
145+
Ok(_) => rmcp::model::TaskStatus::Completed,
146+
Err(_) => rmcp::model::TaskStatus::Failed,
147+
}
148+
} else {
149+
rmcp::model::TaskStatus::Completed
150+
}
151+
}
152+
Err(_) => rmcp::model::TaskStatus::Failed,
153+
};
154+
let timestamp = current_timestamp();
155+
let task = rmcp::model::Task {
156+
task_id,
157+
status,
158+
status_message: None,
159+
created_at: timestamp.clone(),
160+
last_updated_at: Some(timestamp),
161+
ttl: completed_result.descriptor.ttl,
162+
poll_interval: None,
163+
};
164+
return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) });
165+
}
166+
167+
// If not completed, check running
168+
let running = processor.list_running();
169+
if running.into_iter().any(|id| id == task_id) {
170+
let timestamp = current_timestamp();
171+
let task = rmcp::model::Task {
172+
task_id,
173+
status: rmcp::model::TaskStatus::Working,
174+
status_message: None,
175+
created_at: timestamp.clone(),
176+
last_updated_at: Some(timestamp),
177+
ttl: None,
178+
poll_interval: None,
179+
};
180+
return Ok(rmcp::model::GetTaskInfoResult { task: Some(task) });
181+
}
182+
183+
Ok(rmcp::model::GetTaskInfoResult { task: None })
184+
}
185+
};
186+
item_impl.items.push(syn::parse2::<ImplItem>(get_info_fn)?);
187+
}
188+
189+
if !has_method("get_task_result", &item_impl) {
190+
let get_result_fn = quote! {
191+
async fn get_task_result(
192+
&self,
193+
request: rmcp::model::GetTaskResultParam,
194+
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
195+
) -> Result<rmcp::model::TaskResult, McpError> {
196+
use std::time::Duration;
197+
let task_id = request.task_id.clone();
198+
199+
loop {
200+
// Scope the lock so we can await outside if needed
201+
{
202+
let mut processor = (#processor).lock().await;
203+
processor.collect_completed_results();
204+
205+
if let Some(task_result) = processor.take_completed_result(&task_id) {
206+
match task_result.result {
207+
Ok(boxed) => {
208+
if let Some(tool) = boxed.as_any().downcast_ref::<rmcp::task_manager::ToolCallTaskResult>() {
209+
match &tool.result {
210+
Ok(call_tool) => {
211+
let value = ::serde_json::to_value(call_tool).unwrap_or(::serde_json::Value::Null);
212+
return Ok(rmcp::model::TaskResult {
213+
content_type: "application/json".to_string(),
214+
value,
215+
summary: None,
216+
});
217+
}
218+
Err(err) => return Err(McpError::internal_error(
219+
format!("task failed: {}", err),
220+
None,
221+
)),
222+
}
223+
} else {
224+
return Err(McpError::internal_error("unsupported task result transport", None));
225+
}
226+
}
227+
Err(err) => return Err(McpError::internal_error(
228+
format!("task execution error: {}", err),
229+
None,
230+
)),
231+
}
232+
}
233+
234+
// Not completed yet: if not running, return not found
235+
let running = processor.list_running();
236+
if !running.iter().any(|id| id == &task_id) {
237+
return Err(McpError::resource_not_found(format!("task not found: {}", task_id), None));
238+
}
239+
}
240+
241+
tokio::time::sleep(Duration::from_millis(100)).await;
242+
}
243+
}
244+
};
245+
item_impl
246+
.items
247+
.push(syn::parse2::<ImplItem>(get_result_fn)?);
248+
}
249+
250+
if !has_method("cancel_task", &item_impl) {
251+
let cancel_fn = quote! {
252+
async fn cancel_task(
253+
&self,
254+
request: rmcp::model::CancelTaskParam,
255+
_context: rmcp::service::RequestContext<rmcp::RoleServer>,
256+
) -> Result<(), McpError> {
257+
let task_id = request.task_id;
258+
let mut processor = (#processor).lock().await;
259+
processor.collect_completed_results();
260+
261+
if processor.cancel_task(&task_id) {
262+
return Ok(());
263+
}
264+
265+
// If already completed, signal it's not cancellable
266+
let exists_completed = processor.peek_completed().iter().any(|r| r.descriptor.operation_id == task_id);
267+
if exists_completed {
268+
return Err(McpError::invalid_request(format!("task already completed: {}", task_id), None));
269+
}
270+
271+
Err(McpError::resource_not_found(format!("task not found: {}", task_id), None))
272+
}
273+
};
274+
item_impl.items.push(syn::parse2::<ImplItem>(cancel_fn)?);
275+
}
276+
277+
Ok(item_impl.into_token_stream())
278+
}

crates/rmcp/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,8 @@ path = "tests/test_progress_subscriber.rs"
197197
name = "test_elicitation"
198198
required-features = ["elicitation", "client", "server"]
199199
path = "tests/test_elicitation.rs"
200+
201+
[[test]]
202+
name = "test_task"
203+
required-features = ["server", "client", "macros"]
204+
path = "tests/test_task.rs"

crates/rmcp/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ pub enum RmcpError {
4141
error: Box<dyn std::error::Error + Send + Sync>,
4242
},
4343
// and cancellation shouldn't be an error?
44+
45+
// TODO: add more error variants as needed
46+
#[error("Task error: {0}")]
47+
TaskError(String),
4448
}
4549

4650
impl RmcpError {

0 commit comments

Comments
 (0)