Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ jobs:
run: cargo clippy --workspace -- -D warnings

- name: Run cargo test
run: cargo test --workspace --all-targets
run: cargo test --workspace --all-targets -- --test-threads=1
91 changes: 45 additions & 46 deletions src/attested_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
use x509_parser::parse_x509_certificate;

use std::num::TryFromIntError;
use std::{net::SocketAddr, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use tokio_rustls::rustls::RootCertStore;
use tokio_rustls::{
Expand All @@ -38,16 +37,14 @@ pub struct TlsCertAndKey {
pub key: PrivateKeyDer<'static>,
}

/// A TLS over TCP server which provides an attestation before forwarding traffic to a given target address
/// A TLS server which makes an attestation exchange following the TLS handshake
#[derive(Clone)]
pub struct AttestedTlsServer {
/// The underlying TCP listener
pub listener: Arc<TcpListener>,
/// Quote generation type to use (including none)
attestation_generator: AttestationGenerator,
/// Verifier for remote attestation (including none)
attestation_verifier: AttestationVerifier,
/// The certificate chain
/// The TLS certificate chain
cert_chain: Vec<CertificateDer<'static>>,
/// For accepting TLS connections
acceptor: TlsAcceptor,
Expand All @@ -56,7 +53,6 @@ pub struct AttestedTlsServer {
impl AttestedTlsServer {
pub async fn new(
cert_and_key: TlsCertAndKey,
local: impl ToSocketAddrs,
attestation_generator: AttestationGenerator,
attestation_verifier: AttestationVerifier,
client_auth: bool,
Expand All @@ -83,7 +79,6 @@ impl AttestedTlsServer {
Self::new_with_tls_config(
cert_and_key.cert_chain,
server_config.into(),
local,
attestation_generator,
attestation_verifier,
)
Expand All @@ -96,55 +91,36 @@ impl AttestedTlsServer {
pub(crate) async fn new_with_tls_config(
cert_chain: Vec<CertificateDer<'static>>,
server_config: Arc<ServerConfig>,
local: impl ToSocketAddrs,
attestation_generator: AttestationGenerator,
attestation_verifier: AttestationVerifier,
) -> Result<Self, AttestedTlsError> {
let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
let listener = TcpListener::bind(local).await?;

Ok(Self {
listener: listener.into(),
attestation_generator,
attestation_verifier,
acceptor,
cert_chain,
})
}

/// Accept an incoming connection and do an attestation exchange
pub async fn accept(
&self,
) -> Result<
(
tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
Option<MultiMeasurements>,
AttestationType,
),
AttestedTlsError,
> {
let (inbound, _client_addr) = self.listener.accept().await?;

self.handle_connection(inbound).await
}

/// Helper to get the socket address of the underlying TCP listener
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.listener.local_addr()
}

/// Handle an incoming connection from a proxy-client
pub async fn handle_connection(
///
/// This is transport agnostic and will work with any asynchronous stream
pub async fn handle_connection<IO>(
&self,
inbound: TcpStream,
inbound: IO,
) -> Result<
(
tokio_rustls::server::TlsStream<tokio::net::TcpStream>,
tokio_rustls::server::TlsStream<IO>,
Option<MultiMeasurements>,
AttestationType,
),
AttestedTlsError,
> {
>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
tracing::debug!("attested-tls-server accepted connection");

// Do TLS handshake
Expand Down Expand Up @@ -296,23 +272,29 @@ impl AttestedTlsClient {
})
}

/// Connect to an attested-tls-server, do TLS handshake and attestation exchange
pub async fn connect(
/// Given a connection to an attested TLS server, do a TLS handshake and attestation exchange, and return the TLS
/// stream together with measurement details
///
/// This is transport agnostic and will work with any asynchronous stream
pub async fn connect<IO>(
&self,
target: &str,
outbound: IO,
) -> Result<
(
tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
tokio_rustls::client::TlsStream<IO>,
Option<MultiMeasurements>,
AttestationType,
),
AttestedTlsError,
> {
// Make a TCP client connection and TLS handshake
let out = TcpStream::connect(&target).await?;
>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
// Make a TLS handshake with the given connection
let mut tls_stream = self
.connector
.connect(server_name_from_host(target)?, out)
.connect(server_name_from_host(target)?, outbound)
.await?;

let (_io, server_connection) = tls_stream.get_ref();
Expand Down Expand Up @@ -374,12 +356,29 @@ impl AttestedTlsClient {
Ok((tls_stream, measurements, remote_attestation_type))
}

/// Connect to an attested TLS server, retrieve the remote TLS certificate and return it
/// Make a TCP connection, do a TLS handshake and attestation exchange, and return the TLS
/// stream together with measurement details
pub async fn connect_tcp(
&self,
target: &str,
) -> Result<
(
tokio_rustls::client::TlsStream<tokio::net::TcpStream>,
Option<MultiMeasurements>,
AttestationType,
),
AttestedTlsError,
> {
let out = tokio::net::TcpStream::connect(&target).await?;
self.connect(target, out).await
}

/// Connect to an attested TLS server using TCP, retrieve the remote TLS certificate and return it
pub async fn get_tls_cert(
&self,
server_name: &str,
) -> Result<Vec<CertificateDer<'static>>, AttestedTlsError> {
let (mut tls_stream, _, _) = self.connect(server_name).await?;
let (mut tls_stream, _, _) = self.connect_tcp(server_name).await?;

let (_io, server_connection) = tls_stream.get_ref();

Expand Down
22 changes: 12 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,11 @@ use tracing::{error, warn};
#[cfg(test)]
mod test_helpers;

use std::net::SocketAddr;
use std::num::TryFromIntError;
use std::time::Duration;
use std::{net::SocketAddr, num::TryFromIntError, sync::Arc, time::Duration};
use tokio::io;
use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
use tokio_rustls::rustls::pki_types::CertificateDer;

#[cfg(test)]
use std::sync::Arc;
#[cfg(test)]
use tokio_rustls::rustls::{ClientConfig, ServerConfig};

Expand Down Expand Up @@ -57,6 +53,8 @@ type Http2Sender = hyper::client::conn::http2::SendRequest<hyper::body::Incoming
pub struct ProxyServer {
/// The underlying attested TLS server
attested_tls_server: AttestedTlsServer,
/// The underlying TCP listener
listener: Arc<TcpListener>,
/// The address of the target service we are proxying to
target: SocketAddr,
}
Expand All @@ -72,15 +70,17 @@ impl ProxyServer {
) -> Result<Self, ProxyError> {
let attested_tls_server = AttestedTlsServer::new(
cert_and_key,
local,
attestation_generator,
attestation_verifier,
client_auth,
)
.await?;

let listener = TcpListener::bind(local).await?;

Ok(Self {
attested_tls_server,
listener: listener.into(),
target,
})
}
Expand All @@ -100,22 +100,24 @@ impl ProxyServer {
let attested_tls_server = AttestedTlsServer::new_with_tls_config(
cert_chain,
server_config,
local,
attestation_generator,
attestation_verifier,
)
.await?;

let listener = TcpListener::bind(local).await?;

Ok(Self {
attested_tls_server,
listener: listener.into(),
target,
})
}

/// Accept an incoming connection and handle it in a seperate task
pub async fn accept(&self) -> Result<(), ProxyError> {
let target = self.target;
let (inbound, _client_addr) = self.attested_tls_server.listener.accept().await?;
let (inbound, _client_addr) = self.listener.accept().await?;
let attested_tls_server = self.attested_tls_server.clone();

tokio::spawn(async move {
Expand All @@ -139,7 +141,7 @@ impl ProxyServer {

/// Helper to get the socket address of the underlying TCP listener
pub fn local_addr(&self) -> std::io::Result<SocketAddr> {
self.attested_tls_server.local_addr()
self.listener.local_addr()
}

/// Handle an incoming connection from a proxy-client
Expand Down Expand Up @@ -469,7 +471,7 @@ impl ProxyClient {
inner: &AttestedTlsClient,
target: &str,
) -> Result<(Http2Sender, Option<MultiMeasurements>, AttestationType), ProxyError> {
let (tls_stream, measurements, remote_attestation_type) = inner.connect(target).await?;
let (tls_stream, measurements, remote_attestation_type) = inner.connect_tcp(target).await?;

// The attestation exchange is now complete - setup an HTTP client

Expand Down