From 1aed046e51e813c576a81029dcf47885bf26c705 Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 9 Jan 2026 09:55:28 +0100 Subject: [PATCH 1/7] Add attested TLS module --- src/attested_tls.rs | 555 ++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 2 files changed, 556 insertions(+) create mode 100644 src/attested_tls.rs diff --git a/src/attested_tls.rs b/src/attested_tls.rs new file mode 100644 index 0000000..ea26568 --- /dev/null +++ b/src/attested_tls.rs @@ -0,0 +1,555 @@ +use crate::attestation::{ + measurements::MultiMeasurements, AttestationError, AttestationGenerator, AttestationType, +}; +use parity_scale_codec::{Decode, Encode}; +use sha2::{Digest, Sha256}; +use thiserror::Error; +use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; +use tracing::{error, warn}; +use x509_parser::parse_x509_certificate; + +use std::num::TryFromIntError; +use std::{net::SocketAddr, sync::Arc}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; +use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use tokio_rustls::rustls::RootCertStore; +use tokio_rustls::{ + rustls::{ClientConfig, ServerConfig}, + TlsAcceptor, TlsConnector, +}; + +use crate::attestation::{AttestationExchangeMessage, AttestationVerifier}; + +/// This makes it possible to add breaking protocol changes and provide backwards compatibility. +/// When adding more supported versions, note that ordering is important. ALPN will pick the first +/// protocol which both parties support - so newer supported versions should come first. +pub const SUPPORTED_ALPN_PROTOCOL_VERSIONS: [&[u8]; 1] = [b"flashbots-ratls/1"]; + +/// The label used when exporting key material from a TLS session +const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; + +/// TLS Credentials +pub struct TlsCertAndKey { + /// Der-encoded TLS certificate chain + pub cert_chain: Vec>, + /// Der-encoded TLS private key + pub key: PrivateKeyDer<'static>, +} + +/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address +pub struct AttestedTlsServer { + /// The underlying TCP listener + listener: TcpListener, + /// Quote generation type to use (including none) + attestation_generator: AttestationGenerator, + /// Verifier for remote attestation (including none) + attestation_verifier: AttestationVerifier, + /// The certificate chain + cert_chain: Vec>, + /// For accepting TLS connections + acceptor: TlsAcceptor, +} + +impl AttestedTlsServer { + pub async fn new( + cert_and_key: TlsCertAndKey, + local: impl ToSocketAddrs, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + client_auth: bool, + ) -> Result { + let mut server_config = if client_auth { + let root_store = + RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; + + ServerConfig::builder() + .with_client_cert_verifier(verifier) + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? + } else { + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? + }; + + server_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + + Self::new_with_tls_config( + cert_and_key.cert_chain, + server_config.into(), + local, + attestation_generator, + attestation_verifier, + ) + .await + } + + /// Start with preconfigured TLS + /// + /// This is not public as it allows dangerous configuration + async fn new_with_tls_config( + cert_chain: Vec>, + server_config: Arc, + local: impl ToSocketAddrs, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + ) -> Result { + let acceptor = tokio_rustls::TlsAcceptor::from(server_config); + let listener = TcpListener::bind(local).await?; + + Ok(Self { + listener, + attestation_generator, + attestation_verifier, + acceptor, + cert_chain, + }) + } + + /// Accept an incoming connection and handle it in a seperate task + pub async fn accept(&self) -> Result<(), AttestedTlsError> { + let (inbound, _client_addr) = self.listener.accept().await?; + + let acceptor = self.acceptor.clone(); + let cert_chain = self.cert_chain.clone(); + let attestation_generator = self.attestation_generator.clone(); + let attestation_verifier = self.attestation_verifier.clone(); + tokio::spawn(async move { + if let Err(err) = Self::handle_connection( + inbound, + acceptor, + cert_chain, + attestation_generator, + attestation_verifier, + ) + .await + { + warn!("Failed to handle connection: {err}"); + } + }); + + Ok(()) + } + + /// Helper to get the socket address of the underlying TCP listener + pub fn local_addr(&self) -> std::io::Result { + self.listener.local_addr() + } + + /// Handle an incoming connection from a proxy-client + async fn handle_connection( + inbound: TcpStream, + acceptor: TlsAcceptor, + cert_chain: Vec>, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + ) -> Result< + ( + tokio_rustls::server::TlsStream, + Option, + AttestationType, + ), + AttestedTlsError, + > { + tracing::debug!("attested-tls-server accepted connection"); + + // Do TLS handshake + let mut tls_stream = acceptor.accept(inbound).await?; + let (_io, connection) = tls_stream.get_ref(); + + // Ensure that we agreed a protocol + let _negotiated_protocol = connection + .alpn_protocol() + .ok_or(AttestedTlsError::AlpnFailed)?; + + // Compute an exporter unique to the session + let mut exporter = [0u8; 32]; + connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; + + let input_data = compute_report_input(Some(&cert_chain), exporter)?; + + // Get the TLS certficate chain of the client, if there is one + let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); + + // If we are in a CVM, generate an attestation + let attestation = attestation_generator + .generate_attestation(input_data) + .await? + .encode(); + + // Write our attestation to the channel, with length prefix + let attestation_length_prefix = length_prefix(&attestation); + tls_stream.write_all(&attestation_length_prefix).await?; + tls_stream.write_all(&attestation).await?; + + // Now read a length-prefixed attestation from the remote peer + // In the case of no client attestation this will be zero bytes + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; + + let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; + let remote_attestation_type = remote_attestation_message.attestation_type; + + // If we expect an attestaion from the client, verify it and get measurements + let measurements = if attestation_verifier.has_remote_attestion() { + let remote_input_data = compute_report_input(remote_cert_chain.as_deref(), exporter)?; + + attestation_verifier + .verify_attestation(remote_attestation_message, remote_input_data) + .await? + } else { + None + }; + + Ok((tls_stream, measurements, remote_attestation_type)) + } +} + +/// A proxy client which forwards http traffic to a proxy-server +pub struct AttestedTlsClient { + /// The underlying TCP listener + listener: TcpListener, + /// The connector for making TLS connections with out configuration + connector: TlsConnector, + /// Quote generation type to use (including none) + attestation_generator: AttestationGenerator, + /// Verifier for remote attestation (including none) + attestation_verifier: AttestationVerifier, + /// The certificate chain for client auth + cert_chain: Option>>, +} + +impl std::fmt::Debug for AttestedTlsClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AttestedTlsClient") + .field("listener", &self.listener) + .finish() + } +} + +impl AttestedTlsClient { + /// Start with optional TLS client auth + pub async fn new( + cert_and_key: Option, + address: impl ToSocketAddrs, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + remote_certificate: Option>, + ) -> Result { + // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots + let root_store = match remote_certificate { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate)?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), + }; + + // Setup TLS client configuration, with or without client auth + let mut client_config = if let Some(ref cert_and_key) = cert_and_key { + ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert( + cert_and_key.cert_chain.clone(), + cert_and_key.key.clone_key(), + )? + } else { + ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + + Self::new_with_tls_config( + client_config.into(), + address, + attestation_generator, + attestation_verifier, + cert_and_key.map(|c| c.cert_chain), + ) + .await + } + + /// Create a new proxy client with given TLS configuration + /// + /// This is private as it allows dangerous configuration but is used in tests + async fn new_with_tls_config( + client_config: Arc, + local: impl ToSocketAddrs, + attestation_generator: AttestationGenerator, + attestation_verifier: AttestationVerifier, + cert_chain: Option>>, + ) -> Result { + // Setup TCP server and TLS client + let listener = TcpListener::bind(local).await?; + let connector = TlsConnector::from(client_config.clone()); + + Ok(Self { + listener, + connector, + attestation_generator, + attestation_verifier, + cert_chain, + }) + } + + /// Helper to return the local socket address from the underlying TCP listener + pub fn local_addr(&self) -> std::io::Result { + self.listener.local_addr() + } + + /// Connect to the attested-tls-server, do TLS handshake and remote attestation + pub async fn connect( + &self, + target: String, + ) -> Result< + ( + tokio_rustls::client::TlsStream, + Option, + AttestationType, + ), + AttestedTlsError, + > { + // Make a TCP client connection and TLS handshake + let out = TcpStream::connect(&target).await?; + let mut tls_stream = self + .connector + .connect(server_name_from_host(&target)?, out) + .await?; + + let (_io, server_connection) = tls_stream.get_ref(); + + // Ensure that we agreed a protocol + let _negotiated_protocol = server_connection + .alpn_protocol() + .ok_or(AttestedTlsError::AlpnFailed)?; + + // Compute an exporter unique to the channel + let mut exporter = [0u8; 32]; + server_connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; + + // Get the TLS certificate chain of the server + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(AttestedTlsError::NoCertificate)? + .to_owned(); + + let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; + + // Read a length prefixed attestation from the proxy-server + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; + + let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; + let remote_attestation_type = remote_attestation_message.attestation_type; + + // Verify the remote attestation against our accepted measurements + let measurements = self + .attestation_verifier + .verify_attestation(remote_attestation_message, remote_input_data) + .await?; + + // If we are in a CVM, provide an attestation + let attestation = if self.attestation_generator.attestation_type != AttestationType::None { + let local_input_data = compute_report_input(self.cert_chain.as_deref(), exporter)?; + self.attestation_generator + .generate_attestation(local_input_data) + .await? + .encode() + } else { + AttestationExchangeMessage::without_attestation().encode() + }; + + // Send our attestation (or zero bytes) prefixed with length + let attestation_length_prefix = length_prefix(&attestation); + tls_stream.write_all(&attestation_length_prefix).await?; + tls_stream.write_all(&attestation).await?; + + Ok((tls_stream, measurements, remote_attestation_type)) + } +} + +/// Just get the attested remote certificate, with no client authentication +pub async fn get_tls_cert( + server_name: String, + attestation_verifier: AttestationVerifier, + remote_certificate: Option>, +) -> Result>, AttestedTlsError> { + tracing::debug!("Getting remote TLS cert"); + // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots + let root_store = match remote_certificate { + Some(remote_certificate) => { + let mut root_store = RootCertStore::empty(); + root_store.add(remote_certificate)?; + root_store + } + None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), + }; + + let mut client_config = ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS + .into_iter() + .map(|p| p.to_vec()) + .collect(); + + get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await +} + +async fn get_tls_cert_with_config( + server_name: String, + attestation_verifier: AttestationVerifier, + client_config: Arc, +) -> Result>, AttestedTlsError> { + let connector = TlsConnector::from(client_config); + + let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?; + let mut tls_stream = connector + .connect(server_name_from_host(&server_name)?, out) + .await?; + + let (_io, server_connection) = tls_stream.get_ref(); + + let mut exporter = [0u8; 32]; + server_connection.export_keying_material( + &mut exporter, + EXPORTER_LABEL, + None, // context + )?; + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(AttestedTlsError::NoCertificate)? + .to_owned(); + + let mut length_bytes = [0; 4]; + tls_stream.read_exact(&mut length_bytes).await?; + let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + + let mut buf = vec![0; length]; + tls_stream.read_exact(&mut buf).await?; + + let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; + + let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; + + let _measurements = attestation_verifier + .verify_attestation(remote_attestation_message, remote_input_data) + .await?; + + tls_stream.shutdown().await?; + + Ok(remote_cert_chain) +} + +/// Given a certificate chain and an exporter (session key material), build the quote input value +/// SHA256(pki) || exporter +pub fn compute_report_input( + cert_chain: Option<&[CertificateDer<'_>]>, + exporter: [u8; 32], +) -> Result<[u8; 64], AttestationError> { + let mut quote_input = [0u8; 64]; + if let Some(cert_chain) = cert_chain { + let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; + quote_input[..32].copy_from_slice(&pki_hash); + } + quote_input[32..].copy_from_slice(&exporter); + Ok(quote_input) +} + +/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate +fn get_pki_hash_from_certificate_chain( + cert_chain: &[CertificateDer<'_>], +) -> Result<[u8; 32], AttestationError> { + let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; + let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; + let public_key = &cert.tbs_certificate.subject_pki; + let key_bytes = public_key.subject_public_key.as_ref(); + + let mut hasher = Sha256::new(); + hasher.update(key_bytes); + Ok(hasher.finalize().into()) +} + +/// An error when running an attested TLS client or server +#[derive(Error, Debug)] +pub enum AttestedTlsError { + #[error("Failed to get server ceritifcate")] + NoCertificate, + #[error("TLS: {0}")] + Rustls(#[from] tokio_rustls::rustls::Error), + #[error("Verifier builder: {0}")] + VerifierBuilder(#[from] VerifierBuilderError), + #[error("IO: {0}")] + Io(#[from] std::io::Error), + #[error("Attestation: {0}")] + Attestation(#[from] AttestationError), + #[error("Integer conversion: {0}")] + IntConversion(#[from] TryFromIntError), + #[error("Bad host name: {0}")] + BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), + #[error("HTTP: {0}")] + Hyper(#[from] hyper::Error), + #[error("JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("Serialization: {0}")] + Serialization(#[from] parity_scale_codec::Error), + #[error("Protocol negotiation failed - remote peer does not support this protocol")] + AlpnFailed, +} + +/// Given a byte array, encode its length as a 4 byte big endian u32 +fn length_prefix(input: &[u8]) -> [u8; 4] { + let len = input.len() as u32; + len.to_be_bytes() +} + +/// If no port was provided, default to 443 +fn host_to_host_with_port(host: &str) -> String { + if host.contains(':') { + host.to_string() + } else { + format!("{host}:443") + } +} + +/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part +fn server_name_from_host( + host: &str, +) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> { + // If host contains ':', try to split off the port. + let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host); + + // If the host is an IPv6 literal in brackets like "[::1]:443", + // remove the brackets for SNI (SNI allows bare IPv6 too). + let host_part = host_part.trim_matches(|c| c == '[' || c == ']'); + + ServerName::try_from(host_part.to_string()) +} diff --git a/src/lib.rs b/src/lib.rs index e319fe9..46ddbf8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod attestation; pub mod attested_get; +pub mod attested_tls; pub mod file_server; pub mod health_check; From a4690d3ce5088baa8b30da38b60f9039b0d6350b Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 9 Jan 2026 10:23:17 +0100 Subject: [PATCH 2/7] Use attested-tls-server from refactored module --- src/attested_tls.rs | 42 +++++++------ src/lib.rs | 144 ++++++++++---------------------------------- src/main.rs | 3 +- 3 files changed, 56 insertions(+), 133 deletions(-) diff --git a/src/attested_tls.rs b/src/attested_tls.rs index ea26568..55037f5 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -5,7 +5,7 @@ use parity_scale_codec::{Decode, Encode}; use sha2::{Digest, Sha256}; use thiserror::Error; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; -use tracing::{error, warn}; +use tracing::error; use x509_parser::parse_x509_certificate; use std::num::TryFromIntError; @@ -27,7 +27,7 @@ use crate::attestation::{AttestationExchangeMessage, AttestationVerifier}; pub const SUPPORTED_ALPN_PROTOCOL_VERSIONS: [&[u8]; 1] = [b"flashbots-ratls/1"]; /// The label used when exporting key material from a TLS session -const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; +pub(crate) const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; /// TLS Credentials pub struct TlsCertAndKey { @@ -90,8 +90,8 @@ impl AttestedTlsServer { /// Start with preconfigured TLS /// - /// This is not public as it allows dangerous configuration - async fn new_with_tls_config( + /// This is not fully public as it allows dangerous configuration + pub(crate) async fn new_with_tls_config( cert_chain: Vec>, server_config: Arc, local: impl ToSocketAddrs, @@ -111,28 +111,30 @@ impl AttestedTlsServer { } /// Accept an incoming connection and handle it in a seperate task - pub async fn accept(&self) -> Result<(), AttestedTlsError> { + pub async fn accept( + &self, + ) -> Result< + ( + tokio_rustls::server::TlsStream, + Option, + AttestationType, + ), + AttestedTlsError, + > { let (inbound, _client_addr) = self.listener.accept().await?; let acceptor = self.acceptor.clone(); let cert_chain = self.cert_chain.clone(); let attestation_generator = self.attestation_generator.clone(); let attestation_verifier = self.attestation_verifier.clone(); - tokio::spawn(async move { - if let Err(err) = Self::handle_connection( - inbound, - acceptor, - cert_chain, - attestation_generator, - attestation_verifier, - ) - .await - { - warn!("Failed to handle connection: {err}"); - } - }); - - Ok(()) + Ok(Self::handle_connection( + inbound, + acceptor, + cert_chain, + attestation_generator, + attestation_verifier, + ) + .await?) } /// Helper to get the socket address of the underlying TCP listener diff --git a/src/lib.rs b/src/lib.rs index 46ddbf8..161b088 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ pub mod file_server; pub mod health_check; pub use attestation::AttestationGenerator; -use attestation::{measurements::MultiMeasurements, AttestationError, AttestationType}; + use bytes::Bytes; use http::HeaderValue; use http_body_util::{combinators::BoxBody, BodyExt}; @@ -27,22 +27,21 @@ use std::time::Duration; use std::{net::SocketAddr, sync::Arc}; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName}; +use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName}; use tokio_rustls::rustls::RootCertStore; use tokio_rustls::{ rustls::{ClientConfig, ServerConfig}, - TlsAcceptor, TlsConnector, + TlsConnector, }; -use crate::attestation::{AttestationExchangeMessage, AttestationVerifier}; - -/// This makes it possible to add breaking protocol changes and provide backwards compatibility. -/// When adding more supported versions, note that ordering is important. ALPN will pick the first -/// protocol which both parties support - so newer supported versions should come first. -pub const SUPPORTED_ALPN_PROTOCOL_VERSIONS: [&[u8]; 1] = [b"flashbots-ratls/1"]; - -/// The label used when exporting key material from a TLS session -const EXPORTER_LABEL: &[u8; 24] = b"EXPORTER-Channel-Binding"; +use crate::attestation::{ + measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage, AttestationType, + AttestationVerifier, +}; +use crate::attested_tls::{ + AttestedTlsError, AttestedTlsServer, TlsCertAndKey, EXPORTER_LABEL, + SUPPORTED_ALPN_PROTOCOL_VERSIONS, +}; /// The header name for giving attestation type const ATTESTATION_TYPE_HEADER: &str = "X-Flashbots-Attestation-Type"; @@ -59,26 +58,9 @@ type RequestWithResponseSender = ( ); type Http2Sender = hyper::client::conn::http2::SendRequest; -/// TLS Credentials -pub struct TlsCertAndKey { - /// Der-encoded TLS certificate chain - pub cert_chain: Vec>, - /// Der-encoded TLS private key - pub key: PrivateKeyDer<'static>, -} - /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address pub struct ProxyServer { - /// The underlying TCP listener - listener: TcpListener, - /// Quote generation type to use (including none) - attestation_generator: AttestationGenerator, - /// Verifier for remote attestation (including none) - attestation_verifier: AttestationVerifier, - /// The certificate chain - cert_chain: Vec>, - /// For accepting TLS connections - acceptor: TlsAcceptor, + attested_tls_server: AttestedTlsServer, /// The address of the target service we are proxying to target: SocketAddr, } @@ -133,38 +115,30 @@ impl ProxyServer { attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, ) -> Result { - let acceptor = tokio_rustls::TlsAcceptor::from(server_config); - let listener = TcpListener::bind(local).await?; - - Ok(Self { - listener, + let attested_tls_server = AttestedTlsServer::new_with_tls_config( + cert_chain, + server_config, + local, attestation_generator, attestation_verifier, - acceptor, + ) + .await?; + + Ok(Self { + attested_tls_server, target, - cert_chain, }) } /// Accept an incoming connection and handle it in a seperate task pub async fn accept(&self) -> Result<(), ProxyError> { - let (inbound, _client_addr) = self.listener.accept().await?; + let target = self.target.clone(); + let (tls_stream, measurements, attestation_type) = + self.attested_tls_server.accept().await?; - let acceptor = self.acceptor.clone(); - let target = self.target; - let cert_chain = self.cert_chain.clone(); - let attestation_generator = self.attestation_generator.clone(); - let attestation_verifier = self.attestation_verifier.clone(); tokio::spawn(async move { - if let Err(err) = Self::handle_connection( - inbound, - acceptor, - target, - cert_chain, - attestation_generator, - attestation_verifier, - ) - .await + if let Err(err) = + Self::handle_connection(tls_stream, measurements, attestation_type, target).await { warn!("Failed to handle connection: {err}"); } @@ -175,74 +149,18 @@ impl ProxyServer { /// Helper to get the socket address of the underlying TCP listener pub fn local_addr(&self) -> std::io::Result { - self.listener.local_addr() + self.attested_tls_server.local_addr() } /// Handle an incoming connection from a proxy-client async fn handle_connection( - inbound: TcpStream, - acceptor: TlsAcceptor, + tls_stream: tokio_rustls::server::TlsStream, + measurements: Option, + remote_attestation_type: AttestationType, target: SocketAddr, - cert_chain: Vec>, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, ) -> Result<(), ProxyError> { tracing::debug!("proxy-server accepted connection"); - // Do TLS handshake - let mut tls_stream = acceptor.accept(inbound).await?; - let (_io, connection) = tls_stream.get_ref(); - - // Ensure that we agreed a protocol - let _negotiated_protocol = connection.alpn_protocol().ok_or(ProxyError::AlpnFailed)?; - - // Compute an exporter unique to the session - let mut exporter = [0u8; 32]; - connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; - - let input_data = compute_report_input(Some(&cert_chain), exporter)?; - - // Get the TLS certficate chain of the client, if there is one - let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); - - // If we are in a CVM, generate an attestation - let attestation = attestation_generator - .generate_attestation(input_data) - .await? - .encode(); - - // Write our attestation to the channel, with length prefix - let attestation_length_prefix = length_prefix(&attestation); - tls_stream.write_all(&attestation_length_prefix).await?; - tls_stream.write_all(&attestation).await?; - - // Now read a length-prefixed attestation from the remote peer - // In the case of no client attestation this will be zero bytes - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; - - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; - - let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; - let remote_attestation_type = remote_attestation_message.attestation_type; - - // If we expect an attestaion from the client, verify it and get measurements - let measurements = if attestation_verifier.has_remote_attestion() { - let remote_input_data = compute_report_input(remote_cert_chain.as_deref(), exporter)?; - - attestation_verifier - .verify_attestation(remote_attestation_message, remote_input_data) - .await? - } else { - None - }; - // Setup an HTTP server let http = hyper::server::conn::http2::Builder::new(TokioExecutor); @@ -819,6 +737,8 @@ pub enum ProxyError { Serialization(#[from] parity_scale_codec::Error), #[error("Protocol negotiation failed - remote peer does not support this protocol")] AlpnFailed, + #[error("Attested TLS: {0}")] + AttestedTls(#[from] AttestedTlsError), } impl From> for ProxyError { diff --git a/src/main.rs b/src/main.rs index d9cb5b2..7859f7c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,8 +8,9 @@ use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier}, attested_get::attested_get, + attested_tls::TlsCertAndKey, file_server::attested_file_server, - get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, TlsCertAndKey, + get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, }; #[derive(Parser, Debug, Clone)] From 4ffcf95f4af796ad4bbabb69d5c4d51e8eb83bef Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 9 Jan 2026 14:17:48 +0100 Subject: [PATCH 3/7] Use attested-tls-client from refactored module --- src/attested_tls.rs | 65 +++----- src/lib.rs | 368 +++++++++++--------------------------------- src/main.rs | 4 +- src/test_helpers.rs | 3 +- 4 files changed, 114 insertions(+), 326 deletions(-) diff --git a/src/attested_tls.rs b/src/attested_tls.rs index 55037f5..df6cda6 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -38,9 +38,10 @@ pub struct TlsCertAndKey { } /// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address +#[derive(Clone)] pub struct AttestedTlsServer { /// The underlying TCP listener - listener: TcpListener, + pub listener: Arc, /// Quote generation type to use (including none) attestation_generator: AttestationGenerator, /// Verifier for remote attestation (including none) @@ -102,7 +103,7 @@ impl AttestedTlsServer { let listener = TcpListener::bind(local).await?; Ok(Self { - listener, + listener: listener.into(), attestation_generator, attestation_verifier, acceptor, @@ -110,7 +111,7 @@ impl AttestedTlsServer { }) } - /// Accept an incoming connection and handle it in a seperate task + /// Accept an incoming connection and do an attestation exchange pub async fn accept( &self, ) -> Result< @@ -123,18 +124,7 @@ impl AttestedTlsServer { > { let (inbound, _client_addr) = self.listener.accept().await?; - let acceptor = self.acceptor.clone(); - let cert_chain = self.cert_chain.clone(); - let attestation_generator = self.attestation_generator.clone(); - let attestation_verifier = self.attestation_verifier.clone(); - Ok(Self::handle_connection( - inbound, - acceptor, - cert_chain, - attestation_generator, - attestation_verifier, - ) - .await?) + self.handle_connection(inbound).await } /// Helper to get the socket address of the underlying TCP listener @@ -143,12 +133,13 @@ impl AttestedTlsServer { } /// Handle an incoming connection from a proxy-client - async fn handle_connection( + pub async fn handle_connection( + &self, inbound: TcpStream, - acceptor: TlsAcceptor, - cert_chain: Vec>, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, + // acceptor: TlsAcceptor, + // cert_chain: Vec>, + // attestation_generator: AttestationGenerator, + // attestation_verifier: AttestationVerifier, ) -> Result< ( tokio_rustls::server::TlsStream, @@ -160,7 +151,7 @@ impl AttestedTlsServer { tracing::debug!("attested-tls-server accepted connection"); // Do TLS handshake - let mut tls_stream = acceptor.accept(inbound).await?; + let mut tls_stream = self.acceptor.accept(inbound).await?; let (_io, connection) = tls_stream.get_ref(); // Ensure that we agreed a protocol @@ -176,13 +167,14 @@ impl AttestedTlsServer { None, // context )?; - let input_data = compute_report_input(Some(&cert_chain), exporter)?; + let input_data = compute_report_input(Some(&self.cert_chain), exporter)?; // Get the TLS certficate chain of the client, if there is one let remote_cert_chain = connection.peer_certificates().map(|c| c.to_owned()); // If we are in a CVM, generate an attestation - let attestation = attestation_generator + let attestation = self + .attestation_generator .generate_attestation(input_data) .await? .encode(); @@ -205,10 +197,10 @@ impl AttestedTlsServer { let remote_attestation_type = remote_attestation_message.attestation_type; // If we expect an attestaion from the client, verify it and get measurements - let measurements = if attestation_verifier.has_remote_attestion() { + let measurements = if self.attestation_verifier.has_remote_attestion() { let remote_input_data = compute_report_input(remote_cert_chain.as_deref(), exporter)?; - attestation_verifier + self.attestation_verifier .verify_attestation(remote_attestation_message, remote_input_data) .await? } else { @@ -220,9 +212,8 @@ impl AttestedTlsServer { } /// A proxy client which forwards http traffic to a proxy-server +#[derive(Clone)] pub struct AttestedTlsClient { - /// The underlying TCP listener - listener: TcpListener, /// The connector for making TLS connections with out configuration connector: TlsConnector, /// Quote generation type to use (including none) @@ -234,9 +225,10 @@ pub struct AttestedTlsClient { } impl std::fmt::Debug for AttestedTlsClient { + // TODO add other fields fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AttestedTlsClient") - .field("listener", &self.listener) + .field("attestation_verifier", &self.attestation_verifier) .finish() } } @@ -245,7 +237,6 @@ impl AttestedTlsClient { /// Start with optional TLS client auth pub async fn new( cert_and_key: Option, - address: impl ToSocketAddrs, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, remote_certificate: Option>, @@ -281,7 +272,6 @@ impl AttestedTlsClient { Self::new_with_tls_config( client_config.into(), - address, attestation_generator, attestation_verifier, cert_and_key.map(|c| c.cert_chain), @@ -291,20 +281,16 @@ impl AttestedTlsClient { /// Create a new proxy client with given TLS configuration /// - /// This is private as it allows dangerous configuration but is used in tests - async fn new_with_tls_config( + /// This not fully public as it allows dangerous configuration but is used in tests + pub(crate) async fn new_with_tls_config( client_config: Arc, - local: impl ToSocketAddrs, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { - // Setup TCP server and TLS client - let listener = TcpListener::bind(local).await?; let connector = TlsConnector::from(client_config.clone()); Ok(Self { - listener, connector, attestation_generator, attestation_verifier, @@ -312,11 +298,6 @@ impl AttestedTlsClient { }) } - /// Helper to return the local socket address from the underlying TCP listener - pub fn local_addr(&self) -> std::io::Result { - self.listener.local_addr() - } - /// Connect to the attested-tls-server, do TLS handshake and remote attestation pub async fn connect( &self, @@ -425,7 +406,7 @@ pub async fn get_tls_cert( get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await } -async fn get_tls_cert_with_config( +pub(crate) async fn get_tls_cert_with_config( server_name: String, attestation_verifier: AttestationVerifier, client_config: Arc, diff --git a/src/lib.rs b/src/lib.rs index 161b088..339bc7b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,36 +11,33 @@ use http::HeaderValue; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{service::service_fn, Response}; use hyper_util::rt::TokioIo; -use parity_scale_codec::{Decode, Encode}; use sha2::{Digest, Sha256}; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; -use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; +use tokio_rustls::rustls::server::VerifierBuilderError; use tracing::{error, warn}; use x509_parser::parse_x509_certificate; #[cfg(test)] mod test_helpers; +use std::net::SocketAddr; use std::num::TryFromIntError; use std::time::Duration; -use std::{net::SocketAddr, sync::Arc}; -use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; +use tokio::io; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; -use tokio_rustls::rustls::pki_types::{CertificateDer, ServerName}; -use tokio_rustls::rustls::RootCertStore; -use tokio_rustls::{ - rustls::{ClientConfig, ServerConfig}, - TlsConnector, -}; +use tokio_rustls::rustls::pki_types::CertificateDer; -use crate::attestation::{ - measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage, AttestationType, - AttestationVerifier, -}; -use crate::attested_tls::{ - AttestedTlsError, AttestedTlsServer, TlsCertAndKey, EXPORTER_LABEL, - SUPPORTED_ALPN_PROTOCOL_VERSIONS, +#[cfg(test)] +use std::sync::Arc; +#[cfg(test)] +use tokio_rustls::rustls::{ClientConfig, ServerConfig}; + +use crate::{ + attestation::{ + measurements::MultiMeasurements, AttestationError, AttestationType, AttestationVerifier, + }, + attested_tls::{AttestedTlsClient, AttestedTlsError, AttestedTlsServer, TlsCertAndKey}, }; /// The header name for giving attestation type @@ -60,6 +57,7 @@ type Http2Sender = hyper::client::conn::http2::SendRequest Result { - let mut server_config = if client_auth { - let root_store = - RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?; - - ServerConfig::builder() - .with_client_cert_verifier(verifier) - .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? - } else { - ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)? - }; - - server_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS - .into_iter() - .map(|p| p.to_vec()) - .collect(); - - Self::new_with_tls_config( - cert_and_key.cert_chain, - server_config.into(), + let attested_tls_server = AttestedTlsServer::new( + cert_and_key, local, - target, attestation_generator, attestation_verifier, + client_auth, ) - .await + .await?; + + Ok(Self { + attested_tls_server, + target, + }) } /// Start with preconfigured TLS /// /// This is not public as it allows dangerous configuration + #[cfg(test)] async fn new_with_tls_config( cert_chain: Vec>, server_config: Arc, @@ -132,15 +116,23 @@ impl ProxyServer { /// Accept an incoming connection and handle it in a seperate task pub async fn accept(&self) -> Result<(), ProxyError> { - let target = self.target.clone(); - let (tls_stream, measurements, attestation_type) = - self.attested_tls_server.accept().await?; + let target = self.target; + let (inbound, _client_addr) = self.attested_tls_server.listener.accept().await?; + let attested_tls_server = self.attested_tls_server.clone(); tokio::spawn(async move { - if let Err(err) = - Self::handle_connection(tls_stream, measurements, attestation_type, target).await - { - warn!("Failed to handle connection: {err}"); + match attested_tls_server.handle_connection(inbound).await { + Ok((tls_stream, measurements, attestation_type)) => { + if let Err(err) = + Self::handle_connection(tls_stream, measurements, attestation_type, target) + .await + { + warn!("Failed to handle connection: {err}"); + } + } + Err(err) => { + warn!("Attestation exchange failed: {err}"); + } } }); @@ -267,60 +259,49 @@ impl ProxyClient { attestation_verifier: AttestationVerifier, remote_certificate: Option>, ) -> Result { - // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots - let root_store = match remote_certificate { - Some(remote_certificate) => { - let mut root_store = RootCertStore::empty(); - root_store.add(remote_certificate)?; - root_store - } - None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), - }; - - // Setup TLS client configuration, with or without client auth - let mut client_config = if let Some(ref cert_and_key) = cert_and_key { - ClientConfig::builder() - .with_root_certificates(root_store) - .with_client_auth_cert( - cert_and_key.cert_chain.clone(), - cert_and_key.key.clone_key(), - )? - } else { - ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth() - }; - - client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS - .into_iter() - .map(|p| p.to_vec()) - .collect(); - - Self::new_with_tls_config( - client_config.into(), - address, - server_name, + let attested_tls_client = AttestedTlsClient::new( + cert_and_key, attestation_generator, attestation_verifier, - cert_and_key.map(|c| c.cert_chain), + remote_certificate, ) - .await + .await?; + + Self::new_with_inner(address, attested_tls_client, server_name).await } /// Create a new proxy client with given TLS configuration /// /// This is private as it allows dangerous configuration but is used in tests + #[cfg(test)] async fn new_with_tls_config( client_config: Arc, - local: impl ToSocketAddrs, + address: impl ToSocketAddrs, target_name: String, attestation_generator: AttestationGenerator, attestation_verifier: AttestationVerifier, cert_chain: Option>>, ) -> Result { - // Setup TCP server and TLS client - let listener = TcpListener::bind(local).await?; - let connector = TlsConnector::from(client_config.clone()); + let attested_tls_client = AttestedTlsClient::new_with_tls_config( + client_config, + attestation_generator, + attestation_verifier, + cert_chain, + ) + .await?; + + Self::new_with_inner(address, attested_tls_client, target_name).await + } + + /// Create a new proxy client with given TLS configuration + /// + /// This is private as it allows dangerous configuration but is used in tests + async fn new_with_inner( + address: impl ToSocketAddrs, + attested_tls_client: AttestedTlsClient, + target_name: String, + ) -> Result { + let listener = TcpListener::bind(address).await?; // Process the hostname / port provided by the user let target = host_to_host_with_port(&target_name); @@ -335,16 +316,9 @@ impl ProxyClient { // Connect to the proxy server and provide / verify attestation let (mut sender, mut measurements, mut remote_attestation_type) = - Self::setup_connection_with_backoff( - connector.clone(), - target.clone(), - cert_chain.clone(), - attestation_generator.clone(), - attestation_verifier.clone(), - true, - ) - .await?; + Self::setup_connection_with_backoff(target.clone(), &attested_tls_client, true).await?; + let attested_tls_client_clone = attested_tls_client.clone(); tokio::spawn(async move { // Read an incoming request from the channel (from the source client) while let Some((req, response_tx)) = requests_rx.recv().await { @@ -393,11 +367,8 @@ impl ProxyClient { // Reconnect to the server - retrying indefinately with a backoff (sender, measurements, remote_attestation_type) = Self::setup_connection_with_backoff( - connector.clone(), target.clone(), - cert_chain.clone(), - attestation_generator.clone(), - attestation_verifier.clone(), + &attested_tls_client_clone, false, ) .await @@ -467,26 +438,15 @@ impl ProxyClient { // Attempt connection and handshake with the proxy-server // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( - connector: TlsConnector, target: String, - cert_chain: Option>>, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, + attested_tls_client: &AttestedTlsClient, should_bail: bool, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { let mut delay = Duration::from_secs(1); let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); loop { - match Self::setup_connection( - connector.clone(), - target.clone(), - cert_chain.clone(), - attestation_generator.clone(), - attestation_verifier.clone(), - ) - .await - { + match Self::setup_connection(attested_tls_client, target.clone()).await { Ok(output) => { return Ok(output); } @@ -508,74 +468,12 @@ impl ProxyClient { /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( - connector: TlsConnector, + inner: &AttestedTlsClient, target: String, - cert_chain: Option>>, - attestation_generator: AttestationGenerator, - attestation_verifier: AttestationVerifier, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { - // Make a TCP client connection and TLS handshake - let out = TcpStream::connect(&target).await?; - let mut tls_stream = connector - .connect(server_name_from_host(&target)?, out) - .await?; - - let (_io, server_connection) = tls_stream.get_ref(); - - // Ensure that we agreed a protocol - let _negotiated_protocol = server_connection - .alpn_protocol() - .ok_or(ProxyError::AlpnFailed)?; - - // Compute an exporter unique to the channel - let mut exporter = [0u8; 32]; - server_connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; - - // Get the TLS certificate chain of the server - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(ProxyError::NoCertificate)? - .to_owned(); - - let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; - - // Read a length prefixed attestation from the proxy-server - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; + let (tls_stream, measurements, remote_attestation_type) = inner.connect(target).await?; - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; - - let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; - let remote_attestation_type = remote_attestation_message.attestation_type; - - // Verify the remote attestation against our accepted measurements - let measurements = attestation_verifier - .verify_attestation(remote_attestation_message, remote_input_data) - .await?; - - // If we are in a CVM, provide an attestation - let attestation = if attestation_generator.attestation_type != AttestationType::None { - let local_input_data = compute_report_input(cert_chain.as_deref(), exporter)?; - attestation_generator - .generate_attestation(local_input_data) - .await? - .encode() - } else { - AttestationExchangeMessage::without_attestation().encode() - }; - - // Send our attestation (or zero bytes) prefixed with length - let attestation_length_prefix = length_prefix(&attestation); - tls_stream.write_all(&attestation_length_prefix).await?; - tls_stream.write_all(&attestation).await?; - - // The attestation exchange is now complete - now setup an HTTP client + // The attestation exchange is now complete - setup an HTTP client let outbound_io = TokioIo::new(tls_stream); let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor) @@ -604,81 +502,6 @@ impl ProxyClient { } } -/// Just get the attested remote certificate, with no client authentication -pub async fn get_tls_cert( - server_name: String, - attestation_verifier: AttestationVerifier, - remote_certificate: Option>, -) -> Result>, ProxyError> { - tracing::debug!("Getting remote TLS cert"); - // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots - let root_store = match remote_certificate { - Some(remote_certificate) => { - let mut root_store = RootCertStore::empty(); - root_store.add(remote_certificate)?; - root_store - } - None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), - }; - - let mut client_config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - - client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS - .into_iter() - .map(|p| p.to_vec()) - .collect(); - - get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await -} - -async fn get_tls_cert_with_config( - server_name: String, - attestation_verifier: AttestationVerifier, - client_config: Arc, -) -> Result>, ProxyError> { - let connector = TlsConnector::from(client_config); - - let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?; - let mut tls_stream = connector - .connect(server_name_from_host(&server_name)?, out) - .await?; - - let (_io, server_connection) = tls_stream.get_ref(); - - let mut exporter = [0u8; 32]; - server_connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; - - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(ProxyError::NoCertificate)? - .to_owned(); - - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; - - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; - - let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; - - let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; - - let _measurements = attestation_verifier - .verify_attestation(remote_attestation_message, remote_input_data) - .await?; - - tls_stream.shutdown().await?; - - Ok(remote_cert_chain) -} - /// Given a certificate chain and an exporter (session key material), build the quote input value /// SHA256(pki) || exporter pub fn compute_report_input( @@ -733,10 +556,6 @@ pub enum ProxyError { OneShotRecv(#[from] oneshot::error::RecvError), #[error("Failed to send request, connection to proxy-server dropped")] MpscSend, - #[error("Serialization: {0}")] - Serialization(#[from] parity_scale_codec::Error), - #[error("Protocol negotiation failed - remote peer does not support this protocol")] - AlpnFailed, #[error("Attested TLS: {0}")] AttestedTls(#[from] AttestedTlsError), } @@ -747,12 +566,6 @@ impl From> for ProxyError { } } -/// Given a byte array, encode its length as a 4 byte big endian u32 -fn length_prefix(input: &[u8]) -> [u8; 4] { - let len = input.len() as u32; - len.to_be_bytes() -} - /// If no port was provided, default to 443 fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { @@ -762,20 +575,6 @@ fn host_to_host_with_port(host: &str) -> String { } } -/// Given a hostname with or without port number, create a TLS [ServerName] with just the host part -fn server_name_from_host( - host: &str, -) -> Result, tokio_rustls::rustls::pki_types::InvalidDnsNameError> { - // If host contains ':', try to split off the port. - let host_part = host.rsplit_once(':').map(|(h, _)| h).unwrap_or(host); - - // If the host is an IPv6 literal in brackets like "[::1]:443", - // remove the brackets for SNI (SNI allows bare IPv6 too). - let host_part = host_part.trim_matches(|c| c == '[' || c == ']'); - - ServerName::try_from(host_part.to_string()) -} - /// An Executor for hyper that uses the tokio runtime #[derive(Clone)] struct TokioExecutor; @@ -796,8 +595,11 @@ where mod tests { use std::collections::HashMap; - use crate::attestation::measurements::{ - DcapMeasurementRegister, MeasurementPolicy, MeasurementRecord, MultiMeasurements, + use crate::{ + attestation::measurements::{ + DcapMeasurementRegister, MeasurementPolicy, MeasurementRecord, MultiMeasurements, + }, + attested_tls::get_tls_cert_with_config, }; use super::*; @@ -1208,7 +1010,9 @@ mod tests { assert!(matches!( proxy_client_result.unwrap_err(), - ProxyError::Attestation(AttestationError::AttestationTypeNotAccepted) + ProxyError::AttestedTls(AttestedTlsError::Attestation( + AttestationError::AttestationTypeNotAccepted + )) )); } @@ -1267,7 +1071,9 @@ mod tests { assert!(matches!( proxy_client_result.unwrap_err(), - ProxyError::Attestation(AttestationError::MeasurementsNotAccepted) + ProxyError::AttestedTls(AttestedTlsError::Attestation( + AttestationError::MeasurementsNotAccepted + )) )); } } diff --git a/src/main.rs b/src/main.rs index 7859f7c..9bdff8f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,9 +8,9 @@ use tracing::level_filters::LevelFilter; use attested_tls_proxy::{ attestation::{measurements::MeasurementPolicy, AttestationType, AttestationVerifier}, attested_get::attested_get, - attested_tls::TlsCertAndKey, + attested_tls::{get_tls_cert, TlsCertAndKey}, file_server::attested_file_server, - get_tls_cert, health_check, AttestationGenerator, ProxyClient, ProxyServer, + health_check, AttestationGenerator, ProxyClient, ProxyServer, }; #[derive(Parser, Debug, Clone)] diff --git a/src/test_helpers.rs b/src/test_helpers.rs index b783dff..c7df30e 100644 --- a/src/test_helpers.rs +++ b/src/test_helpers.rs @@ -13,7 +13,8 @@ use tokio_rustls::rustls::{ use crate::{ attestation::measurements::{DcapMeasurementRegister, MultiMeasurements}, - MEASUREMENT_HEADER, SUPPORTED_ALPN_PROTOCOL_VERSIONS, + attested_tls::SUPPORTED_ALPN_PROTOCOL_VERSIONS, + MEASUREMENT_HEADER, }; /// Helper to generate a self-signed certificate for testing From 28e7a26ec1ec8192c9baee6e1a8611810f2f4eef Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 9 Jan 2026 14:20:48 +0100 Subject: [PATCH 4/7] Rm unused import --- src/attested_tls.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/attested_tls.rs b/src/attested_tls.rs index df6cda6..372c8cb 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -5,7 +5,6 @@ use parity_scale_codec::{Decode, Encode}; use sha2::{Digest, Sha256}; use thiserror::Error; use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier}; -use tracing::error; use x509_parser::parse_x509_certificate; use std::num::TryFromIntError; From 82b9bf5a628a975aac5dbf933ab776f729a0009d Mon Sep 17 00:00:00 2001 From: peg Date: Fri, 9 Jan 2026 15:08:34 +0100 Subject: [PATCH 5/7] Tidy --- src/attestation/mod.rs | 2 +- src/attested_tls.rs | 16 +++++----------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/src/attestation/mod.rs b/src/attestation/mod.rs index bc78624..4f8cd75 100644 --- a/src/attestation/mod.rs +++ b/src/attestation/mod.rs @@ -115,7 +115,7 @@ impl Display for AttestationType { } /// Can generate a local attestation based on attestation type -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct AttestationGenerator { pub attestation_type: AttestationType, dummy_dcap_url: Option, diff --git a/src/attested_tls.rs b/src/attested_tls.rs index 372c8cb..486d226 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -135,10 +135,6 @@ impl AttestedTlsServer { pub async fn handle_connection( &self, inbound: TcpStream, - // acceptor: TlsAcceptor, - // cert_chain: Vec>, - // attestation_generator: AttestationGenerator, - // attestation_verifier: AttestationVerifier, ) -> Result< ( tokio_rustls::server::TlsStream, @@ -224,10 +220,11 @@ pub struct AttestedTlsClient { } impl std::fmt::Debug for AttestedTlsClient { - // TODO add other fields fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("AttestedTlsClient") .field("attestation_verifier", &self.attestation_verifier) + .field("attestation_generator", &self.attestation_generator) + .field("cert_chain", &self.cert_chain) .finish() } } @@ -297,7 +294,7 @@ impl AttestedTlsClient { }) } - /// Connect to the attested-tls-server, do TLS handshake and remote attestation + /// Connect to an attested-tls-server, do TLS handshake and attestation exchange pub async fn connect( &self, target: String, @@ -376,7 +373,7 @@ impl AttestedTlsClient { } } -/// Just get the attested remote certificate, with no client authentication +/// A client which just gets the attested remote certificate, with no client authentication pub async fn get_tls_cert( server_name: String, attestation_verifier: AttestationVerifier, @@ -405,6 +402,7 @@ pub async fn get_tls_cert( get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await } +// TODO this could use AttestedTlsClient to avoid repeating code pub(crate) async fn get_tls_cert_with_config( server_name: String, attestation_verifier: AttestationVerifier, @@ -497,10 +495,6 @@ pub enum AttestedTlsError { IntConversion(#[from] TryFromIntError), #[error("Bad host name: {0}")] BadDnsName(#[from] tokio_rustls::rustls::pki_types::InvalidDnsNameError), - #[error("HTTP: {0}")] - Hyper(#[from] hyper::Error), - #[error("JSON: {0}")] - Json(#[from] serde_json::Error), #[error("Serialization: {0}")] Serialization(#[from] parity_scale_codec::Error), #[error("Protocol negotiation failed - remote peer does not support this protocol")] From d872a006ea9f8a54f12cf369b536283af76e4e0a Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 12 Jan 2026 09:04:29 +0100 Subject: [PATCH 6/7] Deduplicate get_tls_cert --- src/attested_tls.rs | 126 +++++++++++++++++--------------------------- src/lib.rs | 22 ++++---- 2 files changed, 60 insertions(+), 88 deletions(-) diff --git a/src/attested_tls.rs b/src/attested_tls.rs index 486d226..6460371 100644 --- a/src/attested_tls.rs +++ b/src/attested_tls.rs @@ -1,5 +1,9 @@ -use crate::attestation::{ - measurements::MultiMeasurements, AttestationError, AttestationGenerator, AttestationType, +use crate::{ + attestation::{ + measurements::MultiMeasurements, AttestationError, AttestationExchangeMessage, + AttestationGenerator, AttestationType, AttestationVerifier, + }, + host_to_host_with_port, }; use parity_scale_codec::{Decode, Encode}; use sha2::{Digest, Sha256}; @@ -18,8 +22,6 @@ use tokio_rustls::{ TlsAcceptor, TlsConnector, }; -use crate::attestation::{AttestationExchangeMessage, AttestationVerifier}; - /// This makes it possible to add breaking protocol changes and provide backwards compatibility. /// When adding more supported versions, note that ordering is important. ALPN will pick the first /// protocol which both parties support - so newer supported versions should come first. @@ -297,7 +299,7 @@ impl AttestedTlsClient { /// Connect to an attested-tls-server, do TLS handshake and attestation exchange pub async fn connect( &self, - target: String, + target: &str, ) -> Result< ( tokio_rustls::client::TlsStream, @@ -310,7 +312,7 @@ impl AttestedTlsClient { let out = TcpStream::connect(&target).await?; let mut tls_stream = self .connector - .connect(server_name_from_host(&target)?, out) + .connect(server_name_from_host(target)?, out) .await?; let (_io, server_connection) = tls_stream.get_ref(); @@ -371,82 +373,61 @@ impl AttestedTlsClient { Ok((tls_stream, measurements, remote_attestation_type)) } + + /// Connect to an attested TLS server, retrieve the remote TLS certificate and return it + pub async fn get_tls_cert( + &self, + server_name: &str, + ) -> Result>, AttestedTlsError> { + let (mut tls_stream, _, _) = self.connect(server_name).await?; + + let (_io, server_connection) = tls_stream.get_ref(); + + let remote_cert_chain = server_connection + .peer_certificates() + .ok_or(AttestedTlsError::NoCertificate)? + .to_owned(); + + tls_stream.shutdown().await?; + + Ok(remote_cert_chain) + } } /// A client which just gets the attested remote certificate, with no client authentication pub async fn get_tls_cert( server_name: String, attestation_verifier: AttestationVerifier, - remote_certificate: Option>, + remote_certificate: Option>, ) -> Result>, AttestedTlsError> { tracing::debug!("Getting remote TLS cert"); - // If a remote CA cert was given, use it as the root store, otherwise use webpki_roots - let root_store = match remote_certificate { - Some(remote_certificate) => { - let mut root_store = RootCertStore::empty(); - root_store.add(remote_certificate)?; - root_store - } - None => RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()), - }; - - let mut client_config = ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - - client_config.alpn_protocols = SUPPORTED_ALPN_PROTOCOL_VERSIONS - .into_iter() - .map(|p| p.to_vec()) - .collect(); - - get_tls_cert_with_config(server_name, attestation_verifier, client_config.into()).await + let attested_tls_client = AttestedTlsClient::new( + None, + AttestationGenerator::with_no_attestation(), + attestation_verifier, + remote_certificate, + ) + .await?; + attested_tls_client + .get_tls_cert(&host_to_host_with_port(&server_name)) + .await } -// TODO this could use AttestedTlsClient to avoid repeating code +/// Helper for testing getting remote certificate +#[cfg(test)] pub(crate) async fn get_tls_cert_with_config( - server_name: String, + server_name: &str, attestation_verifier: AttestationVerifier, client_config: Arc, ) -> Result>, AttestedTlsError> { - let connector = TlsConnector::from(client_config); - - let out = TcpStream::connect(host_to_host_with_port(&server_name)).await?; - let mut tls_stream = connector - .connect(server_name_from_host(&server_name)?, out) - .await?; - - let (_io, server_connection) = tls_stream.get_ref(); - - let mut exporter = [0u8; 32]; - server_connection.export_keying_material( - &mut exporter, - EXPORTER_LABEL, - None, // context - )?; - - let remote_cert_chain = server_connection - .peer_certificates() - .ok_or(AttestedTlsError::NoCertificate)? - .to_owned(); - - let mut length_bytes = [0; 4]; - tls_stream.read_exact(&mut length_bytes).await?; - let length: usize = u32::from_be_bytes(length_bytes).try_into()?; - - let mut buf = vec![0; length]; - tls_stream.read_exact(&mut buf).await?; - - let remote_attestation_message = AttestationExchangeMessage::decode(&mut &buf[..])?; - - let remote_input_data = compute_report_input(Some(&remote_cert_chain), exporter)?; - - let _measurements = attestation_verifier - .verify_attestation(remote_attestation_message, remote_input_data) - .await?; - - tls_stream.shutdown().await?; - - Ok(remote_cert_chain) + let attested_tls_client = AttestedTlsClient::new_with_tls_config( + client_config, + AttestationGenerator::with_no_attestation(), + attestation_verifier, + None, + ) + .await?; + attested_tls_client.get_tls_cert(server_name).await } /// Given a certificate chain and an exporter (session key material), build the quote input value @@ -507,15 +488,6 @@ fn length_prefix(input: &[u8]) -> [u8; 4] { len.to_be_bytes() } -/// If no port was provided, default to 443 -fn host_to_host_with_port(host: &str) -> String { - if host.contains(':') { - host.to_string() - } else { - format!("{host}:443") - } -} - /// Given a hostname with or without port number, create a TLS [ServerName] with just the host part fn server_name_from_host( host: &str, diff --git a/src/lib.rs b/src/lib.rs index 339bc7b..00db557 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -267,7 +267,7 @@ impl ProxyClient { ) .await?; - Self::new_with_inner(address, attested_tls_client, server_name).await + Self::new_with_inner(address, attested_tls_client, &server_name).await } /// Create a new proxy client with given TLS configuration @@ -290,7 +290,7 @@ impl ProxyClient { ) .await?; - Self::new_with_inner(address, attested_tls_client, target_name).await + Self::new_with_inner(address, attested_tls_client, &target_name).await } /// Create a new proxy client with given TLS configuration @@ -299,12 +299,12 @@ impl ProxyClient { async fn new_with_inner( address: impl ToSocketAddrs, attested_tls_client: AttestedTlsClient, - target_name: String, + target_name: &str, ) -> Result { let listener = TcpListener::bind(address).await?; // Process the hostname / port provided by the user - let target = host_to_host_with_port(&target_name); + let target = host_to_host_with_port(target_name); // Channel for getting incoming requests from the source client let (requests_tx, mut requests_rx) = mpsc::channel::<( @@ -316,7 +316,7 @@ impl ProxyClient { // Connect to the proxy server and provide / verify attestation let (mut sender, mut measurements, mut remote_attestation_type) = - Self::setup_connection_with_backoff(target.clone(), &attested_tls_client, true).await?; + Self::setup_connection_with_backoff(&target, &attested_tls_client, true).await?; let attested_tls_client_clone = attested_tls_client.clone(); tokio::spawn(async move { @@ -367,7 +367,7 @@ impl ProxyClient { // Reconnect to the server - retrying indefinately with a backoff (sender, measurements, remote_attestation_type) = Self::setup_connection_with_backoff( - target.clone(), + &target, &attested_tls_client_clone, false, ) @@ -438,7 +438,7 @@ impl ProxyClient { // Attempt connection and handshake with the proxy-server // If it fails retry with a backoff (indefinately) async fn setup_connection_with_backoff( - target: String, + target: &str, attested_tls_client: &AttestedTlsClient, should_bail: bool, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { @@ -446,7 +446,7 @@ impl ProxyClient { let max_delay = Duration::from_secs(SERVER_RECONNECT_MAX_BACKOFF_SECS); loop { - match Self::setup_connection(attested_tls_client, target.clone()).await { + match Self::setup_connection(attested_tls_client, target).await { Ok(output) => { return Ok(output); } @@ -469,7 +469,7 @@ impl ProxyClient { /// Connect to the proxy-server, do TLS handshake and remote attestation async fn setup_connection( inner: &AttestedTlsClient, - target: String, + target: &str, ) -> Result<(Http2Sender, Option, AttestationType), ProxyError> { let (tls_stream, measurements, remote_attestation_type) = inner.connect(target).await?; @@ -567,7 +567,7 @@ impl From> for ProxyError { } /// If no port was provided, default to 443 -fn host_to_host_with_port(host: &str) -> String { +pub(crate) fn host_to_host_with_port(host: &str) -> String { if host.contains(':') { host.to_string() } else { @@ -962,7 +962,7 @@ mod tests { }); let retrieved_chain = get_tls_cert_with_config( - proxy_server_addr.to_string(), + &proxy_server_addr.to_string(), AttestationVerifier::mock(), client_config, ) From 3931cca4498ab9e28bf9bd8f709fa6ce912e2ad3 Mon Sep 17 00:00:00 2001 From: peg Date: Mon, 12 Jan 2026 09:07:49 +0100 Subject: [PATCH 7/7] Rm deplicated helpers --- src/lib.rs | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 00db557..36a738c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,12 +11,10 @@ use http::HeaderValue; use http_body_util::{combinators::BoxBody, BodyExt}; use hyper::{service::service_fn, Response}; use hyper_util::rt::TokioIo; -use sha2::{Digest, Sha256}; use thiserror::Error; use tokio::sync::{mpsc, oneshot}; use tokio_rustls::rustls::server::VerifierBuilderError; use tracing::{error, warn}; -use x509_parser::parse_x509_certificate; #[cfg(test)] mod test_helpers; @@ -502,35 +500,6 @@ impl ProxyClient { } } -/// Given a certificate chain and an exporter (session key material), build the quote input value -/// SHA256(pki) || exporter -pub fn compute_report_input( - cert_chain: Option<&[CertificateDer<'_>]>, - exporter: [u8; 32], -) -> Result<[u8; 64], AttestationError> { - let mut quote_input = [0u8; 64]; - if let Some(cert_chain) = cert_chain { - let pki_hash = get_pki_hash_from_certificate_chain(cert_chain)?; - quote_input[..32].copy_from_slice(&pki_hash); - } - quote_input[32..].copy_from_slice(&exporter); - Ok(quote_input) -} - -/// Given a certificate chain, get the [Sha256] hash of the public key of the leaf certificate -fn get_pki_hash_from_certificate_chain( - cert_chain: &[CertificateDer<'_>], -) -> Result<[u8; 32], AttestationError> { - let leaf_certificate = cert_chain.first().ok_or(AttestationError::NoCertificate)?; - let (_, cert) = parse_x509_certificate(leaf_certificate.as_ref())?; - let public_key = &cert.tbs_certificate.subject_pki; - let key_bytes = public_key.subject_public_key.as_ref(); - - let mut hasher = Sha256::new(); - hasher.update(key_bytes); - Ok(hasher.finalize().into()) -} - /// An error when running a proxy client or server #[derive(Error, Debug)] pub enum ProxyError {