diff --git a/Cargo.lock b/Cargo.lock index 6f72f84..42a990e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -147,6 +147,7 @@ dependencies = [ "clap", "configfs-tsm", "dcap-qvl", + "futures-util", "hex", "http", "http-body-util", @@ -171,6 +172,7 @@ dependencies = [ "time", "tokio", "tokio-rustls", + "tokio-tungstenite", "tower-http", "tracing", "tracing-subscriber", @@ -2624,6 +2626,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.9" @@ -2954,6 +2967,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.17" @@ -3167,6 +3192,23 @@ dependencies = [ "target-lexicon", ] +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror 2.0.17", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" @@ -3235,6 +3277,12 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index bf3d930..ce79175 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,8 @@ time = "0.3.44" once_cell = "1.21.3" axum = "0.8.6" tower-http = { version = "0.6.7", features = ["fs"] } +tokio-tungstenite = { version = "0.28.0", optional = true } +futures-util = { version = "0.3.31", optional = true } [dev-dependencies] rcgen = "0.14.5" @@ -54,5 +56,10 @@ tempfile = "3.23.0" tdx-quote = { version = "0.0.5", features = ["mock"] } [features] -default = ["azure"] +default = ["azure", "ws"] + +# Adds support for Microsoft Azure attestation generation and verification azure = ["tss-esapi", "az-tdx-vtpm"] + +# Adds websocket support +ws = ["tokio-tungstenite", "futures-util"] diff --git a/src/attested_tls.rs b/src/attested_tls.rs index c366c8b..f9ff343 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -500,3 +500,46 @@ fn server_name_from_host( ServerName::try_from(host_part.to_string()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::{generate_certificate_chain, generate_tls_config}; + use tokio::net::TcpListener; + + #[tokio::test] + async fn server_attestation() { + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let server_addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + let (tcp_stream, _) = listener.accept().await.unwrap(); + let (_stream, _measurements, _attestation_type) = + server.handle_connection(tcp_stream).await.unwrap(); + }); + + let client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let (_stream, _measurements, _attestation_type) = + client.connect_tcp(&server_addr.to_string()).await.unwrap(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 327f6ea..77d4754 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,9 @@ pub mod attested_tls; pub mod file_server; pub mod health_check; +#[cfg(feature = "azure")] +pub mod websockets; + pub use attestation::AttestationGenerator; use bytes::Bytes; diff --git a/src/websockets.rs b/src/websockets.rs new file mode 100644 index 0000000..3c9e4dc --- /dev/null +++ b/src/websockets.rs @@ -0,0 +1,177 @@ +use std::{net::SocketAddr, sync::Arc}; +use thiserror::Error; +use tokio::net::{TcpListener, ToSocketAddrs}; +use tokio_tungstenite::{tungstenite::protocol::WebSocketConfig, WebSocketStream}; + +use crate::{ + attestation::{measurements::MultiMeasurements, AttestationType}, + attested_tls::{AttestedTlsClient, AttestedTlsError, AttestedTlsServer}, +}; + +/// Websocket message type re-exported for convenience +pub use tokio_tungstenite::tungstenite::protocol::Message; + +/// An attested Websocket server +pub struct AttestedWsServer { + /// The underlying attested TLS server + pub inner: AttestedTlsServer, + /// Optional websocket configuration + pub websocket_config: Option, + listener: Arc, +} + +impl AttestedWsServer { + pub async fn new( + addr: impl ToSocketAddrs, + inner: AttestedTlsServer, + websocket_config: Option, + ) -> Result { + let listener = TcpListener::bind(addr).await?; + + Ok(Self { + listener: listener.into(), + inner, + websocket_config, + }) + } + + /// Accept a Websocket connection + pub async fn accept( + &self, + ) -> Result< + ( + WebSocketStream>, + Option, + AttestationType, + ), + AttestedWsError, + > { + let (tcp_stream, _addr) = self.listener.accept().await?; + + let (stream, measurements, attestation_type) = + self.inner.handle_connection(tcp_stream).await?; + Ok(( + tokio_tungstenite::accept_async_with_config(stream, self.websocket_config).await?, + measurements, + attestation_type, + )) + } + + /// Helper to get the socket address of the underlying TCP listener + pub fn local_addr(&self) -> std::io::Result { + self.listener.local_addr() + } +} + +/// An attested Websocket client +pub struct AttestedWsClient { + /// The underlying attested TLS client + pub inner: AttestedTlsClient, + /// Optional websocket configuration + pub websocket_config: Option, +} + +impl AttestedWsClient { + /// Make a Websocket connection + pub async fn connect( + &self, + server: &str, + ) -> Result< + ( + WebSocketStream>, + Option, + AttestationType, + ), + AttestedWsError, + > { + let (stream, measurements, attestation_type) = self.inner.connect_tcp(server).await?; + let (ws_connection, _response) = tokio_tungstenite::client_async_with_config( + format!("wss://{server}"), + stream, + self.websocket_config, + ) + .await?; + + Ok((ws_connection, measurements, attestation_type)) + } +} + +impl From for AttestedWsClient { + fn from(inner: AttestedTlsClient) -> Self { + Self { + inner, + websocket_config: None, + } + } +} + +#[derive(Error, Debug)] +pub enum AttestedWsError { + #[error("Attested TLS: {0}")] + Rustls(#[from] AttestedTlsError), + #[error("Websockets: {0}")] + Tungstenite(#[from] tokio_tungstenite::tungstenite::Error), + #[error("IO: {0}")] + Io(#[from] std::io::Error), +} + +#[cfg(test)] +mod tests { + use futures_util::{sink::SinkExt, StreamExt}; + use tokio_tungstenite::tungstenite::protocol::Message; + + use super::*; + use crate::{ + attestation::{AttestationGenerator, AttestationType, AttestationVerifier}, + test_helpers::{generate_certificate_chain, generate_tls_config}, + }; + + #[tokio::test] + async fn server_attestation_websocket() { + let (cert_chain, private_key) = generate_certificate_chain("127.0.0.1".parse().unwrap()); + let (server_config, client_config) = generate_tls_config(cert_chain.clone(), private_key); + + let server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + AttestationGenerator::new_not_dummy(AttestationType::DcapTdx).unwrap(), + AttestationVerifier::expect_none(), + ) + .await + .unwrap(); + + let ws_server = AttestedWsServer::new("127.0.0.1:0", server, None) + .await + .unwrap(); + + let server_addr = ws_server.local_addr().unwrap(); + + tokio::spawn(async move { + let (mut ws_connection, _measurements, _attestation_type) = + ws_server.accept().await.unwrap(); + + ws_connection + .send(Message::Text("foo".into())) + .await + .unwrap(); + }); + + let client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + AttestationVerifier::mock(), + None, + ) + .await + .unwrap(); + + let ws_client: AttestedWsClient = client.into(); + + let (mut ws_connection, _measurements, _attestation_type) = + ws_client.connect(&server_addr.to_string()).await.unwrap(); + + let message = ws_connection.next().await.unwrap().unwrap(); + + assert_eq!(message.to_text().unwrap(), "foo"); + } +}