From 36634448b21b7c14a33e351ac554b80c46843818 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 31 Mar 2025 10:50:15 +0800 Subject: [PATCH] fix(notification): fix wrongly error report in notification --- crates/rmcp/Cargo.toml | 9 ++- crates/rmcp/src/handler/server.rs | 2 +- crates/rmcp/src/service.rs | 20 +++-- crates/rmcp/tests/test_notification.rs | 101 +++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 11 deletions(-) create mode 100644 crates/rmcp/tests/test_notification.rs diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index ed426a0f..9ae90b05 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -95,5 +95,10 @@ path = "tests/test_with_python.rs" [[test]] name = "test_with_js" -required-features = ["server", "transport-sse-server"] -path = "tests/test_with_js.rs" \ No newline at end of file +required-features = ["server", "client", "transport-sse-server", "transport-child-process"] +path = "tests/test_with_js.rs" + +[[test]] +name = "test_notification" +required-features = ["server", "client"] +path = "tests/test_notification.rs" diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 1fa90dd8..9d1a0b35 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -103,7 +103,7 @@ impl Service for H { } #[allow(unused_variables)] -pub trait ServerHandler: Sized + Clone + Send + Sync + 'static { +pub trait ServerHandler: Sized + Send + Sync + 'static { fn ping( &self, context: RequestContext, diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index e99ef587..9ddb37e7 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -342,10 +342,12 @@ impl Peer { self.tx .send(PeerSinkMessage::Notification(notification, responder)) .await - .map_err(|_m| ServiceError::Transport(std::io::Error::other("disconnected")))?; - receiver - .await - .map_err(|_e| ServiceError::Transport(std::io::Error::other("disconnected")))? + .map_err(|_m| { + ServiceError::Transport(std::io::Error::other("disconnected: receiver dropped")) + })?; + receiver.await.map_err(|_e| { + ServiceError::Transport(std::io::Error::other("disconnected: responder dropped")) + })? } pub async fn send_request(&self, request: R::Req) -> Result { self.send_cancellable_request(request, PeerRequestOptions::no_options()) @@ -578,10 +580,12 @@ where let send_result = sink .send(Message::Notification(notification).into_json_rpc_message()) .await; - if let Err(e) = send_result { - let _ = - responder.send(Err(ServiceError::Transport(std::io::Error::other(e)))); - } + let response = if let Err(e) = send_result { + Err(ServiceError::Transport(std::io::Error::other(e))) + } else { + Ok(()) + }; + let _ = responder.send(response); if let Some(param) = cancellation_param { if let Some(responder) = local_responder_pool.remove(¶m.request_id) { tracing::info!(id = %param.request_id, reason = param.reason, "cancelled"); diff --git a/crates/rmcp/tests/test_notification.rs b/crates/rmcp/tests/test_notification.rs new file mode 100644 index 00000000..4d4c0f6e --- /dev/null +++ b/crates/rmcp/tests/test_notification.rs @@ -0,0 +1,101 @@ +use std::sync::Arc; + +use rmcp::{ + ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt, + model::{ + ResourceUpdatedNotificationParam, ServerCapabilities, ServerInfo, SubscribeRequestParam, + }, +}; +use tokio::sync::Notify; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +pub struct Server {} + +impl ServerHandler for Server { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder() + .enable_resources() + .enable_resources_subscribe() + .enable_resources_list_changed() + .build(), + ..Default::default() + } + } + + async fn subscribe( + &self, + request: rmcp::model::SubscribeRequestParam, + context: rmcp::service::RequestContext, + ) -> Result<(), rmcp::Error> { + let uri = request.uri; + let peer = context.peer; + + tokio::spawn(async move { + let span = tracing::info_span!("subscribe", uri = %uri); + let _enter = span.enter(); + + if let Err(e) = peer + .notify_resource_updated(ResourceUpdatedNotificationParam { uri: uri.clone() }) + .await + { + panic!("Failed to send notification: {}", e); + } + }); + + Ok(()) + } +} + +pub struct Client { + receive_signal: Arc, + peer: Option>, +} + +impl ClientHandler for Client { + async fn on_resource_updated(&self, params: rmcp::model::ResourceUpdatedNotificationParam) { + let uri = params.uri; + tracing::info!("Resource updated: {}", uri); + self.receive_signal.notify_one(); + } + + fn set_peer(&mut self, peer: Peer) { + self.peer.replace(peer); + } + + fn get_peer(&self) -> Option> { + self.peer.clone() + } +} + +#[tokio::test] +async fn test_server_notification() -> anyhow::Result<()> { + let _ = tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .try_init(); + let (server_transport, client_transport) = tokio::io::duplex(4096); + tokio::spawn(async move { + let server = Server {}.serve(server_transport).await?; + server.waiting().await?; + anyhow::Ok(()) + }); + let receive_signal = Arc::new(Notify::new()); + let client = Client { + peer: Default::default(), + receive_signal: receive_signal.clone(), + } + .serve(client_transport) + .await?; + client + .subscribe(SubscribeRequestParam { + uri: "test://test-resource".to_owned(), + }) + .await?; + receive_signal.notified().await; + client.cancel().await?; + Ok(()) +}