Skip to content

Commit bd763e7

Browse files
committed
Adds global and per ip connection rate limiting
Adds connection rate limiting through governor handled by burstable quotas added to the pgwire server config. Limits are checked in handle_connection, and controlled by the following system params. - PGWIRE_CONNECTION_RATE_LIMIT - PGWIRE_CONNECTION_RATE_LIMIT_BURST - PGWIRE_CONNECTION_RATE_LIMIT_PER_IP - PGWIRE_CONNECTION_RATE_LIMIT_BURST_PER_IP Environmentd must be restarted to update params.
1 parent 2e07b38 commit bd763e7

File tree

8 files changed

+287
-4
lines changed

8 files changed

+287
-4
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/environmentd/src/lib.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
//! [timely dataflow]: ../timely/index.html
1515
1616
use std::collections::BTreeMap;
17+
use std::num::NonZeroU32;
1718
use std::panic::AssertUnwindSafe;
1819
use std::path::PathBuf;
1920
use std::pin::Pin;
@@ -52,7 +53,7 @@ use mz_ore::url::SensitiveUrl;
5253
use mz_ore::{instrument, task};
5354
use mz_persist_client::cache::PersistClientCache;
5455
use mz_persist_client::usage::StorageUsageClient;
55-
use mz_pgwire::MetricsConfig;
56+
use mz_pgwire::{MetricsConfig, RateLimitConfig};
5657
use mz_pgwire_common::ConnectionCounter;
5758
use mz_repr::strconv;
5859
use mz_secrets::SecretsController;
@@ -283,6 +284,34 @@ impl Listener<SqlListenerConfig> {
283284
AuthenticatorKind::None => Authenticator::None,
284285
};
285286

287+
// Get rate limit configuration from system vars.
288+
let system_vars = adapter_client.get_system_vars().await;
289+
let rate_limit = {
290+
let global_rate = system_vars.pgwire_connection_rate_limit();
291+
let per_ip_rate = system_vars.pgwire_connection_rate_limit_per_ip();
292+
293+
// Only enable rate limiting if at least one rate limit is configured (non-zero).
294+
if global_rate > 0 || per_ip_rate > 0 {
295+
let global_burst = system_vars.pgwire_connection_rate_limit_burst();
296+
let per_ip_burst = system_vars.pgwire_connection_rate_limit_per_ip_burst();
297+
298+
// If burst is 0, use the rate limit value as the burst size.
299+
let global_rate = NonZeroU32::new(global_rate).unwrap_or(NonZeroU32::MIN);
300+
let global_burst = NonZeroU32::new(global_burst).unwrap_or(global_rate);
301+
let per_ip_rate = NonZeroU32::new(per_ip_rate).unwrap_or(NonZeroU32::MIN);
302+
let per_ip_burst = NonZeroU32::new(per_ip_burst).unwrap_or(per_ip_rate);
303+
304+
Some(RateLimitConfig {
305+
global_rate_per_second: global_rate,
306+
global_burst_size: global_burst,
307+
per_ip_rate_per_second: per_ip_rate,
308+
per_ip_burst_size: per_ip_burst,
309+
})
310+
} else {
311+
None
312+
}
313+
};
314+
286315
task::spawn(|| format!("{}_sql_server", label), {
287316
let sql_server = mz_pgwire::Server::new(mz_pgwire::Config {
288317
label,
@@ -293,6 +322,7 @@ impl Listener<SqlListenerConfig> {
293322
active_connection_counter,
294323
helm_chart_version,
295324
allowed_roles: self.config.allowed_roles,
325+
rate_limit,
296326
});
297327
mz_server_core::serve(ServeConfig {
298328
conns: self.connection_stream,

src/environmentd/tests/pgwire.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,3 +800,104 @@ fn test_pgtest_mz_transactions() {
800800
fn test_pgtest_mz_vars() {
801801
pg_test_inner(Path::new("../../test/pgtest-mz/vars.pt"), true);
802802
}
803+
804+
#[mz_ore::test]
805+
fn test_connection_rate_limiting() {
806+
// Start a server with a very restrictive rate limit: 1 connection per second with burst of 2.
807+
let server = test_util::TestHarness::default()
808+
.with_system_parameter_default("pgwire_connection_rate_limit".to_string(), "1".to_string())
809+
.with_system_parameter_default(
810+
"pgwire_connection_rate_limit_burst".to_string(),
811+
"2".to_string(),
812+
)
813+
.start_blocking();
814+
815+
// First two connections should succeed (burst allows 2).
816+
let client1 = server.connect(postgres::NoTls);
817+
assert!(client1.is_ok(), "first connection should succeed");
818+
819+
let client2 = server.connect(postgres::NoTls);
820+
assert!(
821+
client2.is_ok(),
822+
"second connection should succeed (within burst)"
823+
);
824+
825+
// Third connection should fail due to rate limiting (burst exhausted).
826+
// Retry to handle transient network errors.
827+
let err = Retry::default()
828+
.retry(|_| {
829+
server
830+
.connect(postgres::NoTls)
831+
.err()
832+
.unwrap()
833+
.as_db_error()
834+
.cloned()
835+
.ok_or("expected database error")
836+
})
837+
.unwrap();
838+
839+
assert_eq!(err.severity(), "FATAL");
840+
assert_eq!(*err.code(), SqlState::TOO_MANY_CONNECTIONS);
841+
assert_eq!(err.message(), "too many connections");
842+
843+
// Wait for rate limiter to replenish (1 second + buffer).
844+
std::thread::sleep(Duration::from_millis(1100));
845+
846+
// Now a new connection should succeed again.
847+
let client4 = server.connect(postgres::NoTls);
848+
assert!(
849+
client4.is_ok(),
850+
"connection after rate limit replenish should succeed"
851+
);
852+
}
853+
854+
#[mz_ore::test]
855+
fn test_connection_rate_limiting_per_ip() {
856+
// Start a server with per-IP rate limiting: 1 connection per second per IP with burst of 1.
857+
let server = test_util::TestHarness::default()
858+
.with_system_parameter_default(
859+
"pgwire_connection_rate_limit_per_ip".to_string(),
860+
"1".to_string(),
861+
)
862+
.with_system_parameter_default(
863+
"pgwire_connection_rate_limit_per_ip_burst".to_string(),
864+
"1".to_string(),
865+
)
866+
.start_blocking();
867+
868+
// First connection should succeed.
869+
let client1 = server.connect(postgres::NoTls);
870+
assert!(client1.is_ok(), "first connection should succeed");
871+
872+
// Second connection from same IP should fail due to per-IP rate limiting.
873+
// Retry to handle transient network errors.
874+
let err = Retry::default()
875+
.retry(|_| {
876+
server
877+
.connect(postgres::NoTls)
878+
.err()
879+
.unwrap()
880+
.as_db_error()
881+
.cloned()
882+
.ok_or("expected database error")
883+
})
884+
.unwrap();
885+
886+
assert_eq!(err.severity(), "FATAL");
887+
assert_eq!(*err.code(), SqlState::TOO_MANY_CONNECTIONS);
888+
assert!(
889+
err.message().starts_with("too many connections from"),
890+
"unexpected error message: {}",
891+
err.message()
892+
);
893+
894+
// Wait for rate limiter to replenish.
895+
std::thread::sleep(Duration::from_millis(1100));
896+
897+
// Now a new connection should succeed again.
898+
let client3 = server.connect(postgres::NoTls);
899+
assert!(
900+
client3.is_ok(),
901+
"connection after rate limit replenish should succeed"
902+
);
903+
}

src/pgwire/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ bytes = "1.10.1"
1818
bytesize = "2.1.0"
1919
enum-kinds = "0.5.1"
2020
futures = "0.3.31"
21+
governor = "0.10.1"
2122
itertools = "0.14.0"
2223
mz-adapter = { path = "../adapter" }
2324
mz-adapter-types = { path = "../adapter-types" }

src/pgwire/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ mod server;
3030

3131
pub use metrics::MetricsConfig;
3232
pub use protocol::match_handshake;
33-
pub use server::{Config, Server};
33+
pub use server::{Config, RateLimitConfig, Server};

src/pgwire/src/server.rs

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,28 @@
99

1010
use std::future::Future;
1111
use std::net::IpAddr;
12+
use std::num::NonZeroU32;
1213
use std::pin::Pin;
1314
use std::str::FromStr;
15+
use std::sync::Arc;
1416

1517
use anyhow::Context;
1618
use async_trait::async_trait;
19+
use governor::clock::DefaultClock;
20+
use governor::middleware::NoOpMiddleware;
21+
use governor::state::keyed::DashMapStateStore;
22+
use governor::state::{InMemoryState, NotKeyed};
23+
use governor::{Quota, RateLimiter};
1724
use mz_authenticator::Authenticator;
1825
use mz_ore::now::{SYSTEM_TIME, epoch_to_uuid_v7};
1926
use mz_pgwire_common::{
20-
ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, FrontendStartupMessage,
21-
MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
27+
ACCEPT_SSL_ENCRYPTION, CONN_UUID_KEY, Conn, ConnectionCounter, ErrorResponse,
28+
FrontendStartupMessage, MZ_FORWARDED_FOR_KEY, REJECT_ENCRYPTION, decode_startup,
2229
};
2330
use mz_server_core::listeners::AllowedRoles;
2431
use mz_server_core::{Connection, ConnectionHandler, ReloadingTlsConfig};
2532
use openssl::ssl::Ssl;
33+
use postgres::error::SqlState;
2634
use tokio::io::AsyncWriteExt;
2735
use tokio_metrics::TaskMetrics;
2836
use tokio_openssl::SslStream;
@@ -32,6 +40,26 @@ use crate::codec::FramedConn;
3240
use crate::metrics::{Metrics, MetricsConfig};
3341
use crate::protocol;
3442

43+
/// Type alias for the global rate limiter.
44+
type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock, NoOpMiddleware>;
45+
46+
/// Type alias for the per-IP rate limiter.
47+
type PerIpRateLimiter =
48+
RateLimiter<IpAddr, DashMapStateStore<IpAddr>, DefaultClock, NoOpMiddleware>;
49+
50+
/// Configuration for connection rate limiting.
51+
#[derive(Debug, Clone)]
52+
pub struct RateLimitConfig {
53+
/// Maximum connections per second globally.
54+
pub global_rate_per_second: NonZeroU32,
55+
/// Maximum burst size for global rate limiting.
56+
pub global_burst_size: NonZeroU32,
57+
/// Maximum connections per second per IP address.
58+
pub per_ip_rate_per_second: NonZeroU32,
59+
/// Maximum burst size for per-IP rate limiting.
60+
pub per_ip_burst_size: NonZeroU32,
61+
}
62+
3563
/// Configures a [`Server`].
3664
#[derive(Debug)]
3765
pub struct Config {
@@ -54,6 +82,8 @@ pub struct Config {
5482
pub helm_chart_version: Option<String>,
5583
/// Whether to allow reserved users (ie: mz_system).
5684
pub allowed_roles: AllowedRoles,
85+
/// Optional rate limiting configuration for new connections.
86+
pub rate_limit: Option<RateLimitConfig>,
5787
}
5888

5989
/// A server that communicates with clients via the pgwire protocol.
@@ -65,6 +95,8 @@ pub struct Server {
6595
active_connection_counter: ConnectionCounter,
6696
helm_chart_version: Option<String>,
6797
allowed_roles: AllowedRoles,
98+
global_rate_limiter: Option<Arc<GlobalRateLimiter>>,
99+
per_ip_rate_limiter: Option<Arc<PerIpRateLimiter>>,
68100
}
69101

70102
#[async_trait]
@@ -90,6 +122,21 @@ impl mz_server_core::Server for Server {
90122
impl Server {
91123
/// Constructs a new server.
92124
pub fn new(config: Config) -> Server {
125+
let (global_rate_limiter, per_ip_rate_limiter) = match config.rate_limit {
126+
Some(rate_limit) => {
127+
let global_quota = Quota::per_second(rate_limit.global_rate_per_second)
128+
.allow_burst(rate_limit.global_burst_size);
129+
let global_limiter = Arc::new(RateLimiter::direct(global_quota));
130+
131+
let per_ip_quota = Quota::per_second(rate_limit.per_ip_rate_per_second)
132+
.allow_burst(rate_limit.per_ip_burst_size);
133+
let per_ip_limiter = Arc::new(RateLimiter::keyed(per_ip_quota));
134+
135+
(Some(global_limiter), Some(per_ip_limiter))
136+
}
137+
None => (None, None),
138+
};
139+
93140
Server {
94141
tls: config.tls,
95142
adapter_client: config.adapter_client,
@@ -98,6 +145,8 @@ impl Server {
98145
active_connection_counter: config.active_connection_counter,
99146
helm_chart_version: config.helm_chart_version,
100147
allowed_roles: config.allowed_roles,
148+
global_rate_limiter,
149+
per_ip_rate_limiter,
101150
}
102151
}
103152

@@ -114,6 +163,8 @@ impl Server {
114163
let active_connection_counter = self.active_connection_counter.clone();
115164
let helm_chart_version = self.helm_chart_version.clone();
116165
let allowed_roles = self.allowed_roles;
166+
let global_rate_limiter = self.global_rate_limiter.clone();
167+
let per_ip_rate_limiter = self.per_ip_rate_limiter.clone();
117168

118169
// TODO(guswynn): remove this redundant_closure_call
119170
#[allow(clippy::redundant_closure_call)]
@@ -171,6 +222,49 @@ impl Server {
171222
}
172223
None => Some(direct_peer_addr)
173224
};
225+
226+
// Check global rate limit. This protects against connection floods.
227+
// We check after SSL negotiation so clients receive a proper error.
228+
if let Some(ref limiter) = global_rate_limiter {
229+
if limiter.check().is_err() {
230+
debug!("global connection rate limit exceeded");
231+
let mut conn = FramedConn::new(
232+
conn_id.clone(),
233+
peer_addr,
234+
conn,
235+
);
236+
conn.send(ErrorResponse::fatal(
237+
SqlState::TOO_MANY_CONNECTIONS,
238+
"too many connections",
239+
))
240+
.await?;
241+
conn.flush().await?;
242+
return Ok(());
243+
}
244+
}
245+
246+
// Check per-IP rate limit using the forwarded IP address.
247+
// This protects against connection floods from specific clients.
248+
if let Some(ref limiter) = per_ip_rate_limiter {
249+
// Use the forwarded IP if available, otherwise use the direct peer.
250+
let rate_limit_ip = peer_addr.unwrap_or(direct_peer_addr);
251+
if limiter.check_key(&rate_limit_ip).is_err() {
252+
debug!(%rate_limit_ip, "per-IP connection rate limit exceeded");
253+
let mut conn = FramedConn::new(
254+
conn_id.clone(),
255+
peer_addr,
256+
conn,
257+
);
258+
conn.send(ErrorResponse::fatal(
259+
SqlState::TOO_MANY_CONNECTIONS,
260+
format!("too many connections from {}", rate_limit_ip),
261+
))
262+
.await?;
263+
conn.flush().await?;
264+
return Ok(());
265+
}
266+
}
267+
174268
let mut conn = FramedConn::new(
175269
conn_id.clone(),
176270
peer_addr,

src/sql/src/session/vars.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,10 @@ impl SystemVars {
11481148
&KAFKA_PROGRESS_RECORD_FETCH_TIMEOUT,
11491149
&ENABLE_LAUNCHDARKLY,
11501150
&MAX_CONNECTIONS,
1151+
&PGWIRE_CONNECTION_RATE_LIMIT,
1152+
&PGWIRE_CONNECTION_RATE_LIMIT_BURST,
1153+
&PGWIRE_CONNECTION_RATE_LIMIT_PER_IP,
1154+
&PGWIRE_CONNECTION_RATE_LIMIT_PER_IP_BURST,
11511155
&NETWORK_POLICY,
11521156
&SUPERUSER_RESERVED_CONNECTIONS,
11531157
&KEEP_N_SOURCE_STATUS_HISTORY_ENTRIES,
@@ -1951,6 +1955,30 @@ impl SystemVars {
19511955
*self.expect_value(&SUPERUSER_RESERVED_CONNECTIONS)
19521956
}
19531957

1958+
/// Returns the `pgwire_connection_rate_limit` configuration parameter.
1959+
/// A value of 0 means rate limiting is disabled.
1960+
pub fn pgwire_connection_rate_limit(&self) -> u32 {
1961+
*self.expect_value(&PGWIRE_CONNECTION_RATE_LIMIT)
1962+
}
1963+
1964+
/// Returns the `pgwire_connection_rate_limit_burst` configuration parameter.
1965+
/// A value of 0 means use the rate limit value as the burst size.
1966+
pub fn pgwire_connection_rate_limit_burst(&self) -> u32 {
1967+
*self.expect_value(&PGWIRE_CONNECTION_RATE_LIMIT_BURST)
1968+
}
1969+
1970+
/// Returns the `pgwire_connection_rate_limit_per_ip` configuration parameter.
1971+
/// A value of 0 means rate limiting is disabled.
1972+
pub fn pgwire_connection_rate_limit_per_ip(&self) -> u32 {
1973+
*self.expect_value(&PGWIRE_CONNECTION_RATE_LIMIT_PER_IP)
1974+
}
1975+
1976+
/// Returns the `pgwire_connection_rate_limit_per_ip_burst` configuration parameter.
1977+
/// A value of 0 means use the rate limit value as the burst size.
1978+
pub fn pgwire_connection_rate_limit_per_ip_burst(&self) -> u32 {
1979+
*self.expect_value(&PGWIRE_CONNECTION_RATE_LIMIT_PER_IP_BURST)
1980+
}
1981+
19541982
pub fn keep_n_source_status_history_entries(&self) -> usize {
19551983
*self.expect_value(&KEEP_N_SOURCE_STATUS_HISTORY_ENTRIES)
19561984
}

0 commit comments

Comments
 (0)