From 3255d1c66abea416b7dc0c494f85fe65b517a9e7 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 12 Jan 2026 11:51:27 +0100 Subject: [PATCH 1/5] Simple attested websocket server/client --- Cargo.lock | 48 +++++++++++++++++ Cargo.toml | 2 + src/attested_tls.rs | 40 ++++++++++++++ src/lib.rs | 1 + src/websockets.rs | 126 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 217 insertions(+) create mode 100644 src/websockets.rs diff --git a/Cargo.lock b/Cargo.lock index 6f72f84..42a990e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -147,6 +147,7 @@ dependencies = [ "clap", "configfs-tsm", "dcap-qvl", + "futures-util", "hex", "http", "http-body-util", @@ -171,6 +172,7 @@ dependencies = [ "time", "tokio", "tokio-rustls", + "tokio-tungstenite", "tower-http", "tracing", "tracing-subscriber", @@ -2624,6 +2626,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -2954,6 +2967,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.17" @@ -3167,6 +3192,23 @@ dependencies = [ "target-lexicon", ] +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror 2.0.17", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" @@ -3235,6 +3277,12 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index bf3d930..7b50d01 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,8 @@ time = "0.3.44" once_cell = "1.21.3" axum = "0.8.6" tower-http = { version = "0.6.7", features = ["fs"] } +tokio-tungstenite = "0.28.0" +futures-util = "0.3.31" [dev-dependencies] rcgen = "0.14.5" diff --git a/src/attested_tls.rs b/src/attested_tls.rs index 6460371..a23ef3a 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -501,3 +501,43 @@ fn server_name_from_host( ServerName::try_from(host_part.to_string()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::{generate_certificate_chain, generate_tls_config}; + + #[tokio::test] + async fn server_attestation() { + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let server_addr = server.local_addr().unwrap(); + + tokio::spawn(async move { + let (_stream, _measurements, _attestation_type) = server.accept().await.unwrap(); + }); + + let client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let (_stream, _measurements, _attestation_type) = + client.connect(&server_addr.to_string()).await.unwrap(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 36a738c..f7c67a4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ pub mod attested_get; pub mod attested_tls; pub mod file_server; pub mod health_check; +pub mod websockets; pub use attestation::AttestationGenerator; diff --git a/src/websockets.rs b/src/websockets.rs new file mode 100644 index 0000000..57c9fe2 --- /dev/null +++ b/src/websockets.rs @@ -0,0 +1,126 @@ +use thiserror::Error; +use tokio_tungstenite::WebSocketStream; + +use crate::{ + attestation::{measurements::MultiMeasurements, AttestationType}, + attested_tls::{AttestedTlsClient, AttestedTlsError, AttestedTlsServer}, +}; + +/// Websocket message type re-exported for convenience +pub use tokio_tungstenite::tungstenite::protocol::Message; + +// TODO allow setting ws config +pub struct AttestedWsServer { + inner: AttestedTlsServer, +} + +impl AttestedWsServer { + pub async fn accept( + &self, + ) -> Result< + ( + WebSocketStream>, + Option, + AttestationType, + ), + AttestedWsError, + > { + let (stream, measurements, attestation_type) = self.inner.accept().await?; + Ok(( + tokio_tungstenite::accept_async(stream).await.unwrap(), + measurements, + attestation_type, + )) + } +} + +pub struct AttestedWsClient { + inner: AttestedTlsClient, +} + +impl AttestedWsClient { + pub async fn connect( + &self, + server: &str, + ) -> Result< + ( + WebSocketStream>, + Option, + AttestationType, + ), + AttestedWsError, + > { + let (stream, measurements, attestation_type) = self.inner.connect(server).await?; + let (ws_connection, _response) = + tokio_tungstenite::client_async(format!("wss://{server}"), stream) + .await + .unwrap(); + Ok((ws_connection, measurements, attestation_type)) + } +} + +#[derive(Error, Debug)] +pub enum AttestedWsError { + #[error("Attested TLS: {0}")] + Rustls(#[from] AttestedTlsError), +} + +#[cfg(test)] +mod tests { + use futures_util::{sink::SinkExt, StreamExt}; + use tokio_tungstenite::tungstenite::protocol::Message; + + use super::*; + use crate::{ + attestation::{AttestationGenerator, AttestationType, AttestationVerifier}, + test_helpers::{generate_certificate_chain, generate_tls_config}, + }; + + #[tokio::test] + async fn server_attestation_websocket() { + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + "127.0.0.1:0", + AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let server_addr = server.local_addr().unwrap(); + + let ws_server = AttestedWsServer { inner: server }; + + tokio::spawn(async move { + let (mut ws_connection, _measurements, _attestation_type) = + ws_server.accept().await.unwrap(); + + ws_connection + .send(Message::Text("foo".into())) + .await + .unwrap(); + }); + + let client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let ws_client = AttestedWsClient { inner: client }; + + let (mut ws_connection, _measurements, _attestation_type) = + ws_client.connect(&server_addr.to_string()).await.unwrap(); + + let message = ws_connection.next().await.unwrap().unwrap(); + + assert_eq!(message.to_text().unwrap(), "foo"); + } +} From 607effa45a46746181934c8231d604a5e871e6f2 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 13 Jan 2026 13:34:31 +0100 Subject: [PATCH 2/5] Error handling --- src/websockets.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/websockets.rs b/src/websockets.rs index 57c9fe2..21b51b3 100644 --- a/src/websockets.rs +++ b/src/websockets.rs @@ -27,7 +27,7 @@ impl AttestedWsServer { > { let (stream, measurements, attestation_type) = self.inner.accept().await?; Ok(( - tokio_tungstenite::accept_async(stream).await.unwrap(), + tokio_tungstenite::accept_async(stream).await?, measurements, attestation_type, )) @@ -52,9 +52,8 @@ impl AttestedWsClient { > { let (stream, measurements, attestation_type) = self.inner.connect(server).await?; let (ws_connection, _response) = - tokio_tungstenite::client_async(format!("wss://{server}"), stream) - .await - .unwrap(); + tokio_tungstenite::client_async(format!("wss://{server}"), stream).await?; + Ok((ws_connection, measurements, attestation_type)) } } @@ -63,6 +62,8 @@ impl AttestedWsClient { pub enum AttestedWsError { #[error("Attested TLS: {0}")] Rustls(#[from] AttestedTlsError), + #[error("Websockets: {0}")] + Tungstenite(#[from] tokio_tungstenite::tungstenite::Error), } #[cfg(test)] From a9c7f4a67d6aebb38e855257861ab4a9a32ca4d3 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 13 Jan 2026 13:38:28 +0100 Subject: [PATCH 3/5] Gate behind feature flag --- Cargo.toml | 11 ++++++++--- src/lib.rs | 2 ++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7b50d01..ce79175 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,8 +47,8 @@ time = "0.3.44" once_cell = "1.21.3" axum = "0.8.6" tower-http = { version = "0.6.7", features = ["fs"] } -tokio-tungstenite = "0.28.0" -futures-util = "0.3.31" +tokio-tungstenite = { version = "0.28.0", optional = true } +futures-util = { version = "0.3.31", optional = true } [dev-dependencies] rcgen = "0.14.5" @@ -56,5 +56,10 @@ tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } [features] -default = ["azure"] +default = ["azure", "ws"] + +# Adds support for Microsoft Azure attestation generation and verification azure = ["tss-esapi", "az-tdx-vtpm"] + +# Adds websocket support +ws = ["tokio-tungstenite", "futures-util"] diff --git a/src/lib.rs b/src/lib.rs index f7c67a4..08d7301 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,8 @@ pub mod attested_get; pub mod attested_tls; pub mod file_server; pub mod health_check; + +#[cfg(feature = "azure")] pub mod websockets; pub use attestation::AttestationGenerator; From 9d2fabf5d9102de42cbefa2d70e2846f6e18f4cb Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 13 Jan 2026 13:51:59 +0100 Subject: [PATCH 4/5] Tidy, allow config to be passed in --- src/websockets.rs | 49 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/websockets.rs b/src/websockets.rs index 21b51b3..26a376d 100644 --- a/src/websockets.rs +++ b/src/websockets.rs @@ -1,5 +1,5 @@ use thiserror::Error; -use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream}; use crate::{ attestation::{measurements::MultiMeasurements, AttestationType}, @@ -9,12 +9,16 @@ use crate::{ /// Websocket message type re-exported for convenience pub use tokio_tungstenite::tungstenite::protocol::Message; -// TODO allow setting ws config +/// An attested Websocket server pub struct AttestedWsServer { - inner: AttestedTlsServer, + /// The underlying attested TLS server + pub inner: AttestedTlsServer, + /// Optional websocket configuration + pub websocket_config: Option, } impl AttestedWsServer { + /// Accept a Websocket connection pub async fn accept( &self, ) -> Result< @@ -27,18 +31,32 @@ impl AttestedWsServer { > { let (stream, measurements, attestation_type) = self.inner.accept().await?; Ok(( - tokio_tungstenite::accept_async(stream).await?, + tokio_tungstenite::accept_async_with_config(stream, self.websocket_config).await?, measurements, attestation_type, )) } } +impl From for AttestedWsServer { + fn from(inner: AttestedTlsServer) -> Self { + Self { + inner, + websocket_config: None, + } + } +} + +/// An attested Websocket client pub struct AttestedWsClient { - inner: AttestedTlsClient, + /// The underlying attested TLS client + pub inner: AttestedTlsClient, + /// Optional websocket configuration + pub websocket_config: Option, } impl AttestedWsClient { + /// Make a Websocket connection pub async fn connect( &self, server: &str, @@ -51,13 +69,26 @@ impl AttestedWsClient { AttestedWsError, > { let (stream, measurements, attestation_type) = self.inner.connect(server).await?; - let (ws_connection, _response) = - tokio_tungstenite::client_async(format!("wss://{server}"), stream).await?; + let (ws_connection, _response) = tokio_tungstenite::client_async_with_config( + format!("wss://{server}"), + stream, + self.websocket_config, + ) + .await?; Ok((ws_connection, measurements, attestation_type)) } } +impl From for AttestedWsClient { + fn from(inner: AttestedTlsClient) -> Self { + Self { + inner, + websocket_config: None, + } + } +} + #[derive(Error, Debug)] pub enum AttestedWsError { #[error("Attested TLS: {0}")] @@ -94,7 +125,7 @@ mod tests { let server_addr = server.local_addr().unwrap(); - let ws_server = AttestedWsServer { inner: server }; + let ws_server: AttestedWsServer = server.into(); tokio::spawn(async move { let (mut ws_connection, _measurements, _attestation_type) = @@ -115,7 +146,7 @@ mod tests { .await .unwrap(); - let ws_client = AttestedWsClient { inner: client }; + let ws_client: AttestedWsClient = client.into(); let (mut ws_connection, _measurements, _attestation_type) = ws_client.connect(&server_addr.to_string()).await.unwrap(); From 54be39f5b2d8d006f92542f5ce9509eb2f4df3e9 Mon Sep 17 00:00:00 2001 From: peg Date: Tue, 13 Jan 2026 14:11:46 +0100 Subject: [PATCH 5/5] Update following merging main --- src/attested_tls.rs | 11 +++++++---- src/websockets.rs | 43 +++++++++++++++++++++++++++++++------------ 2 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/attested_tls.rs b/src/attested_tls.rs index d7de196..f9ff343 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -505,6 +505,7 @@ fn server_name_from_host( mod tests { use super::*; use crate::test_helpers::{generate_certificate_chain, generate_tls_config}; + use tokio::net::TcpListener; #[tokio::test] async fn server_attestation() { @@ -514,17 +515,19 @@ mod tests { let server = AttestedTlsServer::new_with_tls_config( cert_chain, server_config, - "127.0.0.1:0", AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), AttestationVerifier::expect_none(), ) .await .unwrap(); - let server_addr = server.local_addr().unwrap(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); tokio::spawn(async move { - let (_stream, _measurements, _attestation_type) = server.accept().await.unwrap(); + let (tcp_stream, _) = listener.accept().await.unwrap(); + let (_stream, _measurements, _attestation_type) = + server.handle_connection(tcp_stream).await.unwrap(); }); let client = AttestedTlsClient::new_with_tls_config( @@ -537,6 +540,6 @@ mod tests { .unwrap(); let (_stream, _measurements, _attestation_type) = - client.connect(&server_addr.to_string()).await.unwrap(); + client.connect_tcp(&server_addr.to_string()).await.unwrap(); } } diff --git a/src/websockets.rs b/src/websockets.rs index 26a376d..3c9e4dc 100644 --- a/src/websockets.rs +++ b/src/websockets.rs @@ -1,4 +1,6 @@ +use std::{net::SocketAddr, sync::Arc}; use thiserror::Error; +use tokio::net::{TcpListener, ToSocketAddrs}; use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream}; use crate::{ @@ -15,9 +17,24 @@ pub struct AttestedWsServer { pub inner: AttestedTlsServer, /// Optional websocket configuration pub websocket_config: Option, + listener: Arc, } impl AttestedWsServer { + pub async fn new( + addr: impl ToSocketAddrs, + inner: AttestedTlsServer, + websocket_config: Option, + ) -> Result { + let listener = TcpListener::bind(addr).await?; + + Ok(Self { + listener: listener.into(), + inner, + websocket_config, + }) + } + /// Accept a Websocket connection pub async fn accept( &self, @@ -29,21 +46,20 @@ impl AttestedWsServer { ), AttestedWsError, > { - let (stream, measurements, attestation_type) = self.inner.accept().await?; + let (tcp_stream, _addr) = self.listener.accept().await?; + + let (stream, measurements, attestation_type) = + self.inner.handle_connection(tcp_stream).await?; Ok(( tokio_tungstenite::accept_async_with_config(stream, self.websocket_config).await?, measurements, attestation_type, )) } -} -impl From for AttestedWsServer { - fn from(inner: AttestedTlsServer) -> Self { - Self { - inner, - websocket_config: None, - } + /// Helper to get the socket address of the underlying TCP listener + pub fn local_addr(&self) -> std::io::Result { + self.listener.local_addr() } } @@ -68,7 +84,7 @@ impl AttestedWsClient { ), AttestedWsError, > { - let (stream, measurements, attestation_type) = self.inner.connect(server).await?; + let (stream, measurements, attestation_type) = self.inner.connect_tcp(server).await?; let (ws_connection, _response) = tokio_tungstenite::client_async_with_config( format!("wss://{server}"), stream, @@ -95,6 +111,8 @@ pub enum AttestedWsError { Rustls(#[from] AttestedTlsError), #[error("Websockets: {0}")] Tungstenite(#[from] tokio_tungstenite::tungstenite::Error), + #[error("IO: {0}")] + Io(#[from] std::io::Error), } #[cfg(test)] @@ -116,16 +134,17 @@ mod tests { let server = AttestedTlsServer::new_with_tls_config( cert_chain, server_config, - "127.0.0.1:0", AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), AttestationVerifier::expect_none(), ) .await .unwrap(); - let server_addr = server.local_addr().unwrap(); + let ws_server = AttestedWsServer::new("127.0.0.1:0", server, None) + .await + .unwrap(); - let ws_server: AttestedWsServer = server.into(); + let server_addr = ws_server.local_addr().unwrap(); tokio::spawn(async move { let (mut ws_connection, _measurements, _attestation_type) =