Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ base64 = "0.22"
tower = "0.5"
futures-util = "0.3"
ammonia = "4.1.1"
chrono = "0.4"

[build-dependencies]
tonic-prost-build = "0.14"
Expand Down
2 changes: 1 addition & 1 deletion proto
8 changes: 4 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn default_url() -> Url {

#[derive(Parser, Debug, Deserialize)]
#[command(version)]
pub struct Config {
pub struct EnvConfig {
// port the API server will listen on
#[arg(
long,
Expand Down Expand Up @@ -77,15 +77,15 @@ pub enum ConfigError {
ParseError(#[from] toml::de::Error),
}

pub fn get_config() -> Result<Config, ConfigError> {
pub fn get_env_config() -> Result<EnvConfig, ConfigError> {
// parse CLI arguments to get config file path
let cli_config = Config::parse();
let cli_config = EnvConfig::parse();

// load config from file if one was specified
if let Some(config_path) = cli_config.config_path {
info!("Reading configuration from file: {config_path:?}");
let config_toml = read_to_string(config_path)?;
let file_config: Config = toml::from_str(&config_toml)?;
let file_config: EnvConfig = toml::from_str(&config_toml)?;
Ok(file_config)
} else {
Ok(cli_config)
Expand Down
27 changes: 4 additions & 23 deletions src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use tracing::Instrument;

use crate::{
error::ApiError,
http::GRPC_SERVER_RESTART_CHANNEL,
proto::{core_request, core_response, proxy_server, CoreRequest, CoreResponse, DeviceInfo},
MIN_CORE_VERSION, VERSION,
};
Expand All @@ -34,9 +33,9 @@ use crate::{
type ClientMap = HashMap<SocketAddr, mpsc::UnboundedSender<Result<CoreRequest, Status>>>;

#[derive(Debug, Clone, Default)]
pub(crate) struct Configuration {
pub(crate) grpc_key_pem: String,
pub(crate) grpc_cert_pem: String,
pub struct Configuration {
pub grpc_key_pem: String,
pub grpc_cert_pem: String,
}

pub(crate) struct ProxyServer {
Expand All @@ -47,7 +46,6 @@ pub(crate) struct ProxyServer {
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
config: Arc<Mutex<Option<Configuration>>>,
cookie_key: Arc<RwLock<Option<Key>>>,
setup_in_progress: Arc<AtomicBool>,
}

impl ProxyServer {
Expand All @@ -62,21 +60,9 @@ impl ProxyServer {
connected: Arc::new(AtomicBool::new(false)),
core_version: Arc::new(Mutex::new(None)),
config: Arc::new(Mutex::new(None)),
setup_in_progress: Arc::new(AtomicBool::new(false)),
}
}

pub(crate) fn set_tls_config(&self, cert_pem: String, key_pem: String) -> Result<(), ApiError> {
let mut lock = self
.config
.lock()
.expect("Failed to acquire lock on config mutex when updating TLS configuration");
let config = lock.get_or_insert_with(Configuration::default);
config.grpc_cert_pem = cert_pem;
config.grpc_key_pem = key_pem;
Ok(())
}

pub(crate) fn configure(&self, config: Configuration) {
let mut lock = self
.config
Expand Down Expand Up @@ -121,11 +107,7 @@ impl ProxyServer {

builder
.add_service(versioned_service)
.serve_with_shutdown(addr, async move {
let mut rx_lock = GRPC_SERVER_RESTART_CHANNEL.1.lock().await;
rx_lock.recv().await;
info!("Shutting down gRPC server for restart...");
})
.serve(addr)
.await
.map_err(|err| {
error!("gRPC server error: {err}");
Expand Down Expand Up @@ -194,7 +176,6 @@ impl Clone for ProxyServer {
core_version: Arc::clone(&self.core_version),
cookie_key: Arc::clone(&self.cookie_key),
config: Arc::clone(&self.config),
setup_in_progress: Arc::clone(&self.setup_in_progress),
}
}
}
Expand Down
135 changes: 58 additions & 77 deletions src/http.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::{
collections::HashMap,
fs::read_to_string,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::Path,
sync::{atomic::Ordering, Arc, LazyLock, RwLock},
sync::{atomic::Ordering, Arc, RwLock},
time::Duration,
};

Expand All @@ -21,11 +20,7 @@ use axum_extra::extract::cookie::Key;
use clap::crate_version;
use defguard_version::{server::DefguardVersionLayer, Version};
use serde::Serialize;
use tokio::{
net::TcpListener,
sync::{oneshot, Mutex},
task::JoinSet,
};
use tokio::{net::TcpListener, sync::oneshot, task::JoinSet};
use tower_governor::{
governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer,
};
Expand All @@ -35,13 +30,13 @@ use url::Url;

use crate::{
assets::{index, web_asset},
config::Config,
config::EnvConfig,
enterprise::handlers::openid_login::{self, FlowType},
error::ApiError,
grpc::{Configuration, ProxyServer},
handlers::{desktop_client_mfa, enrollment, password_reset, polling},
setup::ProxySetupServer,
CommsChannel, VERSION,
LogsReceiver, VERSION,
};

pub(crate) static ENROLLMENT_COOKIE_NAME: &str = "defguard_proxy";
Expand All @@ -51,13 +46,8 @@ const DEFGUARD_CORE_VERSION_HEADER: &str = "defguard-core-version";
const RATE_LIMITER_CLEANUP_PERIOD: Duration = Duration::from_secs(60);
const X_FORWARDED_FOR: &str = "x-forwarded-for";
const X_POWERED_BY: &str = "x-powered-by";
const GRPC_CERT_NAME: &str = "proxy_grpc_cert.pem";
const GRPC_KEY_NAME: &str = "proxy_grpc_key.pem";

pub static GRPC_SERVER_RESTART_CHANNEL: LazyLock<CommsChannel<()>> = LazyLock::new(|| {
let (tx, rx) = tokio::sync::mpsc::channel(100);
(Arc::new(Mutex::new(tx)), Arc::new(Mutex::new(rx)))
});
pub const GRPC_CERT_NAME: &str = "proxy_grpc_cert.pem";
pub const GRPC_KEY_NAME: &str = "proxy_grpc_key.pem";

#[derive(Clone)]
pub(crate) struct AppState {
Expand Down Expand Up @@ -177,6 +167,45 @@ async fn powered_by_header<B>(mut response: Response<B>) -> Response<B> {
response
}

pub async fn run_setup(
env_config: &EnvConfig,
logs_rx: LogsReceiver,
) -> anyhow::Result<Configuration> {
let setup_server = ProxySetupServer::new(logs_rx);
let cert_dir = Path::new(&env_config.cert_dir);
if !cert_dir.exists() {
tokio::fs::create_dir_all(cert_dir).await?;
}

// Only attempt setup if not already configured
info!(
"No gRPC TLS certificates found at {}, new certificates will be obtained during setup",
cert_dir.display()
);
let configuration = setup_server
.await_initial_setup(SocketAddr::new(
env_config
.grpc_bind_address
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
env_config.grpc_port,
))
.await?;
info!("Generated new gRPC TLS certificates and signed by Defguard Core");

let Configuration {
grpc_cert_pem,
grpc_key_pem,
..
} = &configuration;

let cert_path = cert_dir.join(GRPC_CERT_NAME);
let key_path = cert_dir.join(GRPC_KEY_NAME);
tokio::fs::write(&cert_path, grpc_cert_pem).await?;
tokio::fs::write(&key_path, grpc_key_pem).await?;

Ok(configuration)
}

/// Middleware that gates all HTTP endpoints except health checks until the proxy
/// is fully configured.
///
Expand Down Expand Up @@ -206,9 +235,9 @@ async fn ensure_configured(
next.run(request).await
}

pub async fn run_server(config: Config) -> anyhow::Result<()> {
pub async fn run_server(env_config: EnvConfig, config: Configuration) -> anyhow::Result<()> {
info!("Starting Defguard Proxy server");
debug!("Using config: {config:?}");
debug!("Using config: {env_config:?}");

let mut tasks = JoinSet::new();
let cookie_key = Default::default();
Expand All @@ -217,68 +246,20 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
let grpc_server = ProxyServer::new(Arc::clone(&cookie_key));

let server_clone = grpc_server.clone();

let setup_server = ProxySetupServer::new();
grpc_server.configure(config);

// Start gRPC server.
// TODO: Wait with spawning the HTTP server until gRPC server is ready.
debug!("Spawning gRPC server");
tasks.spawn(async move {
let cert_dir = Path::new(&config.cert_dir);
if !cert_dir.exists() {
debug!("Creating certs directory");
tokio::fs::create_dir_all(cert_dir).await?;
}

loop {
info!("Starting gRPC server...");
let server_to_run = server_clone.clone();

if let (Some(cert), Some(key)) = (
read_to_string(cert_dir.join(GRPC_CERT_NAME)).ok(),
read_to_string(cert_dir.join(GRPC_KEY_NAME)).ok(),
) {
info!(
"Using existing gRPC TLS certificates from {}",
cert_dir.display()
);
server_clone.set_tls_config(cert, key)?;
} else if !server_clone.setup_completed() {
// Only attempt setup if not already configured
info!(
"No gRPC TLS certificates found at {}, new certificates will be generated",
cert_dir.display()
);
let configuration = setup_server
.await_setup(SocketAddr::new(
config
.grpc_bind_address
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
config.grpc_port,
))
.await?;
info!("Generated new gRPC TLS certificates and signed by Defguard Core");

let Configuration {
grpc_cert_pem,
grpc_key_pem,
..
} = &configuration;

let cert_path = cert_dir.join(GRPC_CERT_NAME);
let key_path = cert_dir.join(GRPC_KEY_NAME);
tokio::fs::write(&cert_path, grpc_cert_pem).await?;
tokio::fs::write(&key_path, grpc_key_pem).await?;

server_to_run.configure(configuration);
} else {
info!("Proxy already configured, skipping setup phase");
}

let addr = SocketAddr::new(
config
env_config
.grpc_bind_address
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
config.grpc_port,
env_config.grpc_port,
);

if let Err(e) = server_to_run.run(addr).await {
Expand All @@ -293,18 +274,18 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
cookie_key,
grpc_server,
remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
url: config.url.clone(),
url: env_config.url.clone(),
};

// Setup tower_governor rate-limiter
debug!(
"Configuring rate limiter, per_second: {}, burst: {}",
config.rate_limit_per_second, config.rate_limit_burst
env_config.rate_limit_per_second, env_config.rate_limit_burst
);
let governor_conf = GovernorConfigBuilder::default()
.key_extractor(SmartIpKeyExtractor)
.per_second(config.rate_limit_per_second)
.burst_size(config.rate_limit_burst)
.per_second(env_config.rate_limit_per_second)
.burst_size(env_config.rate_limit_burst)
.finish();

let governor_conf = if let Some(conf) = governor_conf {
Expand All @@ -323,7 +304,7 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
});
info!(
"Configured rate limiter, per_second: {}, burst: {}",
config.rate_limit_per_second, config.rate_limit_burst
env_config.rate_limit_per_second, env_config.rate_limit_burst
);
Some(conf)
} else {
Expand Down Expand Up @@ -385,10 +366,10 @@ pub async fn run_server(config: Config) -> anyhow::Result<()> {
debug!("Spawning API web server");
tasks.spawn(async move {
let addr = SocketAddr::new(
config
env_config
.http_bind_address
.unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)),
config.http_port,
env_config.http_port,
);
let listener = TcpListener::bind(&addr).await?;
info!("API web server is listening on {addr}");
Expand Down
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ use std::sync::Arc;
use defguard_version::Version;
use tokio::sync::mpsc;

use crate::proto::LogEntry;

pub mod assets;
pub mod config;
mod enterprise;
mod error;
mod grpc;
pub mod grpc;
mod handlers;
pub mod http;
pub mod logging;
Expand All @@ -27,3 +29,5 @@ type CommsChannel<T> = (
Arc<tokio::sync::Mutex<mpsc::Sender<T>>>,
Arc<tokio::sync::Mutex<mpsc::Receiver<T>>>,
);

type LogsReceiver = Arc<tokio::sync::Mutex<mpsc::Receiver<LogEntry>>>;
Loading