diff --git a/crates/rmcp/src/transport/sse.rs b/crates/rmcp/src/transport/sse.rs index 6854e370..201f138b 100644 --- a/crates/rmcp/src/transport/sse.rs +++ b/crates/rmcp/src/transport/sse.rs @@ -14,13 +14,13 @@ const MIME_TYPE: &str = "text/event-stream"; const HEADER_LAST_EVENT_ID: &str = "Last-Event-ID"; #[derive(Error, Debug)] -pub enum SseTransportError { +pub enum SseTransportError { #[error("SSE error: {0}")] Sse(#[from] SseError), #[error("IO error: {0}")] Io(#[from] std::io::Error), - #[error("Reqwest error: {0}")] - Reqwest(#[from] reqwest::Error), + #[error("Transport error: {0}")] + Transport(E), #[error("unexpected end of stream")] UnexpectedEndOfStream, #[error("Url error: {0}")] @@ -29,14 +29,14 @@ pub enum SseTransportError { UnexpectedContentType(Option), } -enum SseTransportState { +type SseStreamFuture = + BoxFuture<'static, Result>, SseTransportError>>; + +enum SseTransportState { Connected(BoxStream<'static, Result>), Retrying { times: usize, - fut: BoxFuture< - 'static, - Result>, SseTransportError>, - >, + fut: SseStreamFuture, }, Fatal { reason: String, @@ -60,67 +60,159 @@ impl Default for SseTransportRetryConfig { } } -/// # Transport for client sse -/// -/// Call [`SseTransport::start`] to create a new transport from url. -/// -/// Call [`SseTransport::start_with_client`] to create a new transport with a customized reqwest client. -pub struct SseTransport { +impl From for SseTransportError { + fn from(e: reqwest::Error) -> Self { + SseTransportError::Transport(e) + } +} + +pub trait SseClient: Clone + Send + Sync + 'static { + fn connect(&self, last_event_id: Option) -> SseStreamFuture; + + fn post( + &self, + endpoint: &str, + message: ClientJsonRpcMessage, + ) -> BoxFuture<'static, Result<(), SseTransportError>>; +} + +pub struct RetryConfig { + pub max_times: Option, + pub min_duration: Duration, +} + +#[derive(Clone)] +pub struct ReqwestSseClient { http_client: HttpClient, - state: SseTransportState, - post_url: Arc, - sse_url: Arc, - last_event_id: Option, - recommended_retry_duration_ms: Option, - #[allow(clippy::type_complexity)] - request_queue: VecDeque>>, - pub retry_config: SseTransportRetryConfig, + sse_url: Url, } +impl ReqwestSseClient { + pub fn new(url: U) -> Result> + where + U: IntoUrl, + { + let url = url.into_url()?; + Ok(Self { + http_client: HttpClient::default(), + sse_url: url, + }) + } -impl SseTransport { - pub async fn start_with_timeout(url: U, timeout: Duration) -> Result + pub async fn new_with_timeout( + url: U, + timeout: Duration, + ) -> Result> where U: IntoUrl, { let mut client = HttpClient::builder(); client = client.timeout(timeout); let client = client.build()?; - Self::start_with_client(url, client).await + let url = url.into_url()?; + Ok(Self { + http_client: client, + sse_url: url, + }) } - pub async fn start(url: U) -> Result + pub async fn new_with_client( + url: U, + client: HttpClient, + ) -> Result> where U: IntoUrl, { - Self::start_with_client(url, HttpClient::default()).await + let url = url.into_url()?; + Ok(Self { + http_client: client, + sse_url: url, + }) + } +} + +impl SseClient for ReqwestSseClient { + fn connect(&self, last_event_id: Option) -> SseStreamFuture { + let client = self.http_client.clone(); + let sse_url = self.sse_url.as_ref().to_string(); + let last_event_id = last_event_id.clone(); + let fut = async move { + let mut request_builder = client.get(&sse_url).header(ACCEPT, MIME_TYPE); + if let Some(last_event_id) = last_event_id { + request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); + } + let response = request_builder.send().await?; + let response = response.error_for_status()?; + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(ct) => { + if !ct.as_bytes().starts_with(MIME_TYPE.as_bytes()) { + return Err(SseTransportError::UnexpectedContentType(Some(ct.clone()))); + } + } + None => { + return Err(SseTransportError::UnexpectedContentType(None)); + } + } + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(event_stream) + }; + fut.boxed() } - /// Start with a reqwest client, this would be helpful when you want to customize the client behavior like default headers or tls stuff. - pub async fn start_with_client(url: U, client: HttpClient) -> Result + fn post( + &self, + session_id: &str, + message: ClientJsonRpcMessage, + ) -> BoxFuture<'static, Result<(), SseTransportError>> { + let client = self.http_client.clone(); + let sse_url = self.sse_url.clone(); + let session_id = session_id.to_string(); + Box::pin(async move { + let uri = sse_url.join(&session_id).map_err(SseTransportError::from)?; + let request_builder = client.post(uri.as_ref()).json(&message); + request_builder + .send() + .await + .and_then(|resp| resp.error_for_status()) + .map_err(SseTransportError::from) + .map(drop) + }) + } +} + +/// # Transport for client sse +/// +/// Call [`SseTransport::start`] to create a new transport from url. +/// +/// Call [`SseTransport::start_with_client`] to create a new transport with a customized reqwest client. +pub struct SseTransport, E: std::error::Error + Send + Sync + 'static> { + client: Arc, + state: SseTransportState, + last_event_id: Option, + recommended_retry_duration_ms: Option, + session_id: String, + #[allow(clippy::type_complexity)] + request_queue: VecDeque>>>, + pub retry_config: SseTransportRetryConfig, +} + +impl SseTransport { + pub async fn start( + url: U, + ) -> Result, SseTransportError> where U: IntoUrl, { - let url = url.into_url()?; - let response = client - .get(url.clone()) - .header(ACCEPT, MIME_TYPE) - .send() - .await?; - let response = response.error_for_status()?; - match response.headers().get(reqwest::header::CONTENT_TYPE) { - Some(ct) => { - if !ct.as_bytes().starts_with(MIME_TYPE.as_bytes()) { - return Err(SseTransportError::UnexpectedContentType(Some(ct.clone()))); - } - } - None => { - return Err(SseTransportError::UnexpectedContentType(None)); - } - } - let mut event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + let client = ReqwestSseClient::new(url)?; + SseTransport::start_with_client(client).await + } +} + +impl, E: std::error::Error + Send + Sync + 'static> SseTransport { + pub async fn start_with_client(client: C) -> Result> { + let mut event_stream = client.connect(None).await?; let mut last_event_id = None; let mut retry = None; - let post_url = loop { + let session_id = loop { let next_event = event_stream .next() .await @@ -135,62 +227,34 @@ impl SseTransport { break next_event.data.unwrap_or_default(); } }; - tracing::info!("will post event on {post_url}"); - let post_url = url.join(&post_url)?; Ok(SseTransport { - http_client: client, + client: Arc::new(client), state: SseTransportState::Connected(Box::pin(event_stream)), - post_url: Arc::from(post_url), last_event_id, recommended_retry_duration_ms: retry, - sse_url: Arc::from(url), + session_id, request_queue: Default::default(), retry_config: Default::default(), }) } - fn retry_connection( - &self, - ) -> BoxFuture<'static, Result>, SseTransportError>> - { + fn retry_connection(&self) -> SseStreamFuture { let retry_duration = { let recommended_retry_duration = self .recommended_retry_duration_ms .map(Duration::from_millis); let config_retry_duration = self.retry_config.min_duration; recommended_retry_duration - .map(|d| d.max(config_retry_duration)) + .map(|d: Duration| d.max(config_retry_duration)) .unwrap_or(config_retry_duration) }; - let client = self.http_client.clone(); - let sse_url = self.sse_url.as_ref().clone(); + std::thread::sleep(retry_duration); let last_event_id = self.last_event_id.clone(); - let fut = async move { - tokio::time::sleep(retry_duration).await; - let mut request_builder = client.get(sse_url).header(ACCEPT, MIME_TYPE); - if let Some(last_event_id) = last_event_id { - request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); - } - let response = request_builder.send().await?; - let response = response.error_for_status()?; - match response.headers().get(reqwest::header::CONTENT_TYPE) { - Some(ct) => { - if ct.as_bytes() != MIME_TYPE.as_bytes() { - return Err(SseTransportError::UnexpectedContentType(Some(ct.clone()))); - } - } - None => { - return Err(SseTransportError::UnexpectedContentType(None)); - } - } - let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); - Ok(event_stream) - }; - fut.boxed() + self.client.connect(last_event_id) } } -impl Stream for SseTransport { +impl, E: std::error::Error + Send + Sync + 'static> Stream for SseTransport { type Item = ServerJsonRpcMessage; fn poll_next( @@ -198,17 +262,16 @@ impl Stream for SseTransport { cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { let retry_config = self.retry_config; - let state = &mut self.state; - match state { + match &mut self.state { SseTransportState::Connected(event_stream) => { let event = std::task::ready!(event_stream.poll_next_unpin(cx)); match event { Some(Ok(event)) => { if let Some(retry) = event.retry { - self.recommended_retry_duration_ms = Some(retry); + self.as_mut().recommended_retry_duration_ms = Some(retry); } if let Some(id) = event.id { - self.last_event_id = Some(id); + self.as_mut().last_event_id = Some(id); } if let Some(data) = event.data { match serde_json::from_str(&data) { @@ -226,7 +289,7 @@ impl Stream for SseTransport { Some(Err(e)) => { tracing::error!(error = %e, "sse event stream encounter an error"); let fut = self.retry_connection(); - self.state = SseTransportState::Retrying { times: 1, fut }; + self.as_mut().state = SseTransportState::Retrying { times: 1, fut }; self.poll_next(cx) } None => std::task::Poll::Ready(None), @@ -236,14 +299,14 @@ impl Stream for SseTransport { let retry_result = std::task::ready!(fut.poll_unpin(cx)); match retry_result { Ok(stream) => { - self.state = SseTransportState::Connected(stream); + self.as_mut().state = SseTransportState::Connected(stream); self.poll_next(cx) } Err(e) => { tracing::warn!(error = %e, "retrying failed"); if let Some(max_retry_times) = retry_config.max_times { if *times >= max_retry_times { - self.state = SseTransportState::Fatal { + self.as_mut().state = SseTransportState::Fatal { reason: format!("retrying failed after {} times: {}", times, e), }; return self.poll_next(cx); @@ -251,7 +314,7 @@ impl Stream for SseTransport { } let times = *times + 1; let fut = self.retry_connection(); - self.state = SseTransportState::Retrying { times, fut }; + self.as_mut().state = SseTransportState::Retrying { times, fut }; self.poll_next(cx) } } @@ -264,17 +327,20 @@ impl Stream for SseTransport { } } -impl Sink for SseTransport { - type Error = SseTransportError; +impl, E: std::error::Error + Send + Sync + 'static> Sink + for SseTransport +{ + type Error = SseTransportError; fn poll_ready( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { const QUEUE_SIZE: usize = 16; - if self.request_queue.len() >= QUEUE_SIZE { + if self.as_mut().request_queue.len() >= QUEUE_SIZE { std::task::ready!( - self.request_queue + self.as_mut() + .request_queue .front_mut() .expect("queue is not empty") .poll_unpin(cx) @@ -288,21 +354,16 @@ impl Sink for SseTransport { mut self: std::pin::Pin<&mut Self>, item: ClientJsonRpcMessage, ) -> Result<(), Self::Error> { - let client = self.http_client.clone(); - let uri = self.post_url.clone(); - let (tx, rx) = tokio::sync::oneshot::channel(); - let request_builder = client.post(uri.as_ref().clone()).json(&item); + let client = self.client.clone(); + let session_id = self.session_id.clone(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let session_id = session_id.clone(); tokio::spawn(async move { - let result = request_builder - .send() - .await - .and_then(|resp| resp.error_for_status()) - .map_err(SseTransportError::from) - .map(drop); + let result = { client.post(&session_id, item).await }; let _ = tx.send(result); }); - self.as_mut().request_queue.push_back(rx); + self.request_queue.push_back(rx); Ok(()) }