From c3f486507c7051c5a4466b3d9e4371dee0240651 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 12 Jan 2026 09:43:12 +0100 Subject: [PATCH 1/2] Make transport agnostic --- src/attested_tls.rs | 91 ++++++++++++++++++++++----------------------- src/lib.rs | 22 ++++++----- 2 files changed, 57 insertions(+), 56 deletions(-) diff --git a/src/attested_tls.rs b/src/attested_tls.rs index 6460371..c366c8b 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -12,9 +12,8 @@ use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; use x509_parser::parse_x509_certificate; use std::num::TryFromIntError; -use std::{net::SocketAddr, sync::Arc}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use std::sync::Arc; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; use tokio_rustls::rustls::RootCertStore; use tokio_rustls::{ @@ -38,16 +37,14 @@ pub struct TlsCertAndKey { pub key: PrivateKeyDer<'static>, } -/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address +/// A TLS server which makes an attestation exchange following the TLS handshake #[derive(Clone)] pub struct AttestedTlsServer { - /// The underlying TCP listener - pub listener: Arc, /// Quote generation type to use (including none) attestation_generator: AttestationGenerator, /// Verifier for remote attestation (including none) attestation_verifier: AttestationVerifier, - /// The certificate chain + /// The TLS certificate chain cert_chain: Vec>, /// For accepting TLS connections acceptor: TlsAcceptor, @@ -56,7 +53,6 @@ pub struct AttestedTlsServer { impl AttestedTlsServer { pub async fn new( cert_and_key: TlsCertAndKey, - local: impl ToSocketAddrs, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, client_auth: bool, @@ -83,7 +79,6 @@ impl AttestedTlsServer { Self::new_with_tls_config( cert_and_key.cert_chain, server_config.into(), - local, attestation_generator, attestation_verifier, ) @@ -96,15 +91,12 @@ impl AttestedTlsServer { pub(crate) async fn new_with_tls_config( cert_chain: Vec>, server_config: Arc, - local: impl ToSocketAddrs, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result { let acceptor = tokio_rustls::TlsAcceptor::from(server_config); - let listener = TcpListener::bind(local).await?; Ok(Self { - listener: listener.into(), attestation_generator, attestation_verifier, acceptor, @@ -112,39 +104,23 @@ impl AttestedTlsServer { }) } - /// Accept an incoming connection and do an attestation exchange - pub async fn accept( - &self, - ) -> Result< - ( - tokio_rustls::server::TlsStream, - Option, - AttestationType, - ), - AttestedTlsError, - > { - let (inbound, _client_addr) = self.listener.accept().await?; - - self.handle_connection(inbound).await - } - - /// Helper to get the socket address of the underlying TCP listener - pub fn local_addr(&self) -> std::io::Result { - self.listener.local_addr() - } - /// Handle an incoming connection from a proxy-client - pub async fn handle_connection( + /// + /// This is transport agnostic and will work with any asynchronous stream + pub async fn handle_connection( &self, - inbound: TcpStream, + inbound: IO, ) -> Result< ( - tokio_rustls::server::TlsStream, + tokio_rustls::server::TlsStream, Option, AttestationType, ), AttestedTlsError, - > { + > + where + IO: AsyncRead + AsyncWrite + Unpin, + { tracing::debug!("attested-tls-server accepted connection"); // Do TLS handshake @@ -296,23 +272,29 @@ impl AttestedTlsClient { }) } - /// Connect to an attested-tls-server, do TLS handshake and attestation exchange - pub async fn connect( + /// Given a connection to an attested TLS server, do a TLS handshake and attestation exchange, and return the TLS + /// stream together with measurement details + /// + /// This is transport agnostic and will work with any asynchronous stream + pub async fn connect( &self, target: &str, + outbound: IO, ) -> Result< ( - tokio_rustls::client::TlsStream, + tokio_rustls::client::TlsStream, Option, AttestationType, ), AttestedTlsError, - > { - // Make a TCP client connection and TLS handshake - let out = TcpStream::connect(&target).await?; + > + where + IO: AsyncRead + AsyncWrite + Unpin, + { + // Make a TLS handshake with the given connection let mut tls_stream = self .connector - .connect(server_name_from_host(target)?, out) + .connect(server_name_from_host(target)?, outbound) .await?; let (_io, server_connection) = tls_stream.get_ref(); @@ -374,12 +356,29 @@ impl AttestedTlsClient { Ok((tls_stream, measurements, remote_attestation_type)) } - /// Connect to an attested TLS server, retrieve the remote TLS certificate and return it + /// Make a TCP connection, do a TLS handshake and attestation exchange, and return the TLS + /// stream together with measurement details + pub async fn connect_tcp( + &self, + target: &str, + ) -> Result< + ( + tokio_rustls::client::TlsStream, + Option, + AttestationType, + ), + AttestedTlsError, + > { + let out = tokio::net::TcpStream::connect(&target).await?; + self.connect(target, out).await + } + + /// Connect to an attested TLS server using TCP, retrieve the remote TLS certificate and return it pub async fn get_tls_cert( &self, server_name: &str, ) -> Result>, AttestedTlsError> { - let (mut tls_stream, _, _) = self.connect(server_name).await?; + let (mut tls_stream, _, _) = self.connect_tcp(server_name).await?; let (_io, server_connection) = tls_stream.get_ref(); diff --git a/src/lib.rs b/src/lib.rs index 36a738c..327f6ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,15 +19,11 @@ use tracing::{error, warn}; #[cfg(test)] mod test_helpers; -use std::net::SocketAddr; -use std::num::TryFromIntError; -use std::time::Duration; +use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration}; use tokio::io; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; use tokio_rustls::rustls::pki_types::CertificateDer; -#[cfg(test)] -use std::sync::Arc; #[cfg(test)] use tokio_rustls::rustls::{ClientConfig, ServerConfig}; @@ -57,6 +53,8 @@ type Http2Sender = hyper::client::conn::http2::SendRequest, /// The address of the target service we are proxying to target: SocketAddr, } @@ -72,15 +70,17 @@ impl ProxyServer { ) -> Result { let attested_tls_server = AttestedTlsServer::new( cert_and_key, - local, attestation_generator, attestation_verifier, client_auth, ) .await?; + let listener = TcpListener::bind(local).await?; + Ok(Self { attested_tls_server, + listener: listener.into(), target, }) } @@ -100,14 +100,16 @@ impl ProxyServer { let attested_tls_server = AttestedTlsServer::new_with_tls_config( cert_chain, server_config, - local, attestation_generator, attestation_verifier, ) .await?; + let listener = TcpListener::bind(local).await?; + Ok(Self { attested_tls_server, + listener: listener.into(), target, }) } @@ -115,7 +117,7 @@ impl ProxyServer { /// Accept an incoming connection and handle it in a seperate task pub async fn accept(&self) -> Result<(), ProxyError> { let target = self.target; - let (inbound, _client_addr) = self.attested_tls_server.listener.accept().await?; + let (inbound, _client_addr) = self.listener.accept().await?; let attested_tls_server = self.attested_tls_server.clone(); tokio::spawn(async move { @@ -139,7 +141,7 @@ impl ProxyServer { /// Helper to get the socket address of the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.attested_tls_server.local_addr() + self.listener.local_addr() } /// Handle an incoming connection from a proxy-client @@ -469,7 +471,7 @@ impl ProxyClient { inner: &AttestedTlsClient, target: &str, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { - let (tls_stream, measurements, remote_attestation_type) = inner.connect(target).await?; + let (tls_stream, measurements, remote_attestation_type) = inner.connect_tcp(target).await?; // The attestation exchange is now complete - setup an HTTP client From 1ef41d50e28fefb0db70a5f1571388f892a1dafd Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 12 Jan 2026 09:57:21 +0100 Subject: [PATCH 2/2] Force single test thread in CI --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4052992..2850c37 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,4 +36,4 @@ jobs: run: cargo clippy --workspace -- -D warnings - name: Run cargo test - run: cargo test --workspace --all-targets + run: cargo test --workspace --all-targets -- --test-threads=1