From 1203369317c76a98d2b048939432fd019cf38a5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 28 Jan 2026 09:23:29 +0100 Subject: [PATCH 01/19] create new VPN client session on MFA success --- .../defguard_core/src/grpc/proxy/client_mfa.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 667314eaf..274f8b0a2 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -12,6 +12,7 @@ use defguard_common::{ models::{ BiometricAuth, BiometricChallenge, Device, DeviceNetworkInfo, User, WireguardNetwork, device::{DeviceInfo, WireguardNetworkDevice}, + vpn_client_session::VpnClientSession, wireguard::LocationMfaMode, }, }, @@ -711,6 +712,22 @@ impl ClientMfaServer { )), })?; + // create new VPN client session + let vpn_client_session = VpnClientSession::new( + location.id, + user.id, + device.id, + None, + location.location_mfa_mode.clone(), + ) + .save(&mut *transaction) + .await + .map_err(|err| { + error!("Failed to create new VPN client session for device {device} in location {location}: {err}"); + Status::internal("unexpected error") + })?; + debug!("Created new VPN client session: {vpn_client_session:?}"); + let response = ClientMfaFinishResponse { preshared_key: key.public, token: match method { From 50f309bf32937611c0a5801586f1f64ea14da401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 28 Jan 2026 12:35:08 +0100 Subject: [PATCH 02/19] handle optional enums in model derive --- crates/model_derive/src/lib.rs | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/crates/model_derive/src/lib.rs b/crates/model_derive/src/lib.rs index 21f1e20cb..edb91db3d 100644 --- a/crates/model_derive/src/lib.rs +++ b/crates/model_derive/src/lib.rs @@ -1,8 +1,8 @@ use proc_macro::TokenStream; use quote::quote; use syn::{ - Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Ident, Path, Type, TypePath, - meta::parser, parse::Parser, parse_macro_input, + Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, GenericArgument, Ident, Path, + PathArguments, Type, TypePath, meta::parser, parse::Parser, parse_macro_input, }; /// Try to find the value of `model` attribute, e.g. `#[model(model_type)]`. @@ -42,13 +42,34 @@ fn field_type(ty: &Type) -> Option<&Ident> { .. }) = ty { - if let Some(segment) = segments.first() { + if let Some(segment) = segments.last() { return Some(&segment.ident); } } None } +fn option_field_type(ty: &Type) -> Option<&Ident> { + if let Type::Path(TypePath { + path: Path { segments, .. }, + .. + }) = ty + { + if let Some(segment) = segments.last() { + if segment.ident == "Option" { + // Extract the generic arguments + if let PathArguments::AngleBracketed(args) = &segment.arguments { + // Get the first generic argument (the T in Option) + if let Some(GenericArgument::Type(inner_ty)) = args.args.first() { + return field_type(inner_ty); + } + } + } + } + } + None +} + #[proc_macro_derive(Model, attributes(table, model))] pub fn derive(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); @@ -125,6 +146,8 @@ pub fn derive(input: TokenStream) -> TokenStream { cs_aliased_fields.push_str("?: SecretString\""); } else if field_type == "ip" { cs_aliased_fields.push_str(": IpAddr\""); + } else if field_type == "option" { + cs_aliased_fields.push_str("?: _\""); } else { cs_aliased_fields.push_str(": _\""); } @@ -152,6 +175,10 @@ pub fn derive(input: TokenStream) -> TokenStream { if let Some(field_type) = field_type(&field.ty) { return Some(quote! { &self.#name as &#field_type }); } + } else if tokens == "option" { + if let Some(field_type) = option_field_type(&field.ty) { + return Some(quote! { &self.#name as &Option<#field_type> }); + } } else if tokens == "secret" { // FIXME: hard-coded struct name return Some(quote! { &self.#name as &Option }); From d67a2e0093b818dce59cdd58cb6cfcf79bd8a1a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 28 Jan 2026 12:35:34 +0100 Subject: [PATCH 03/19] add mfa method to session --- .../src/db/models/vpn_client_session.rs | 27 ++++++++++++------- .../src/db/models/wireguard.rs | 4 +-- .../src/grpc/proxy/client_mfa.rs | 2 +- crates/defguard_proto/src/lib.rs | 13 +++++++++ .../src/session_state.rs | 2 +- ...43_[2.0.0]_add_session_mfa_method.down.sql | 9 +++++++ ...2943_[2.0.0]_add_session_mfa_method.up.sql | 15 +++++++++++ .../src/vpn_session_stats.rs | 2 +- 8 files changed, 60 insertions(+), 14 deletions(-) create mode 100644 migrations/20260128082943_[2.0.0]_add_session_mfa_method.down.sql create mode 100644 migrations/20260128082943_[2.0.0]_add_session_mfa_method.up.sql diff --git a/crates/defguard_common/src/db/models/vpn_client_session.rs b/crates/defguard_common/src/db/models/vpn_client_session.rs index 8050f4ecc..b080f21a6 100644 --- a/crates/defguard_common/src/db/models/vpn_client_session.rs +++ b/crates/defguard_common/src/db/models/vpn_client_session.rs @@ -4,7 +4,7 @@ use sqlx::{Error as SqlxError, Type, query_as}; use crate::db::{ Id, NoId, - models::{WireguardNetwork, vpn_session_stats::VpnSessionStats, wireguard::LocationMfaMode}, + models::{WireguardNetwork, vpn_session_stats::VpnSessionStats}, }; #[derive(Debug, Default, Type)] @@ -16,6 +16,16 @@ pub enum VpnClientSessionState { Disconnected, } +#[derive(Debug, Type)] +#[sqlx(type_name = "vpn_client_mfa_method", rename_all = "lowercase")] +pub enum VpnClientMfaMethod { + Totp, + Email, + Oidc, + Biometric, + MobileApprove, +} + /// Represents a single VPN client session from creation to eventual disconnection #[derive(Debug, Model)] #[table(vpn_client_session)] @@ -27,9 +37,8 @@ pub struct VpnClientSession { pub created_at: NaiveDateTime, pub connected_at: Option, pub disconnected_at: Option, - // TODO: use actual MFA method used to connect - #[model(enum)] - pub mfa_mode: LocationMfaMode, + #[model(option)] + pub mfa_method: Option, #[model(enum)] pub state: VpnClientSessionState, } @@ -40,7 +49,7 @@ impl VpnClientSession { user_id: Id, device_id: Id, connected_at: Option, - mfa_mode: LocationMfaMode, + mfa_method: Option, ) -> Self { // determine session state let state = if connected_at.is_some() { @@ -57,7 +66,7 @@ impl VpnClientSession { created_at: Utc::now().naive_utc(), connected_at, disconnected_at: None, - mfa_mode, + mfa_method, state, } } @@ -75,7 +84,7 @@ impl VpnClientSession { query_as!( Self, "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \ - mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" \ + mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" \ FROM vpn_client_session \ WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')", location_id, @@ -111,7 +120,7 @@ impl VpnClientSession { query_as!( Self, "SELECT s.id, location_id, user_id, device_id, created_at, s.connected_at, disconnected_at, \ - mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" \ + mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" \ FROM vpn_client_session s \ LEFT JOIN LATERAL ( \ SELECT latest_handshake \ @@ -135,7 +144,7 @@ impl VpnClientSession { query_as!( Self, "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, \ - mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" \ + mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" \ FROM vpn_client_session \ WHERE location_id = $1 AND state = 'new' \ AND (NOW() - created_at) > $2 * interval '1 second'", diff --git a/crates/defguard_common/src/db/models/wireguard.rs b/crates/defguard_common/src/db/models/wireguard.rs index 6dddd4b16..672732dd5 100644 --- a/crates/defguard_common/src/db/models/wireguard.rs +++ b/crates/defguard_common/src/db/models/wireguard.rs @@ -31,7 +31,7 @@ use crate::{ db::{ Id, NoId, models::{ - vpn_client_session::{VpnClientSession, VpnClientSessionState}, + vpn_client_session::{VpnClientMfaMethod, VpnClientSession, VpnClientSessionState}, vpn_session_stats::VpnSessionStats, }, }, @@ -1019,7 +1019,7 @@ impl WireguardNetwork { query_as!( VpnClientSession, "SELECT id, location_id, user_id, device_id, \ - created_at, connected_at, disconnected_at, mfa_mode \"mfa_mode: LocationMfaMode\", \ + created_at, connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", \ state \"state: VpnClientSessionState\" \ FROM vpn_client_session \ WHERE location_id = $1 AND state = 'connected'::vpn_client_session_state", diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 274f8b0a2..8cbef7104 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -718,7 +718,7 @@ impl ClientMfaServer { user.id, device.id, None, - location.location_mfa_mode.clone(), + Some(method.into()), ) .save(&mut *transaction) .await diff --git a/crates/defguard_proto/src/lib.rs b/crates/defguard_proto/src/lib.rs index df37922bb..471adb7f4 100644 --- a/crates/defguard_proto/src/lib.rs +++ b/crates/defguard_proto/src/lib.rs @@ -24,6 +24,7 @@ use defguard_common::{ Id, models::{ Device, DeviceConfig, User, + vpn_client_session::VpnClientMfaMethod, wireguard::{LocationMfaMode, ServiceLocationMode}, }, }, @@ -66,6 +67,18 @@ impl Serialize for MfaMethod { } } +impl Into for MfaMethod { + fn into(self) -> VpnClientMfaMethod { + match self { + Self::Totp => VpnClientMfaMethod::Totp, + Self::Email => VpnClientMfaMethod::Email, + Self::Oidc => VpnClientMfaMethod::Oidc, + Self::Biometric => VpnClientMfaMethod::Biometric, + Self::MobileApprove => VpnClientMfaMethod::MobileApprove, + } + } +} + impl From for CoreError { fn from(status: Status) -> Self { Self { diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index 422a830d7..a8b7e70cd 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -284,7 +284,7 @@ impl ActiveSessionsMap { user.id, device_id, Some(stats_update.latest_handshake), - location.location_mfa_mode.clone(), + None, ) .save(transaction) .await?; diff --git a/migrations/20260128082943_[2.0.0]_add_session_mfa_method.down.sql b/migrations/20260128082943_[2.0.0]_add_session_mfa_method.down.sql new file mode 100644 index 000000000..7bb999b47 --- /dev/null +++ b/migrations/20260128082943_[2.0.0]_add_session_mfa_method.down.sql @@ -0,0 +1,9 @@ +-- Restore MFA mode column +-- This will not restore a correct MFA mode, but it souldn't be an issue outside of development environments +ALTER TABLE vpn_client_session ADD COLUMN mfa_mode location_mfa_mode NOT NULL DEFAULT 'disabled'; + +-- Drop MFA method column +ALTER TABLE vpn_client_session DROP COLUMN mfa_method; + +-- Drop MFA method enum +DROP TYPE vpn_client_mfa_method; diff --git a/migrations/20260128082943_[2.0.0]_add_session_mfa_method.up.sql b/migrations/20260128082943_[2.0.0]_add_session_mfa_method.up.sql new file mode 100644 index 000000000..5675edc2a --- /dev/null +++ b/migrations/20260128082943_[2.0.0]_add_session_mfa_method.up.sql @@ -0,0 +1,15 @@ +-- Add enum for MFA methods +CREATE TYPE vpn_client_mfa_method AS ENUM ( + 'totp', + 'email', + 'oidc', + 'biometric', + 'mobileapprove' +); + +-- Add MFA method column to VPN client session +ALTER TABLE vpn_client_session ADD COLUMN mfa_method vpn_client_mfa_method NULL; + +-- Remove unnecessary MFA type from VPN client session +ALTER TABLE vpn_client_session DROP COLUMN mfa_mode; + diff --git a/tools/defguard_generator/src/vpn_session_stats.rs b/tools/defguard_generator/src/vpn_session_stats.rs index 7e62eb98c..907b95ac3 100644 --- a/tools/defguard_generator/src/vpn_session_stats.rs +++ b/tools/defguard_generator/src/vpn_session_stats.rs @@ -88,7 +88,7 @@ pub async fn generate_vpn_session_stats( device.user_id, device.id, Some(session_start), - LocationMfaMode::Disabled, + None, ) .save(&mut *transaction) .await?; From f0bbda9b7f9fd9cc86dd0450be04045eeb6c037e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 28 Jan 2026 12:35:41 +0100 Subject: [PATCH 04/19] update query data --- ...715ad1262d5eb8cdf76cf4234b6a4971e6769.json} | 18 ++++++++++-------- ...143fef55fb878076487fc97738f120573820f.json} | 18 ++++++++++-------- ...a57f393d5eff60c4bc6703f926bef2583a781.json} | 18 ++++++++++-------- ...00c42e268149b730b7e4f2add5cbdbe7843c8.json} | 18 ++++++++++-------- ...e829ad4a5825707e8b2f55b86801039cd784e.json} | 14 ++++++++------ ...6e315923c12c54b9446c7b18e4b4fb1791f42.json} | 18 ++++++++++-------- ...0280867b8f432fc4bd52ba7194a5895540b4e.json} | 18 ++++++++++-------- ...df997807b66e0b532da747b146513c34e15c5c.json | 8 ++++---- ...47bda3869b74fcb598ec6f2cf8352424f87d6.json} | 14 ++++++++------ ...a36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json | 8 ++++---- 10 files changed, 84 insertions(+), 68 deletions(-) rename .sqlx/{query-e9bfbd2e39ddc1cc0f95258cbd711fa4ea8d63de6e7bd7a0f5cf41119cb3bf86.json => query-2700bf01e6a2afbe3ac2a4b686f715ad1262d5eb8cdf76cf4234b6a4971e6769.json} (73%) rename .sqlx/{query-1815955c24b6178c653bd7a0e4a18dde89c59df69ede331ea25d3dbac93bc5b8.json => query-2bc56fca85ec693d2d53fe08e53143fef55fb878076487fc97738f120573820f.json} (74%) rename .sqlx/{query-b2894d8c60b044744f946ee6b9ae26c24f3778a932d744b1378abbc0e04fb8a5.json => query-3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781.json} (68%) rename .sqlx/{query-98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37.json => query-45e1c056d0868dd8072abd5cdab00c42e268149b730b7e4f2add5cbdbe7843c8.json} (74%) rename .sqlx/{query-c9808451bd3653635dfb455c6125a7d392cee6a69bd7cbf6479cc2cf5294319c.json => query-51dffd9c1ae018de2ade5fde93ae829ad4a5825707e8b2f55b86801039cd784e.json} (69%) rename .sqlx/{query-dbb3290d7ec75771a626416e3cdb8efd92525a9773fe79c17f3a19081fd932f9.json => query-73965013a68538139aadcf0c2346e315923c12c54b9446c7b18e4b4fb1791f42.json} (77%) rename .sqlx/{query-126b613d8b07d65836a20429bef3b0917f7345c9d3a054b73e1176758a4ba1a9.json => query-86b6f0f850b4b8436379a9da38c0280867b8f432fc4bd52ba7194a5895540b4e.json} (77%) rename .sqlx/{query-2ec4ae04a8cf90d7a062ce0b2c318bff70bed69bec930e7331c576073f612677.json => query-dd0190b514cf61ba7c444d14f9c47bda3869b74fcb598ec6f2cf8352424f87d6.json} (71%) diff --git a/.sqlx/query-e9bfbd2e39ddc1cc0f95258cbd711fa4ea8d63de6e7bd7a0f5cf41119cb3bf86.json b/.sqlx/query-2700bf01e6a2afbe3ac2a4b686f715ad1262d5eb8cdf76cf4234b6a4971e6769.json similarity index 73% rename from .sqlx/query-e9bfbd2e39ddc1cc0f95258cbd711fa4ea8d63de6e7bd7a0f5cf41119cb3bf86.json rename to .sqlx/query-2700bf01e6a2afbe3ac2a4b686f715ad1262d5eb8cdf76cf4234b6a4971e6769.json index f5d922538..efcd4c510 100644 --- a/.sqlx/query-e9bfbd2e39ddc1cc0f95258cbd711fa4ea8d63de6e7bd7a0f5cf41119cb3bf86.json +++ b/.sqlx/query-2700bf01e6a2afbe3ac2a4b686f715ad1262d5eb8cdf76cf4234b6a4971e6769.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND state = 'new' AND (NOW() - created_at) > $2 * interval '1 second'", + "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND state = 'new' AND (NOW() - created_at) > $2 * interval '1 second'", "describe": { "columns": [ { @@ -40,15 +40,17 @@ }, { "ordinal": 7, - "name": "mfa_mode: LocationMfaMode", + "name": "mfa_method: VpnClientMfaMethod", "type_info": { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -85,9 +87,9 @@ false, true, true, - false, + true, false ] }, - "hash": "e9bfbd2e39ddc1cc0f95258cbd711fa4ea8d63de6e7bd7a0f5cf41119cb3bf86" + "hash": "2700bf01e6a2afbe3ac2a4b686f715ad1262d5eb8cdf76cf4234b6a4971e6769" } diff --git a/.sqlx/query-1815955c24b6178c653bd7a0e4a18dde89c59df69ede331ea25d3dbac93bc5b8.json b/.sqlx/query-2bc56fca85ec693d2d53fe08e53143fef55fb878076487fc97738f120573820f.json similarity index 74% rename from .sqlx/query-1815955c24b6178c653bd7a0e4a18dde89c59df69ede331ea25d3dbac93bc5b8.json rename to .sqlx/query-2bc56fca85ec693d2d53fe08e53143fef55fb878076487fc97738f120573820f.json index be4214451..e12c06233 100644 --- a/.sqlx/query-1815955c24b6178c653bd7a0e4a18dde89c59df69ede331ea25d3dbac93bc5b8.json +++ b/.sqlx/query-2bc56fca85ec693d2d53fe08e53143fef55fb878076487fc97738f120573820f.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND state = 'connected'::vpn_client_session_state", + "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND state = 'connected'::vpn_client_session_state", "describe": { "columns": [ { @@ -40,15 +40,17 @@ }, { "ordinal": 7, - "name": "mfa_mode: LocationMfaMode", + "name": "mfa_method: VpnClientMfaMethod", "type_info": { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -84,9 +86,9 @@ false, true, true, - false, + true, false ] }, - "hash": "1815955c24b6178c653bd7a0e4a18dde89c59df69ede331ea25d3dbac93bc5b8" + "hash": "2bc56fca85ec693d2d53fe08e53143fef55fb878076487fc97738f120573820f" } diff --git a/.sqlx/query-b2894d8c60b044744f946ee6b9ae26c24f3778a932d744b1378abbc0e04fb8a5.json b/.sqlx/query-3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781.json similarity index 68% rename from .sqlx/query-b2894d8c60b044744f946ee6b9ae26c24f3778a932d744b1378abbc0e04fb8a5.json rename to .sqlx/query-3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781.json index 925029758..c39810651 100644 --- a/.sqlx/query-b2894d8c60b044744f946ee6b9ae26c24f3778a932d744b1378abbc0e04fb8a5.json +++ b/.sqlx/query-3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT s.id, location_id, user_id, device_id, created_at, s.connected_at, disconnected_at, mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" FROM vpn_client_session s LEFT JOIN LATERAL ( SELECT latest_handshake FROM vpn_session_stats WHERE session_id = s.id ORDER BY collected_at DESC LIMIT 1 ) ss ON true WHERE location_id = $1 AND state = 'connected' AND (NOW() - ss.latest_handshake) > $2 * interval '1 second'", + "query": "SELECT s.id, location_id, user_id, device_id, created_at, s.connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" FROM vpn_client_session s LEFT JOIN LATERAL ( SELECT latest_handshake FROM vpn_session_stats WHERE session_id = s.id ORDER BY collected_at DESC LIMIT 1 ) ss ON true WHERE location_id = $1 AND state = 'connected' AND (NOW() - ss.latest_handshake) > $2 * interval '1 second'", "describe": { "columns": [ { @@ -40,15 +40,17 @@ }, { "ordinal": 7, - "name": "mfa_mode: LocationMfaMode", + "name": "mfa_method: VpnClientMfaMethod", "type_info": { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -85,9 +87,9 @@ false, true, true, - false, + true, false ] }, - "hash": "b2894d8c60b044744f946ee6b9ae26c24f3778a932d744b1378abbc0e04fb8a5" + "hash": "3b063ceba4e1b38bc3bb921468ba57f393d5eff60c4bc6703f926bef2583a781" } diff --git a/.sqlx/query-98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37.json b/.sqlx/query-45e1c056d0868dd8072abd5cdab00c42e268149b730b7e4f2add5cbdbe7843c8.json similarity index 74% rename from .sqlx/query-98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37.json rename to .sqlx/query-45e1c056d0868dd8072abd5cdab00c42e268149b730b7e4f2add5cbdbe7843c8.json index 043b02067..5ead8ddc6 100644 --- a/.sqlx/query-98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37.json +++ b/.sqlx/query-45e1c056d0868dd8072abd5cdab00c42e268149b730b7e4f2add5cbdbe7843c8.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_mode \"mfa_mode: LocationMfaMode\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')", + "query": "SELECT id, location_id, user_id, device_id, created_at, connected_at, disconnected_at, mfa_method \"mfa_method: VpnClientMfaMethod\", state \"state: VpnClientSessionState\" FROM vpn_client_session WHERE location_id = $1 AND device_id = $2 AND state IN ('new', 'connected')", "describe": { "columns": [ { @@ -40,15 +40,17 @@ }, { "ordinal": 7, - "name": "mfa_mode: LocationMfaMode", + "name": "mfa_method: VpnClientMfaMethod", "type_info": { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -85,9 +87,9 @@ false, true, true, - false, + true, false ] }, - "hash": "98739c1b3049739f95c056ef871f0ed200f2b4c10707685ece61e1dbe85c5c37" + "hash": "45e1c056d0868dd8072abd5cdab00c42e268149b730b7e4f2add5cbdbe7843c8" } diff --git a/.sqlx/query-c9808451bd3653635dfb455c6125a7d392cee6a69bd7cbf6479cc2cf5294319c.json b/.sqlx/query-51dffd9c1ae018de2ade5fde93ae829ad4a5825707e8b2f55b86801039cd784e.json similarity index 69% rename from .sqlx/query-c9808451bd3653635dfb455c6125a7d392cee6a69bd7cbf6479cc2cf5294319c.json rename to .sqlx/query-51dffd9c1ae018de2ade5fde93ae829ad4a5825707e8b2f55b86801039cd784e.json index 004cbcf8c..0dc3ac7ee 100644 --- a/.sqlx/query-c9808451bd3653635dfb455c6125a7d392cee6a69bd7cbf6479cc2cf5294319c.json +++ b/.sqlx/query-51dffd9c1ae018de2ade5fde93ae829ad4a5825707e8b2f55b86801039cd784e.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE \"vpn_client_session\" SET \"location_id\" = $2,\"user_id\" = $3,\"device_id\" = $4,\"created_at\" = $5,\"connected_at\" = $6,\"disconnected_at\" = $7,\"mfa_mode\" = $8,\"state\" = $9 WHERE id = $1", + "query": "UPDATE \"vpn_client_session\" SET \"location_id\" = $2,\"user_id\" = $3,\"device_id\" = $4,\"created_at\" = $5,\"connected_at\" = $6,\"disconnected_at\" = $7,\"mfa_method\" = $8,\"state\" = $9 WHERE id = $1", "describe": { "columns": [], "parameters": { @@ -14,12 +14,14 @@ "Timestamp", { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -40,5 +42,5 @@ }, "nullable": [] }, - "hash": "c9808451bd3653635dfb455c6125a7d392cee6a69bd7cbf6479cc2cf5294319c" + "hash": "51dffd9c1ae018de2ade5fde93ae829ad4a5825707e8b2f55b86801039cd784e" } diff --git a/.sqlx/query-dbb3290d7ec75771a626416e3cdb8efd92525a9773fe79c17f3a19081fd932f9.json b/.sqlx/query-73965013a68538139aadcf0c2346e315923c12c54b9446c7b18e4b4fb1791f42.json similarity index 77% rename from .sqlx/query-dbb3290d7ec75771a626416e3cdb8efd92525a9773fe79c17f3a19081fd932f9.json rename to .sqlx/query-73965013a68538139aadcf0c2346e315923c12c54b9446c7b18e4b4fb1791f42.json index f4d8d26d8..bd6db45ee 100644 --- a/.sqlx/query-dbb3290d7ec75771a626416e3cdb8efd92525a9773fe79c17f3a19081fd932f9.json +++ b/.sqlx/query-73965013a68538139aadcf0c2346e315923c12c54b9446c7b18e4b4fb1791f42.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"location_id\",\"user_id\",\"device_id\",\"created_at\",\"connected_at\",\"disconnected_at\",\"mfa_mode\" \"mfa_mode: _\",\"state\" \"state: _\" FROM \"vpn_client_session\" WHERE id = $1", + "query": "SELECT id, \"location_id\",\"user_id\",\"device_id\",\"created_at\",\"connected_at\",\"disconnected_at\",\"mfa_method\" \"mfa_method?: _\",\"state\" \"state: _\" FROM \"vpn_client_session\" WHERE id = $1", "describe": { "columns": [ { @@ -40,15 +40,17 @@ }, { "ordinal": 7, - "name": "mfa_mode: _", + "name": "mfa_method?: _", "type_info": { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -84,9 +86,9 @@ false, true, true, - false, + true, false ] }, - "hash": "dbb3290d7ec75771a626416e3cdb8efd92525a9773fe79c17f3a19081fd932f9" + "hash": "73965013a68538139aadcf0c2346e315923c12c54b9446c7b18e4b4fb1791f42" } diff --git a/.sqlx/query-126b613d8b07d65836a20429bef3b0917f7345c9d3a054b73e1176758a4ba1a9.json b/.sqlx/query-86b6f0f850b4b8436379a9da38c0280867b8f432fc4bd52ba7194a5895540b4e.json similarity index 77% rename from .sqlx/query-126b613d8b07d65836a20429bef3b0917f7345c9d3a054b73e1176758a4ba1a9.json rename to .sqlx/query-86b6f0f850b4b8436379a9da38c0280867b8f432fc4bd52ba7194a5895540b4e.json index 045d193d6..ac7a8d312 100644 --- a/.sqlx/query-126b613d8b07d65836a20429bef3b0917f7345c9d3a054b73e1176758a4ba1a9.json +++ b/.sqlx/query-86b6f0f850b4b8436379a9da38c0280867b8f432fc4bd52ba7194a5895540b4e.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "SELECT id, \"location_id\",\"user_id\",\"device_id\",\"created_at\",\"connected_at\",\"disconnected_at\",\"mfa_mode\" \"mfa_mode: _\",\"state\" \"state: _\" FROM \"vpn_client_session\"", + "query": "SELECT id, \"location_id\",\"user_id\",\"device_id\",\"created_at\",\"connected_at\",\"disconnected_at\",\"mfa_method\" \"mfa_method?: _\",\"state\" \"state: _\" FROM \"vpn_client_session\"", "describe": { "columns": [ { @@ -40,15 +40,17 @@ }, { "ordinal": 7, - "name": "mfa_mode: _", + "name": "mfa_method?: _", "type_info": { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -82,9 +84,9 @@ false, true, true, - false, + true, false ] }, - "hash": "126b613d8b07d65836a20429bef3b0917f7345c9d3a054b73e1176758a4ba1a9" + "hash": "86b6f0f850b4b8436379a9da38c0280867b8f432fc4bd52ba7194a5895540b4e" } diff --git a/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json b/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json index b843b3c06..a2e62691a 100644 --- a/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json +++ b/.sqlx/query-d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c.json @@ -45,12 +45,12 @@ }, { "ordinal": 8, - "name": "name", + "name": "version", "type_info": "Text" }, { "ordinal": 9, - "name": "version", + "name": "name", "type_info": "Text" } ], @@ -68,8 +68,8 @@ true, false, true, - false, - true + true, + false ] }, "hash": "d10c9a7b0b391aeb8b4869f6bddf997807b66e0b532da747b146513c34e15c5c" diff --git a/.sqlx/query-2ec4ae04a8cf90d7a062ce0b2c318bff70bed69bec930e7331c576073f612677.json b/.sqlx/query-dd0190b514cf61ba7c444d14f9c47bda3869b74fcb598ec6f2cf8352424f87d6.json similarity index 71% rename from .sqlx/query-2ec4ae04a8cf90d7a062ce0b2c318bff70bed69bec930e7331c576073f612677.json rename to .sqlx/query-dd0190b514cf61ba7c444d14f9c47bda3869b74fcb598ec6f2cf8352424f87d6.json index 47d085485..238f7541a 100644 --- a/.sqlx/query-2ec4ae04a8cf90d7a062ce0b2c318bff70bed69bec930e7331c576073f612677.json +++ b/.sqlx/query-dd0190b514cf61ba7c444d14f9c47bda3869b74fcb598ec6f2cf8352424f87d6.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO \"vpn_client_session\" (\"location_id\",\"user_id\",\"device_id\",\"created_at\",\"connected_at\",\"disconnected_at\",\"mfa_mode\",\"state\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8) RETURNING id", + "query": "INSERT INTO \"vpn_client_session\" (\"location_id\",\"user_id\",\"device_id\",\"created_at\",\"connected_at\",\"disconnected_at\",\"mfa_method\",\"state\") VALUES ($1,$2,$3,$4,$5,$6,$7,$8) RETURNING id", "describe": { "columns": [ { @@ -19,12 +19,14 @@ "Timestamp", { "Custom": { - "name": "location_mfa_mode", + "name": "vpn_client_mfa_method", "kind": { "Enum": [ - "disabled", - "internal", - "external" + "totp", + "email", + "oidc", + "biometric", + "mobileapprove" ] } } @@ -47,5 +49,5 @@ false ] }, - "hash": "2ec4ae04a8cf90d7a062ce0b2c318bff70bed69bec930e7331c576073f612677" + "hash": "dd0190b514cf61ba7c444d14f9c47bda3869b74fcb598ec6f2cf8352424f87d6" } diff --git a/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json b/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json index d3ffd878e..db1d8414a 100644 --- a/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json +++ b/.sqlx/query-e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a.json @@ -45,12 +45,12 @@ }, { "ordinal": 8, - "name": "name", + "name": "version", "type_info": "Text" }, { "ordinal": 9, - "name": "version", + "name": "name", "type_info": "Text" } ], @@ -68,8 +68,8 @@ true, false, true, - false, - true + true, + false ] }, "hash": "e9ca71b61f7a3736ca335d90aca36ab5a93dc8a00ad622267f13b3cd4cdb4a5a" From c194cceb0d1bd6cda70e64e0e454c7bf4a74c090 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 09:23:22 +0100 Subject: [PATCH 05/19] track session state --- .../src/db/models/vpn_client_session.rs | 2 +- .../src/session_state.rs | 27 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/crates/defguard_common/src/db/models/vpn_client_session.rs b/crates/defguard_common/src/db/models/vpn_client_session.rs index b080f21a6..749965e52 100644 --- a/crates/defguard_common/src/db/models/vpn_client_session.rs +++ b/crates/defguard_common/src/db/models/vpn_client_session.rs @@ -7,7 +7,7 @@ use crate::db::{ models::{WireguardNetwork, vpn_session_stats::VpnSessionStats}, }; -#[derive(Debug, Default, Type)] +#[derive(Clone, Debug, Default, Type)] #[sqlx(type_name = "vpn_client_session_state", rename_all = "lowercase")] pub enum VpnClientSessionState { #[default] diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index a8b7e70cd..751236577 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -5,8 +5,10 @@ use defguard_common::{ db::{ Id, models::{ - Device, User, WireguardNetwork, vpn_client_session::VpnClientSession, + Device, User, WireguardNetwork, + vpn_client_session::{VpnClientSession, VpnClientSessionState}, vpn_session_stats::VpnSessionStats, + wireguard::LocationMfaMode, }, }, messages::peer_stats_update::PeerStatsUpdate, @@ -87,17 +89,11 @@ impl From> for LastStatsUpdate { /// State of a specific VPN client session pub(crate) struct SessionState { session_id: Id, + state: VpnClientSessionState, last_stats_update: LastGatewayUpdate, } impl SessionState { - fn new(session_id: Id) -> Self { - Self { - session_id, - last_stats_update: LastGatewayUpdate::new(), - } - } - fn try_get_last_stats_update(&self, gateway_id: Id) -> Option<&LastStatsUpdate> { self.last_stats_update.0.get(&gateway_id) } @@ -150,6 +146,7 @@ impl From<&VpnClientSession> for SessionState { fn from(value: &VpnClientSession) -> Self { Self { session_id: value.id, + state: value.state.clone(), last_stats_update: LastGatewayUpdate::new(), } } @@ -241,6 +238,9 @@ impl ActiveSessionsMap { /// Attempts to create a new VPN client session, add it to curent state and persists it in DB /// + /// This should only happen for non-MFA sessions since MFA sessions (with `new` state) should be created once the authorization is completed + /// in the proxy handler. + /// /// We assume that at this point it's been checked that a session for this client does not exist yet, /// but we do check if given peer can be considered active based on a given locations peer disconnect threshold. pub(crate) async fn try_add_new_session( @@ -257,6 +257,15 @@ impl ActiveSessionsMap { .await? .clone(); + // check location MFA mode since MFA sessions should be created elsewhere + // once MFA auth is successful + if location.location_mfa_mode != LocationMfaMode::Disabled { + warn!( + "Received peer stats update for MFA-enabled location {location}, but VPN session does not exist yet. Skipping creating a new session..." + ); + return Ok(None); + } + // check if a given peer is considered active and should be added to active sessions if Utc::now().naive_utc() - stats_update.latest_handshake > TimeDelta::seconds(location.peer_disconnect_threshold.into()) @@ -290,7 +299,7 @@ impl ActiveSessionsMap { .await?; // add to session map - let session_state = SessionState::new(session.id); + let session_state = SessionState::from(&session); let session_map = self.get_or_create_location_session_map(location_id); let maybe_existing_session = session_map.insert(device.id, session_state); From 3b317f7d192e85980781ca9879fb65203932e634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 09:35:25 +0100 Subject: [PATCH 06/19] mark MFA sessions as connected --- .../src/db/models/vpn_client_session.rs | 2 +- crates/defguard_session_manager/src/error.rs | 2 ++ .../src/session_state.rs | 17 +++++++++++++++++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/crates/defguard_common/src/db/models/vpn_client_session.rs b/crates/defguard_common/src/db/models/vpn_client_session.rs index 749965e52..9b5c98d2c 100644 --- a/crates/defguard_common/src/db/models/vpn_client_session.rs +++ b/crates/defguard_common/src/db/models/vpn_client_session.rs @@ -7,7 +7,7 @@ use crate::db::{ models::{WireguardNetwork, vpn_session_stats::VpnSessionStats}, }; -#[derive(Clone, Debug, Default, Type)] +#[derive(Clone, Debug, Default, PartialEq, Type)] #[sqlx(type_name = "vpn_client_session_state", rename_all = "lowercase")] pub enum VpnClientSessionState { #[default] diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index e7fa25196..97d5d0284 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -22,6 +22,8 @@ pub enum SessionManagerError { DeviceDoesNotExistError(Id), #[error("Location with ID {0} does not exist")] LocationDoesNotExistError(Id), + #[error("VPN client session with ID {0} does not exist")] + SessionDoesNotExistError(Id), #[error("Received out of order peer stats update")] PeerStatsUpdateOutOfOrderError, #[error("Failed to send session manager event: {0}")] diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index 751236577..89ec6f751 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -104,6 +104,23 @@ impl SessionState { transaction: &mut PgConnection, peer_stats_update: PeerStatsUpdate, ) -> Result<(), SessionManagerError> { + // mark new MFA session as connected if necessary + if self.state == VpnClientSessionState::New { + // fetch DB session + let mut db_session = VpnClientSession::find_by_id(&mut *transaction, self.session_id) + .await? + .ok_or(SessionManagerError::SessionDoesNotExistError( + self.session_id, + ))?; + // update DB session + db_session.state = VpnClientSessionState::Connected; + db_session.connected_at = Some(peer_stats_update.latest_handshake); + db_session.save(&mut *transaction).await?; + + // update local session state + self.state = VpnClientSessionState::Connected; + } + // get previous stats for a given gateway if available and calculate transfer change let (upload_diff, download_diff) = match self.try_get_last_stats_update(peer_stats_update.gateway_id) { From 8765bac2040876f15f940b95bd5ffd7d64a2e882 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 09:40:04 +0100 Subject: [PATCH 07/19] clippy fix --- crates/defguard_proto/src/lib.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/crates/defguard_proto/src/lib.rs b/crates/defguard_proto/src/lib.rs index 471adb7f4..84e78505c 100644 --- a/crates/defguard_proto/src/lib.rs +++ b/crates/defguard_proto/src/lib.rs @@ -67,14 +67,14 @@ impl Serialize for MfaMethod { } } -impl Into for MfaMethod { - fn into(self) -> VpnClientMfaMethod { - match self { - Self::Totp => VpnClientMfaMethod::Totp, - Self::Email => VpnClientMfaMethod::Email, - Self::Oidc => VpnClientMfaMethod::Oidc, - Self::Biometric => VpnClientMfaMethod::Biometric, - Self::MobileApprove => VpnClientMfaMethod::MobileApprove, +impl From for VpnClientMfaMethod { + fn from(val: MfaMethod) -> Self { + match val { + MfaMethod::Totp => VpnClientMfaMethod::Totp, + MfaMethod::Email => VpnClientMfaMethod::Email, + MfaMethod::Oidc => VpnClientMfaMethod::Oidc, + MfaMethod::Biometric => VpnClientMfaMethod::Biometric, + MfaMethod::MobileApprove => VpnClientMfaMethod::MobileApprove, } } } From d3b25d8c49986d448b75b49300818c92075e3e08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 11:23:50 +0100 Subject: [PATCH 08/19] disconnect unused sessions only for MFA locations --- crates/defguard_session_manager/src/lib.rs | 31 +++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 31563e1f6..48d85ef36 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -188,21 +188,23 @@ impl SessionManager { // get all sessions which were created but have never connected // this is only relevant for MFA locations - let unused_sessions = - VpnClientSession::get_never_connected(&mut *transaction, &location).await?; + if location.mfa_enabled() { + let unused_sessions = + VpnClientSession::get_never_connected(&mut *transaction, &location).await?; - debug!( - "Found {} new VPN sessions which have not connected within required time in location {location}", - unused_sessions.len() - ); - - for session in unused_sessions { debug!( - "Disconnecting never connected session for user {}, device {} in location {location}", - session.user_id, session.device_id + "Found {} new VPN sessions which have not connected within required time in location {location}", + unused_sessions.len() ); - self.disconnect_session(&mut transaction, session, &location) - .await?; + + for session in unused_sessions { + debug!( + "Disconnecting never connected session for user {}, device {} in location {location}", + session.user_id, session.device_id + ); + self.disconnect_session(&mut transaction, session, &location) + .await?; + } } } @@ -239,6 +241,11 @@ impl SessionManager { session.device_id, ))?; + // remove peers from GW for MFA locations + if location.mfa_enabled() { + unimplemented!() + } + // emit event let context = SessionManagerEventContext { timestamp: disconnect_timestamp, From c1dd85c61527cb8d7a562a15323feaea3310a480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 12:08:25 +0100 Subject: [PATCH 09/19] send session disconnect event to gateway handler --- Cargo.lock | 1 + crates/defguard/src/main.rs | 18 ++++++----- .../defguard_core/src/grpc/gateway/events.rs | 21 ++++++++---- crates/defguard_core/src/grpc/gateway/mod.rs | 7 ++++ crates/defguard_session_manager/Cargo.toml | 3 ++ crates/defguard_session_manager/src/error.rs | 10 +++++- crates/defguard_session_manager/src/lib.rs | 32 ++++++++++++++++--- 7 files changed, 73 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 335981a00..159ef921f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1379,6 +1379,7 @@ version = "0.0.0" dependencies = [ "chrono", "defguard_common", + "defguard_core", "sqlx", "thiserror 2.0.18", "tokio", diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 5041d7978..9236114b2 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -117,7 +117,8 @@ async fn main() -> Result<(), anyhow::Error> { // setup communication channels for services let (webhook_tx, webhook_rx) = unbounded_channel::(); - let (wireguard_tx, _wireguard_rx) = broadcast::channel::(256); + // RX is discarded here since it can be derived from TX later on + let (gateway_tx, _gateway_rx) = broadcast::channel::(256); let (mail_tx, mail_rx) = unbounded_channel::(); let (event_logger_tx, event_logger_rx) = unbounded_channel::(); let (peer_stats_tx, peer_stats_rx) = unbounded_channel::(); @@ -181,7 +182,7 @@ async fn main() -> Result<(), anyhow::Error> { } let (proxy_control_tx, proxy_control_rx) = channel::(100); - let proxy_tx = ProxyTxSet::new(wireguard_tx.clone(), mail_tx.clone(), bidi_event_tx.clone()); + let proxy_tx = ProxyTxSet::new(gateway_tx.clone(), mail_tx.clone(), bidi_event_tx.clone()); let proxy_manager = ProxyManager::new( pool.clone(), proxy_tx, @@ -194,7 +195,7 @@ async fn main() -> Result<(), anyhow::Error> { res = proxy_manager.run() => error!("ProxyManager returned early: {res:?}"), res = run_grpc_gateway_stream( pool.clone(), - wireguard_tx.clone(), + gateway_tx.clone(), mail_tx.clone(), peer_stats_tx, ) => error!("Gateway gRPC stream returned early: {res:?}"), @@ -209,7 +210,7 @@ async fn main() -> Result<(), anyhow::Error> { worker_state, webhook_tx, webhook_rx, - wireguard_tx.clone(), + gateway_tx.clone(), mail_tx.clone(), pool.clone(), failed_logins, @@ -220,7 +221,7 @@ async fn main() -> Result<(), anyhow::Error> { res = run_mail_handler(mail_rx) => error!("Mail handler returned early: {res:?}"), res = run_periodic_peer_disconnect( pool.clone(), - wireguard_tx.clone(), + gateway_tx.clone(), internal_event_tx.clone() ) => error!("Periodic peer disconnect task returned early: {res:?}"), res = run_periodic_stats_purge( @@ -231,7 +232,7 @@ async fn main() -> Result<(), anyhow::Error> { error!("Periodic stats purge task returned early: {res:?}"), res = run_periodic_license_check(&pool) => error!("Periodic license check task returned early: {res:?}"), - res = run_utility_thread(&pool, wireguard_tx.clone()) => + res = run_utility_thread(&pool, gateway_tx.clone()) => error!("Utility thread returned early: {res:?}"), res = run_event_router( RouterReceiverSet::new( @@ -241,7 +242,7 @@ async fn main() -> Result<(), anyhow::Error> { session_manager_event_rx ), event_logger_tx, - wireguard_tx, + gateway_tx.clone(), mail_tx, activity_log_stream_reload_notify.clone() ) => error!("Event router returned early: {res:?}"), @@ -255,7 +256,8 @@ async fn main() -> Result<(), anyhow::Error> { res = run_session_manager( pool.clone(), peer_stats_rx, - session_manager_event_tx + session_manager_event_tx, + gateway_tx ) => error!("VPN client session manager returned early: {res:?}"), } diff --git a/crates/defguard_core/src/grpc/gateway/events.rs b/crates/defguard_core/src/grpc/gateway/events.rs index 9f4513bb4..6ea648c11 100644 --- a/crates/defguard_core/src/grpc/gateway/events.rs +++ b/crates/defguard_core/src/grpc/gateway/events.rs @@ -1,17 +1,26 @@ use defguard_common::db::{ Id, - models::{WireguardNetwork, device::DeviceInfo}, + models::{Device, WireguardNetwork, device::DeviceInfo}, }; use defguard_proto::{enterprise::firewall::FirewallConfig, gateway::Peer}; +type LocationId = Id; + +// TODO: move this to common crate #[derive(Clone, Debug)] pub enum GatewayEvent { - NetworkCreated(Id, WireguardNetwork), - NetworkModified(Id, WireguardNetwork, Vec, Option), - NetworkDeleted(Id, String), + NetworkCreated(LocationId, WireguardNetwork), + NetworkModified( + LocationId, + WireguardNetwork, + Vec, + Option, + ), + NetworkDeleted(LocationId, String), DeviceCreated(DeviceInfo), DeviceModified(DeviceInfo), DeviceDeleted(DeviceInfo), - FirewallConfigChanged(Id, FirewallConfig), - FirewallDisabled(Id), + FirewallConfigChanged(LocationId, FirewallConfig), + FirewallDisabled(LocationId), + MfaSessionDisconnected(LocationId, Device), } diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 49acc4391..c63052cd6 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -475,6 +475,13 @@ impl GatewayUpdatesHandler { Ok(()) } } + GatewayEvent::MfaSessionDisconnected(location_id, device) => { + if location_id == self.network_id { + self.send_peer_delete(&device.wireguard_pubkey) + } else { + Ok(()) + } + } }; if result.is_err() { error!( diff --git a/crates/defguard_session_manager/Cargo.toml b/crates/defguard_session_manager/Cargo.toml index e72d7ce99..a28bc0ec1 100644 --- a/crates/defguard_session_manager/Cargo.toml +++ b/crates/defguard_session_manager/Cargo.toml @@ -9,6 +9,9 @@ rust-version.workspace = true [dependencies] defguard_common.workspace = true +# TODO: remove this dependency once gateway events are moved +defguard_core.workspace = true + chrono.workspace = true sqlx.workspace = true thiserror.workspace = true diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 97d5d0284..4b2150756 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -1,6 +1,7 @@ use defguard_common::db::Id; +use defguard_core::grpc::gateway::events::GatewayEvent; use thiserror::Error; -use tokio::sync::mpsc::error::SendError; +use tokio::sync::{broadcast::error::SendError as BroadcastSendError, mpsc::error::SendError}; use crate::events::SessionManagerEvent; @@ -28,6 +29,8 @@ pub enum SessionManagerError { PeerStatsUpdateOutOfOrderError, #[error("Failed to send session manager event: {0}")] SessionManagerEventError(Box>), + #[error("Failed to send gateway manager event: {0}")] + GatewayManagerEventError(Box>), } impl From> for SessionManagerError { @@ -35,3 +38,8 @@ impl From> for SessionManagerError { Self::SessionManagerEventError(Box::new(error)) } } +impl From> for SessionManagerError { + fn from(error: BroadcastSendError) -> Self { + Self::GatewayManagerEventError(Box::new(error)) + } +} diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 48d85ef36..84bbc8f84 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -8,9 +8,13 @@ use defguard_common::{ }, messages::peer_stats_update::PeerStatsUpdate, }; +use defguard_core::grpc::gateway::events::GatewayEvent; use sqlx::{PgConnection, PgPool}; use tokio::{ - sync::mpsc::{UnboundedReceiver, UnboundedSender}, + sync::{ + broadcast::Sender, + mpsc::{UnboundedReceiver, UnboundedSender}, + }, time::{Duration, interval}, }; use tracing::{debug, error, info, trace, warn}; @@ -32,12 +36,13 @@ pub async fn run_session_manager( pool: PgPool, mut peer_stats_rx: UnboundedReceiver, session_manager_event_tx: UnboundedSender, + gateway_tx: Sender, ) -> Result<(), SessionManagerError> { info!("Starting VPN client session manager service"); let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); // initialize session manager - let mut session_manager = SessionManager::new(pool, session_manager_event_tx); + let mut session_manager = SessionManager::new(pool, session_manager_event_tx, gateway_tx); loop { // receive next batch of peer stats messages @@ -69,13 +74,19 @@ pub async fn run_session_manager( struct SessionManager { pool: PgPool, session_manager_event_tx: UnboundedSender, + gateway_tx: Sender, } impl SessionManager { - fn new(pool: PgPool, session_manager_event_tx: UnboundedSender) -> Self { + fn new( + pool: PgPool, + session_manager_event_tx: UnboundedSender, + gateway_tx: Sender, + ) -> Self { Self { pool, session_manager_event_tx, + gateway_tx, } } @@ -243,7 +254,7 @@ impl SessionManager { // remove peers from GW for MFA locations if location.mfa_enabled() { - unimplemented!() + self.send_peer_disconnect_message(location, &device)?; } // emit event @@ -263,4 +274,17 @@ impl SessionManager { Ok(()) } + + fn send_peer_disconnect_message( + &self, + location: &WireguardNetwork, + device: &Device, + ) -> Result<(), SessionManagerError> { + debug!( + "Sending MFA session disconnect event for device {device} in location {location} to gateway manager" + ); + let event = GatewayEvent::MfaSessionDisconnected(location.id, device.clone()); + self.gateway_tx.send(event)?; + Ok(()) + } } From 4a2f1129732224b1f1a173de1b5622b0c28dc1f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 12:11:38 +0100 Subject: [PATCH 10/19] rename MFA event --- crates/defguard_core/src/events.rs | 2 +- crates/defguard_core/src/grpc/proxy/client_mfa.rs | 2 +- crates/defguard_event_router/src/handlers/bidi.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index 70bae976c..cda067989 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -374,7 +374,7 @@ pub type ClientMFAMethod = MfaMethod; #[derive(Debug)] pub enum DesktopClientMfaEvent { - Connected { + Success { device: Device, location: WireguardNetwork, method: ClientMFAMethod, diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 8cbef7104..9440e5026 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -704,7 +704,7 @@ impl ClientMfaServer { self.emit_event(BidiStreamEvent { context, event: BidiStreamEventType::DesktopClientMfa(Box::new( - DesktopClientMfaEvent::Connected { + DesktopClientMfaEvent::Success { location: location.clone(), device: device.clone(), method, diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index 0dac49b40..e7dcd41a0 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -45,7 +45,7 @@ impl EventRouter { ), }, BidiStreamEventType::DesktopClientMfa(event) => match *event { - DesktopClientMfaEvent::Connected { + DesktopClientMfaEvent::Success { location, device, method, From 492c6fdf6f7d4de69e4b509b956e086750bb9766 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 13:20:15 +0100 Subject: [PATCH 11/19] simplify MFA-related VPN client events --- crates/defguard/src/main.rs | 5 +-- .../src/db/models/activity_log/mod.rs | 3 +- crates/defguard_core/src/events.rs | 35 ------------------- .../src/wireguard_peer_disconnect.rs | 16 +-------- .../defguard_event_logger/src/description.rs | 9 ++--- crates/defguard_event_logger/src/lib.rs | 10 ++---- crates/defguard_event_logger/src/message.rs | 30 ++-------------- crates/defguard_event_router/src/events.rs | 3 +- .../src/handlers/bidi.rs | 4 +-- .../src/handlers/internal.rs | 27 -------------- .../defguard_event_router/src/handlers/mod.rs | 1 - crates/defguard_event_router/src/lib.rs | 10 +----- 12 files changed, 16 insertions(+), 137 deletions(-) delete mode 100644 crates/defguard_event_router/src/handlers/internal.rs diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 9236114b2..0a58ac04c 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -25,7 +25,7 @@ use defguard_core::{ license::{License, run_periodic_license_check, set_cached_license}, limits::update_counts, }, - events::{ApiEvent, BidiStreamEvent, InternalEvent}, + events::{ApiEvent, BidiStreamEvent}, grpc::{ WorkerState, gateway::{events::GatewayEvent, run_grpc_gateway_stream}, @@ -107,7 +107,6 @@ async fn main() -> Result<(), anyhow::Error> { // create event channels for services let (api_event_tx, api_event_rx) = unbounded_channel::(); let (bidi_event_tx, bidi_event_rx) = unbounded_channel::(); - let (internal_event_tx, internal_event_rx) = unbounded_channel::(); let (session_manager_event_tx, session_manager_event_rx) = unbounded_channel::(); @@ -222,7 +221,6 @@ async fn main() -> Result<(), anyhow::Error> { res = run_periodic_peer_disconnect( pool.clone(), gateway_tx.clone(), - internal_event_tx.clone() ) => error!("Periodic peer disconnect task returned early: {res:?}"), res = run_periodic_stats_purge( pool.clone(), @@ -238,7 +236,6 @@ async fn main() -> Result<(), anyhow::Error> { RouterReceiverSet::new( api_event_rx, bidi_event_rx, - internal_event_rx, session_manager_event_rx ), event_logger_tx, diff --git a/crates/defguard_core/src/db/models/activity_log/mod.rs b/crates/defguard_core/src/db/models/activity_log/mod.rs index cf9314cf4..f50e23f9e 100644 --- a/crates/defguard_core/src/db/models/activity_log/mod.rs +++ b/crates/defguard_core/src/db/models/activity_log/mod.rs @@ -75,8 +75,7 @@ pub enum EventType { // VPN client events VpnClientConnected, VpnClientDisconnected, - VpnClientConnectedMfa, - VpnClientDisconnectedMfa, + VpnClientMfaSuccess, VpnClientMfaFailed, // Enrollment events EnrollmentTokenAdded, diff --git a/crates/defguard_core/src/events.rs b/crates/defguard_core/src/events.rs index cda067989..eb9b4987c 100644 --- a/crates/defguard_core/src/events.rs +++ b/crates/defguard_core/src/events.rs @@ -386,38 +386,3 @@ pub enum DesktopClientMfaEvent { message: String, }, } - -/// Shared context for every internally-triggered event. -/// -/// Similarly to `ApiRequestContexts` at the moment it's mostly meant to populate the activity log. -#[derive(Debug)] -pub struct InternalEventContext { - pub timestamp: NaiveDateTime, - pub user_id: Id, - pub username: String, - pub ip: IpAddr, - pub device: Device, -} - -impl InternalEventContext { - #[must_use] - pub fn new(user_id: Id, username: String, ip: IpAddr, device: Device) -> Self { - let timestamp = Utc::now().naive_utc(); - Self { - timestamp, - user_id, - username, - ip, - device, - } - } -} - -/// Events emmited by background threads, not triggered directly by users -#[derive(Debug)] -pub enum InternalEvent { - DesktopClientMfaDisconnected { - context: InternalEventContext, - location: WireguardNetwork, - }, -} diff --git a/crates/defguard_core/src/wireguard_peer_disconnect.rs b/crates/defguard_core/src/wireguard_peer_disconnect.rs index 661eaed04..b211e07b7 100644 --- a/crates/defguard_core/src/wireguard_peer_disconnect.rs +++ b/crates/defguard_core/src/wireguard_peer_disconnect.rs @@ -29,10 +29,7 @@ use tokio::{ time::sleep, }; -use crate::{ - events::{InternalEvent, InternalEventContext}, - grpc::gateway::events::GatewayEvent, -}; +use crate::grpc::gateway::events::GatewayEvent; // How long to sleep between loop iterations const DISCONNECT_LOOP_SLEEP: Duration = Duration::from_secs(60); // 1 minute @@ -47,8 +44,6 @@ pub enum PeerDisconnectError { WireguardError(#[from] WireguardNetworkError), #[error("Failed to send gateway event: {0}")] GatewayEventError(#[from] broadcast::error::SendError), - #[error("Failed to send internal event: {0}")] - InternalEventError(#[from] mpsc::error::SendError), } #[derive(Debug)] @@ -86,7 +81,6 @@ impl From for Device { pub async fn run_periodic_peer_disconnect( pool: PgPool, wireguard_tx: Sender, - internal_event_tx: UnboundedSender, ) -> Result<(), PeerDisconnectError> { info!("Starting periodic disconnect of inactive devices in MFA-protected locations"); loop { @@ -174,14 +168,6 @@ pub async fn run_periodic_peer_disconnect( // endpoint is a `text` column in the db so we have to // handle potential parsing issues here .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); - let event = InternalEvent::DesktopClientMfaDisconnected { - context: InternalEventContext::new(user.id, user.username, ip, device), - location: location.clone(), - }; - internal_event_tx.send(event).map_err(|err| { - error!("Error sending internal event: {err}"); - PeerDisconnectError::InternalEventError(err) - })?; } else { error!( "Network config for device {device} in location {location} not found. Skipping device..." diff --git a/crates/defguard_event_logger/src/description.rs b/crates/defguard_event_logger/src/description.rs index 1ca637719..99618d0de 100644 --- a/crates/defguard_event_logger/src/description.rs +++ b/crates/defguard_event_logger/src/description.rs @@ -263,17 +263,14 @@ pub fn get_defguard_event_description(event: &DefguardEvent) -> Option { #[must_use] pub fn get_vpn_event_description(event: &VpnEvent) -> Option { match event { - VpnEvent::ConnectedToMfaLocation { + VpnEvent::ClientMfaSuccess { location, device, method, } => Some(format!( - "Device {device} connected to MFA location {location} using {method}" + "Device {device} completed MFA authorization for location {location} using {method}" )), - VpnEvent::DisconnectedFromMfaLocation { location, device } => Some(format!( - "Device {device} disconnected from MFA location {location}" - )), - VpnEvent::MfaFailed { + VpnEvent::ClientMfaFailed { location, device, method, diff --git a/crates/defguard_event_logger/src/lib.rs b/crates/defguard_event_logger/src/lib.rs index 29c612395..aa5a2b194 100644 --- a/crates/defguard_event_logger/src/lib.rs +++ b/crates/defguard_event_logger/src/lib.rs @@ -476,7 +476,7 @@ pub async fn run_event_logger( let description = get_vpn_event_description(&event); let (event_type, metadata) = match *event { - VpnEvent::MfaFailed { + VpnEvent::ClientMfaFailed { location, device, method, @@ -491,12 +491,12 @@ pub async fn run_event_logger( }) .ok(), ), - VpnEvent::ConnectedToMfaLocation { + VpnEvent::ClientMfaSuccess { location, device, method, } => ( - EventType::VpnClientConnectedMfa, + EventType::VpnClientMfaSuccess, serde_json::to_value(VpnClientMfaMetadata { location, device, @@ -504,10 +504,6 @@ pub async fn run_event_logger( }) .ok(), ), - VpnEvent::DisconnectedFromMfaLocation { location, device } => ( - EventType::VpnClientDisconnectedMfa, - serde_json::to_value(VpnClientMetadata { location, device }).ok(), - ), VpnEvent::ConnectedToLocation { location, device } => ( EventType::VpnClientConnected, serde_json::to_value(VpnClientMetadata { location, device }).ok(), diff --git a/crates/defguard_event_logger/src/message.rs b/crates/defguard_event_logger/src/message.rs index 3d736281e..80bf1216e 100644 --- a/crates/defguard_event_logger/src/message.rs +++ b/crates/defguard_event_logger/src/message.rs @@ -14,10 +14,7 @@ use defguard_core::{ activity_log_stream::ActivityLogStream, api_tokens::ApiToken, openid_provider::OpenIdProvider, snat::UserSnatBinding, }, - events::{ - ApiRequestContext, BidiRequestContext, ClientMFAMethod, GrpcRequestContext, - InternalEventContext, - }, + events::{ApiRequestContext, BidiRequestContext, ClientMFAMethod, GrpcRequestContext}, }; use defguard_session_manager::events::SessionManagerEventContext; @@ -86,23 +83,6 @@ impl EventContext { } } - #[must_use] - pub fn from_internal_context( - val: InternalEventContext, - location: Option>, - ) -> Self { - let location = location.map(|location| location.name); - - Self { - timestamp: val.timestamp, - user_id: val.user_id, - username: val.username, - location, - ip: val.ip, - device: format!("{} (ID {})", val.device.name, val.device.id), - } - } - #[must_use] pub fn from_session_manager_context(val: SessionManagerEventContext) -> Self { Self { @@ -347,16 +327,12 @@ pub enum ClientEvent { /// Represents activity log events related to VPN pub enum VpnEvent { - ConnectedToMfaLocation { + ClientMfaSuccess { location: WireguardNetwork, device: Device, method: ClientMFAMethod, }, - DisconnectedFromMfaLocation { - location: WireguardNetwork, - device: Device, - }, - MfaFailed { + ClientMfaFailed { location: WireguardNetwork, device: Device, method: ClientMFAMethod, diff --git a/crates/defguard_event_router/src/events.rs b/crates/defguard_event_router/src/events.rs index ccbf7c794..02cf4e22d 100644 --- a/crates/defguard_event_router/src/events.rs +++ b/crates/defguard_event_router/src/events.rs @@ -1,4 +1,4 @@ -use defguard_core::events::{ApiEvent, BidiStreamEvent, InternalEvent}; +use defguard_core::events::{ApiEvent, BidiStreamEvent}; use defguard_session_manager::events::SessionManagerEvent; /// Enum representing all possible events that can be generated in the system. @@ -9,6 +9,5 @@ use defguard_session_manager::events::SessionManagerEvent; pub enum Event { Api(ApiEvent), Bidi(BidiStreamEvent), - Internal(Box), SessionManager(Box), } diff --git a/crates/defguard_event_router/src/handlers/bidi.rs b/crates/defguard_event_router/src/handlers/bidi.rs index e7dcd41a0..669a0e051 100644 --- a/crates/defguard_event_router/src/handlers/bidi.rs +++ b/crates/defguard_event_router/src/handlers/bidi.rs @@ -50,7 +50,7 @@ impl EventRouter { device, method, } => ( - LoggerEvent::Vpn(Box::new(VpnEvent::ConnectedToMfaLocation { + LoggerEvent::Vpn(Box::new(VpnEvent::ClientMfaSuccess { location: location.clone(), device, method, @@ -63,7 +63,7 @@ impl EventRouter { method, message, } => ( - LoggerEvent::Vpn(Box::new(VpnEvent::MfaFailed { + LoggerEvent::Vpn(Box::new(VpnEvent::ClientMfaFailed { location: location.clone(), device, method, diff --git a/crates/defguard_event_router/src/handlers/internal.rs b/crates/defguard_event_router/src/handlers/internal.rs deleted file mode 100644 index 2e1e3781b..000000000 --- a/crates/defguard_event_router/src/handlers/internal.rs +++ /dev/null @@ -1,27 +0,0 @@ -use defguard_core::events::InternalEvent; -use defguard_event_logger::message::{EventContext, LoggerEvent, VpnEvent}; -use tracing::debug; - -use crate::{EventRouter, error::EventRouterError}; - -impl EventRouter { - pub(crate) fn handle_internal_event( - &self, - event: InternalEvent, - ) -> Result<(), EventRouterError> { - debug!("Processing internal event: {event:?}"); - - match event { - InternalEvent::DesktopClientMfaDisconnected { context, location } => { - let device = context.device.clone(); - self.log_event( - EventContext::from_internal_context(context, Some(location.clone())), - LoggerEvent::Vpn(Box::new(VpnEvent::DisconnectedFromMfaLocation { - device, - location, - })), - ) - } - } - } -} diff --git a/crates/defguard_event_router/src/handlers/mod.rs b/crates/defguard_event_router/src/handlers/mod.rs index 0c40a8387..ad5c568be 100644 --- a/crates/defguard_event_router/src/handlers/mod.rs +++ b/crates/defguard_event_router/src/handlers/mod.rs @@ -1,4 +1,3 @@ pub(crate) mod api; pub(crate) mod bidi; -pub(crate) mod internal; pub(crate) mod session_manager; diff --git a/crates/defguard_event_router/src/lib.rs b/crates/defguard_event_router/src/lib.rs index 132f56d75..ecac1bd13 100644 --- a/crates/defguard_event_router/src/lib.rs +++ b/crates/defguard_event_router/src/lib.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use defguard_core::{ - events::{ApiEvent, BidiStreamEvent, InternalEvent}, + events::{ApiEvent, BidiStreamEvent}, grpc::gateway::events::GatewayEvent, }; use defguard_event_logger::message::{EventContext, EventLoggerMessage, LoggerEvent}; @@ -42,7 +42,6 @@ mod handlers; pub struct RouterReceiverSet { api: UnboundedReceiver, bidi: UnboundedReceiver, - internal: UnboundedReceiver, session_manager: UnboundedReceiver, } @@ -51,13 +50,11 @@ impl RouterReceiverSet { pub fn new( api: UnboundedReceiver, bidi: UnboundedReceiver, - internal: UnboundedReceiver, session_manager: UnboundedReceiver, ) -> Self { Self { api, bidi, - internal, session_manager, } } @@ -120,10 +117,6 @@ impl EventRouter { error!("Bidi gRPC stream event channel closed"); return Err(EventRouterError::BidiEventChannelClosed); }, - event = self.receivers.internal.recv() => if let Some(internal_event) = event { Event::Internal(Box::new(internal_event)) } else { - error!("Internal event channel closed"); - return Err(EventRouterError::InternalEventChannelClosed); - }, event = self.receivers.session_manager.recv() => if let Some(session_manager_event) = event { Event::SessionManager(Box::new(session_manager_event)) } else { error!("Internal event channel closed"); return Err(EventRouterError::InternalEventChannelClosed); @@ -136,7 +129,6 @@ impl EventRouter { match event { Event::Api(api_event) => self.handle_api_event(api_event)?, Event::Bidi(bidi_event) => self.handle_bidi_event(bidi_event)?, - Event::Internal(internal_event) => self.handle_internal_event(*internal_event)?, Event::SessionManager(session_manager_event) => { self.handle_session_manager_event(*session_manager_event)? } From 244553f8f69bd2e22de34ee271facd4264079ffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 13:22:59 +0100 Subject: [PATCH 12/19] update frontend messages --- web/messages/en/activity.json | 3 +-- web/src/shared/api/activity-log-types.ts | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/web/messages/en/activity.json b/web/messages/en/activity.json index e13c9eefb..01d698df6 100644 --- a/web/messages/en/activity.json +++ b/web/messages/en/activity.json @@ -29,8 +29,7 @@ "activity_event_activity_log_stream_removed": "Activity log stream removed", "activity_event_vpn_client_connected": "VPN client connected", "activity_event_vpn_client_disconnected": "VPN client disconnected", - "activity_event_vpn_client_connected_mfa": "VPN client connected with MFA", - "activity_event_vpn_client_disconnected_mfa": "VPN client disconnected with MFA", + "activity_event_vpn_client_mfa_success": "VPN client MFA success", "activity_event_vpn_client_mfa_failed": "VPN client MFA failed", "activity_event_enrollment_token_added": "Enrollment token added", "activity_event_enrollment_started": "Enrollment started", diff --git a/web/src/shared/api/activity-log-types.ts b/web/src/shared/api/activity-log-types.ts index 6f495af78..33d74c79b 100644 --- a/web/src/shared/api/activity-log-types.ts +++ b/web/src/shared/api/activity-log-types.ts @@ -46,8 +46,7 @@ export const ActivityLogEventType = { VpnClientConnected: 'vpn_client_connected', VpnClientDisconnected: 'vpn_client_disconnected', - VpnClientConnectedMfa: 'vpn_client_connected_mfa', - VpnClientDisconnectedMfa: 'vpn_client_disconnected_mfa', + VpnClientMfaSuccess: 'vpn_client_mfa_success', VpnClientMfaFailed: 'vpn_client_mfa_failed', EnrollmentTokenAdded: 'enrollment_token_added', From 8576f896bd9d2737d402f01f4a96016843dfaa89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 13:25:36 +0100 Subject: [PATCH 13/19] remove separate peer disconnect service --- crates/defguard/src/main.rs | 5 - crates/defguard_core/src/lib.rs | 1 - .../src/wireguard_peer_disconnect.rs | 187 ------------------ 3 files changed, 193 deletions(-) delete mode 100644 crates/defguard_core/src/wireguard_peer_disconnect.rs diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 0a58ac04c..7c9a0fc7c 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -34,7 +34,6 @@ use defguard_core::{ init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, version::IncompatibleComponents, - wireguard_peer_disconnect::run_periodic_peer_disconnect, wireguard_stats_purge::run_periodic_stats_purge, }; use defguard_event_logger::{message::EventLoggerMessage, run_event_logger}; @@ -218,10 +217,6 @@ async fn main() -> Result<(), anyhow::Error> { proxy_control_tx ) => error!("Web server returned early: {res:?}"), res = run_mail_handler(mail_rx) => error!("Mail handler returned early: {res:?}"), - res = run_periodic_peer_disconnect( - pool.clone(), - gateway_tx.clone(), - ) => error!("Periodic peer disconnect task returned early: {res:?}"), res = run_periodic_stats_purge( pool.clone(), config.stats_purge_frequency.into(), diff --git a/crates/defguard_core/src/lib.rs b/crates/defguard_core/src/lib.rs index 04611ca25..122ba2b60 100644 --- a/crates/defguard_core/src/lib.rs +++ b/crates/defguard_core/src/lib.rs @@ -177,7 +177,6 @@ pub mod user_management; pub mod utility_thread; pub mod version; pub mod wg_config; -pub mod wireguard_peer_disconnect; pub mod wireguard_stats_purge; #[macro_use] diff --git a/crates/defguard_core/src/wireguard_peer_disconnect.rs b/crates/defguard_core/src/wireguard_peer_disconnect.rs deleted file mode 100644 index b211e07b7..000000000 --- a/crates/defguard_core/src/wireguard_peer_disconnect.rs +++ /dev/null @@ -1,187 +0,0 @@ -//! This module implements a functionality of disconnecting inactive peers -//! in MFA-protected locations. -//! If a device does not disconnect explicitly and just becomes inactive -//! it should be removed from gateway configuration and marked as "not allowed", -//! which enforces an authentication requirement to connect again. - -use std::{ - net::{IpAddr, Ipv4Addr}, - str::FromStr, - time::Duration, -}; - -use chrono::NaiveDateTime; -use defguard_common::db::{ - Id, - models::{ - Device, DeviceNetworkInfo, DeviceType, ModelError, WireguardNetwork, WireguardNetworkError, - device::{DeviceInfo, WireguardNetworkDevice}, - wireguard::{LocationMfaMode, ServiceLocationMode}, - }, -}; -use sqlx::{Error as SqlxError, PgPool, query_as}; -use thiserror::Error; -use tokio::{ - sync::{ - broadcast::{self, Sender}, - mpsc::{self, UnboundedSender}, - }, - time::sleep, -}; - -use crate::grpc::gateway::events::GatewayEvent; - -// How long to sleep between loop iterations -const DISCONNECT_LOOP_SLEEP: Duration = Duration::from_secs(60); // 1 minute - -#[derive(Debug, Error)] -pub enum PeerDisconnectError { - #[error(transparent)] - DbError(#[from] SqlxError), - #[error(transparent)] - ModelError(#[from] ModelError), - #[error(transparent)] - WireguardError(#[from] WireguardNetworkError), - #[error("Failed to send gateway event: {0}")] - GatewayEventError(#[from] broadcast::error::SendError), -} - -#[derive(Debug)] -struct DeviceWithEndpoint { - pub id: Id, - pub name: String, - pub wireguard_pubkey: String, - pub user_id: Id, - pub created: NaiveDateTime, - pub device_type: DeviceType, - pub description: Option, - pub configured: bool, - pub endpoint: Option, -} - -impl From for Device { - fn from(device: DeviceWithEndpoint) -> Self { - Self { - id: device.id, - name: device.name, - wireguard_pubkey: device.wireguard_pubkey, - user_id: device.user_id, - created: device.created, - device_type: device.device_type, - description: device.description, - configured: device.configured, - } - } -} - -/// Run periodic disconnect task -/// -/// Run with a specified frequency and disconnect all inactive peers in MFA-protected locations. -#[instrument(skip_all)] -pub async fn run_periodic_peer_disconnect( - pool: PgPool, - wireguard_tx: Sender, -) -> Result<(), PeerDisconnectError> { - info!("Starting periodic disconnect of inactive devices in MFA-protected locations"); - loop { - debug!("Starting periodic inactive device disconnect"); - - // get all MFA-protected locations - let locations = query_as!( - WireguardNetwork::, - "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, \ - allowed_ips, connected_at, keepalive_interval, peer_disconnect_threshold, acl_enabled, \ - acl_default_allow, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", \ - service_location_mode \"service_location_mode: ServiceLocationMode\" \ - FROM wireguard_network WHERE location_mfa_mode != 'disabled'::location_mfa_mode", - ) - .fetch_all(&pool) - .await?; - - // loop over all locations - for location in locations { - debug!("Fetching inactive devices for location {location}"); - let devices = query_as!( - DeviceWithEndpoint, - "WITH stats AS ( \ - SELECT DISTINCT ON (device_id) device_id, endpoint, latest_handshake \ - FROM wireguard_peer_stats \ - WHERE network = $1 \ - ORDER BY device_id, collected_at DESC \ - ) \ - SELECT d.id, d.name, d.wireguard_pubkey, d.user_id, d.created, d.description, - d.device_type \"device_type: DeviceType\", configured, stats.endpoint \ - FROM device d \ - JOIN wireguard_network_device wnd ON wnd.device_id = d.id \ - LEFT JOIN stats on d.id = stats.device_id \ - WHERE wnd.wireguard_network_id = $1 AND wnd.is_authorized = true \ - AND d.configured = true \ - AND (NOW() - wnd.authorized_at) > $2 * interval '1 second' \ - AND (NOW() - stats.latest_handshake) > $2 * interval '1 second'", - location.id, - f64::from(location.peer_disconnect_threshold) - ) - .fetch_all(&pool) - .await?; - - for device_with_endpoint in devices { - debug!("Processing inactive device {device_with_endpoint:?}"); - let endpoint = device_with_endpoint.endpoint.clone(); - let device: Device = device_with_endpoint.into(); - - // start transaction - let mut transaction = pool.begin().await?; - - // get network config for device - if let Some(mut device_network_config) = - WireguardNetworkDevice::find(&mut *transaction, device.id, location.id).await? - { - info!( - "Marking device {device} as not authorized to connect to location {location}" - ); - // change `is_authorized` value for device - device_network_config.is_authorized = false; - // clear `preshared_key` value - device_network_config.preshared_key = None; - device_network_config.update(&mut *transaction).await?; - - debug!("Sending `peer_delete` message to gateway"); - let device_info = DeviceInfo { - device: device.clone(), - network_info: vec![DeviceNetworkInfo { - network_id: location.id, - device_wireguard_ips: device_network_config.wireguard_ips, - preshared_key: device_network_config.preshared_key, - is_authorized: device_network_config.is_authorized, - }], - }; - let event = GatewayEvent::DeviceDeleted(device_info); - wireguard_tx.send(event).map_err(|err| { - error!("Error sending WireGuard event: {err}"); - PeerDisconnectError::GatewayEventError(err) - })?; - let user = device.get_owner(&mut *transaction).await?; - let ip = endpoint - .as_ref() - .and_then(|endpoint| endpoint.split_once(':')) - .and_then(|(ip, _)| IpAddr::from_str(ip).ok()) - // endpoint is a `text` column in the db so we have to - // handle potential parsing issues here - .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); - } else { - error!( - "Network config for device {device} in location {location} not found. Skipping device..." - ); - continue; - } - - // commit transaction - transaction.commit().await?; - } - } - - // wait till next iteration - debug!("Sleeping until next iteration"); - sleep(DISCONNECT_LOOP_SLEEP).await; - } -} From e9af7b06f4384278588e4d4f7dbc2b0a22f6dd3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 13:31:32 +0100 Subject: [PATCH 14/19] don't store legacy stats --- .../defguard_core/src/grpc/gateway/handler.rs | 55 +------------------ 1 file changed, 3 insertions(+), 52 deletions(-) diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 4aa617099..d24c0d061 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -3,23 +3,17 @@ use std::{ sync::atomic::{AtomicU64, Ordering}, }; -use chrono::{DateTime, Utc}; use defguard_certs::der_to_pem; use defguard_common::{ VERSION, db::{ - Id, NoId, - models::{ - Device, Settings, WireguardNetwork, gateway::Gateway, - wireguard_peer_stats::WireguardPeerStats, - }, + Id, + models::{Device, Settings, WireguardNetwork, gateway::Gateway}, }, messages::peer_stats_update::PeerStatsUpdate, }; use defguard_mail::Mail; -use defguard_proto::gateway::{ - CoreResponse, PeerStats, core_request, core_response, gateway_client, -}; +use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; use defguard_version::client::ClientVersionInterceptor; use reqwest::Url; use semver::Version; @@ -60,26 +54,6 @@ impl Scheme { } } -fn peer_stats_from_proto(stats: PeerStats, network_id: Id, device_id: Id) -> WireguardPeerStats { - let endpoint = match stats.endpoint { - endpoint if endpoint.is_empty() => None, - _ => Some(stats.endpoint), - }; - WireguardPeerStats { - id: NoId, - network: network_id, - endpoint, - device_id, - collected_at: Utc::now().naive_utc(), - upload: stats.upload as i64, - download: stats.download as i64, - latest_handshake: DateTime::from_timestamp(stats.latest_handshake as i64, 0) - .unwrap_or_default() - .naive_utc(), - allowed_ips: Some(stats.allowed_ips), - } -} - /// One instance per connected Gateway. pub(crate) struct GatewayHandler { // Gateway server endpoint URL. @@ -382,7 +356,6 @@ impl GatewayHandler { let public_key = peer_stats.public_key.clone(); // Fetch device from database. - // TODO: fetch only when device has changed and use client state // otherwise let Ok(Some(device)) = self.fetch_device_from_db(&public_key).await else { @@ -396,14 +369,6 @@ impl GatewayHandler { // copy device ID for easier reference later let device_id = device.id; - // Convert stats to database storage format. - // FIXME: remove once legacy table is removed - let stats = peer_stats_from_proto( - peer_stats.clone(), - self.gateway.network_id, - device_id, - ); - // convert stats to DB storage format match try_protos_into_stats_message( peer_stats.clone(), @@ -425,20 +390,6 @@ impl GatewayHandler { }; } }; - - // Save stats to database. - // FIXME: remove once legacy table is removed - let stats = match stats.save(&self.pool).await { - Ok(stats) => stats, - Err(err) => { - error!( - "Saving WireGuard peer stats to database failed: {err}" - ); - continue; - } - }; - info!("Saved WireGuard peer stats to database."); - debug!("WireGuard peer stats: {stats:?}"); } None => (), } From 846fa293712eee8a9f0bc126eb5c4d147cf967f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Thu, 29 Jan 2026 15:40:58 +0100 Subject: [PATCH 15/19] avoid fetching device from DB multiple times --- .../src/messages/peer_stats_update.rs | 6 ++-- .../defguard_core/src/grpc/gateway/handler.rs | 28 +----------------- crates/defguard_core/src/grpc/gateway/mod.rs | 3 +- crates/defguard_session_manager/src/error.rs | 2 ++ crates/defguard_session_manager/src/lib.rs | 17 ++++++++--- .../src/session_state.rs | 29 +++++++++++++------ 6 files changed, 40 insertions(+), 45 deletions(-) diff --git a/crates/defguard_common/src/messages/peer_stats_update.rs b/crates/defguard_common/src/messages/peer_stats_update.rs index 9db662bd3..b9f0ac2df 100644 --- a/crates/defguard_common/src/messages/peer_stats_update.rs +++ b/crates/defguard_common/src/messages/peer_stats_update.rs @@ -10,7 +10,7 @@ use crate::db::Id; pub struct PeerStatsUpdate { pub location_id: Id, pub gateway_id: Id, - pub device_id: Id, + pub device_pubkey: String, pub collected_at: NaiveDateTime, pub endpoint: SocketAddr, // bytes sent to peer @@ -24,7 +24,7 @@ impl PeerStatsUpdate { pub fn new( location_id: Id, gateway_id: Id, - device_id: Id, + device_pubkey: String, endpoint: SocketAddr, upload: u64, download: u64, @@ -34,7 +34,7 @@ impl PeerStatsUpdate { Self { location_id, gateway_id, - device_id, + device_pubkey, collected_at, endpoint, upload, diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index d24c0d061..a744518fa 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -8,7 +8,7 @@ use defguard_common::{ VERSION, db::{ Id, - models::{Device, Settings, WireguardNetwork, gateway::Gateway}, + models::{Settings, WireguardNetwork, gateway::Gateway}, }, messages::peer_stats_update::PeerStatsUpdate, }; @@ -239,15 +239,6 @@ impl GatewayHandler { } } - /// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors - async fn fetch_device_from_db( - &self, - public_key: &str, - ) -> Result>, GatewayError> { - let device = Device::find_by_pubkey(&self.pool, public_key).await?; - Ok(device) - } - /// Connect to Gateway and handle its messages through gRPC. pub(crate) async fn handle_connection(&mut self) -> Result<(), GatewayError> { let endpoint = self.endpoint(Scheme::Https)?; @@ -353,28 +344,11 @@ impl GatewayHandler { continue; } - let public_key = peer_stats.public_key.clone(); - - // Fetch device from database. - // otherwise - let Ok(Some(device)) = self.fetch_device_from_db(&public_key).await - else { - warn!( - "Received stats update for a device which does not \ - exist: {public_key}, skipping." - ); - continue; - }; - - // copy device ID for easier reference later - let device_id = device.id; - // convert stats to DB storage format match try_protos_into_stats_message( peer_stats.clone(), self.gateway.network_id, self.gateway.id, - device_id, ) { None => { warn!( diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index c63052cd6..85c614265 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -68,7 +68,6 @@ fn try_protos_into_stats_message( proto_stats: PeerStats, location_id: Id, gateway_id: Id, - device_id: Id, ) -> Option { // try to parse endpoint let endpoint = proto_stats.endpoint.parse().ok()?; @@ -80,7 +79,7 @@ fn try_protos_into_stats_message( Some(PeerStatsUpdate::new( location_id, gateway_id, - device_id, + proto_stats.public_key, endpoint, proto_stats.upload, proto_stats.download, diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 4b2150756..8e4c6bca7 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -21,6 +21,8 @@ pub enum SessionManagerError { UserDoesNotExistError(Id), #[error("Device with ID {0} does not exist")] DeviceDoesNotExistError(Id), + #[error("Device with pubkey {0} does not exist")] + DevicePubkeyDoesNotExistError(String), #[error("Location with ID {0} does not exist")] LocationDoesNotExistError(Id), #[error("VPN client session with ID {0} does not exist")] diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 84bbc8f84..ed3079270 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -135,17 +135,26 @@ impl SessionManager { // check if a session exists already for a given peer // and attempt to add one if necessary let maybe_session = match active_sessions - .try_get_peer_session(transaction, message.location_id, message.device_id) + .try_get_peer_session( + transaction, + message.location_id, + message.device_pubkey.clone(), + ) .await? { Some(session) => Some(session), None => { debug!( - "No active session found for device {} in location {}. Creating a new session", - message.device_id, message.location_id + "No active session found for device with pubkey {} in location {}. Creating a new session", + message.device_pubkey, message.location_id ); active_sessions - .try_add_new_session(transaction, &message, &self.session_manager_event_tx) + .try_add_new_session( + transaction, + &message, + &message.device_pubkey, + &self.session_manager_event_tx, + ) .await? } }; diff --git a/crates/defguard_session_manager/src/session_state.rs b/crates/defguard_session_manager/src/session_state.rs index 89ec6f751..1b87acfb7 100644 --- a/crates/defguard_session_manager/src/session_state.rs +++ b/crates/defguard_session_manager/src/session_state.rs @@ -194,7 +194,7 @@ pub(crate) struct ActiveSessionsMap { sessions: HashMap, locations: HashMap>, users: HashMap>, - devices: HashMap>, + devices: HashMap>, } impl ActiveSessionsMap { @@ -216,8 +216,11 @@ impl ActiveSessionsMap { &mut self, transaction: &mut PgConnection, location_id: Id, - device_id: Id, + device_pubkey: String, ) -> Result, SessionManagerError> { + // translate pubkey into device ID + let device_id = self.get_device(&mut *transaction, device_pubkey).await?.id; + // try to get session from current map let session_map = self.get_or_create_location_session_map(location_id); if session_map.0.contains_key(&device_id) { @@ -264,6 +267,7 @@ impl ActiveSessionsMap { &mut self, transaction: &mut PgConnection, stats_update: &PeerStatsUpdate, + device_pubkey: &str, event_tx: &UnboundedSender, ) -> Result, SessionManagerError> { // fetch location @@ -295,8 +299,11 @@ impl ActiveSessionsMap { // fetch other related objects from DB // clone them because we'll need those for event context - let device_id = stats_update.device_id; - let device = self.get_device(&mut *transaction, device_id).await?.clone(); + let device = self + .get_device(&mut *transaction, device_pubkey.into()) + .await? + .clone(); + let device_id = device.id; let user = self .get_user(&mut *transaction, device.user_id) .await? @@ -382,16 +389,20 @@ impl ActiveSessionsMap { async fn get_device<'e, E: sqlx::PgExecutor<'e>>( &mut self, executor: E, - device_id: Id, + device_pubkey: String, ) -> Result<&Device, SessionManagerError> { // first try to find device in object cache - let device_entry = match self.devices.entry(device_id) { + let device_entry = match self.devices.entry(device_pubkey.clone()) { Entry::Occupied(occupied_entry) => occupied_entry, Entry::Vacant(vacant_entry) => { - debug!("Device {device_id} not found in object cache. Trying to fetch from DB."); - let device = Device::find_by_id(executor, device_id) + debug!( + "Device {device_pubkey} not found in object cache. Trying to fetch from DB." + ); + let device = Device::find_by_pubkey(executor, &device_pubkey) .await? - .ok_or(SessionManagerError::DeviceDoesNotExistError(device_id))?; + .ok_or(SessionManagerError::DevicePubkeyDoesNotExistError( + device_pubkey, + ))?; // update object cache vacant_entry.insert_entry(device) } From 8bb115bce14e9e2d2fea62aecd57c70437189de4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 30 Jan 2026 07:17:36 +0100 Subject: [PATCH 16/19] update query data --- ...8bfbed999267ca3ced60f27f1e3c08b2e9574.json | 144 ------------------ ...d2ae962ab8f9b3e33fafa9906392824411190.json | 81 ---------- 2 files changed, 225 deletions(-) delete mode 100644 .sqlx/query-0e2e17c49ac83e9b8b70c8bcadc8bfbed999267ca3ced60f27f1e3c08b2e9574.json delete mode 100644 .sqlx/query-8e93732789b7a3da041e0382d91d2ae962ab8f9b3e33fafa9906392824411190.json diff --git a/.sqlx/query-0e2e17c49ac83e9b8b70c8bcadc8bfbed999267ca3ced60f27f1e3c08b2e9574.json b/.sqlx/query-0e2e17c49ac83e9b8b70c8bcadc8bfbed999267ca3ced60f27f1e3c08b2e9574.json deleted file mode 100644 index 3ee27a31e..000000000 --- a/.sqlx/query-0e2e17c49ac83e9b8b70c8bcadc8bfbed999267ca3ced60f27f1e3c08b2e9574.json +++ /dev/null @@ -1,144 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT id, name, address, port, pubkey, prvkey, endpoint, dns, mtu, fwmark, allowed_ips, connected_at, keepalive_interval, peer_disconnect_threshold, acl_enabled, acl_default_allow, location_mfa_mode \"location_mfa_mode: LocationMfaMode\", service_location_mode \"service_location_mode: ServiceLocationMode\" FROM wireguard_network WHERE location_mfa_mode != 'disabled'::location_mfa_mode", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "name", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "address", - "type_info": "InetArray" - }, - { - "ordinal": 3, - "name": "port", - "type_info": "Int4" - }, - { - "ordinal": 4, - "name": "pubkey", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "prvkey", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "endpoint", - "type_info": "Text" - }, - { - "ordinal": 7, - "name": "dns", - "type_info": "Text" - }, - { - "ordinal": 8, - "name": "mtu", - "type_info": "Int4" - }, - { - "ordinal": 9, - "name": "fwmark", - "type_info": "Int8" - }, - { - "ordinal": 10, - "name": "allowed_ips", - "type_info": "InetArray" - }, - { - "ordinal": 11, - "name": "connected_at", - "type_info": "Timestamp" - }, - { - "ordinal": 12, - "name": "keepalive_interval", - "type_info": "Int4" - }, - { - "ordinal": 13, - "name": "peer_disconnect_threshold", - "type_info": "Int4" - }, - { - "ordinal": 14, - "name": "acl_enabled", - "type_info": "Bool" - }, - { - "ordinal": 15, - "name": "acl_default_allow", - "type_info": "Bool" - }, - { - "ordinal": 16, - "name": "location_mfa_mode: LocationMfaMode", - "type_info": { - "Custom": { - "name": "location_mfa_mode", - "kind": { - "Enum": [ - "disabled", - "internal", - "external" - ] - } - } - } - }, - { - "ordinal": 17, - "name": "service_location_mode: ServiceLocationMode", - "type_info": { - "Custom": { - "name": "service_location_mode", - "kind": { - "Enum": [ - "disabled", - "prelogon", - "alwayson" - ] - } - } - } - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - false, - false, - false, - false, - false, - false, - true, - false, - false, - false, - true, - false, - false, - false, - false, - false, - false - ] - }, - "hash": "0e2e17c49ac83e9b8b70c8bcadc8bfbed999267ca3ced60f27f1e3c08b2e9574" -} diff --git a/.sqlx/query-8e93732789b7a3da041e0382d91d2ae962ab8f9b3e33fafa9906392824411190.json b/.sqlx/query-8e93732789b7a3da041e0382d91d2ae962ab8f9b3e33fafa9906392824411190.json deleted file mode 100644 index 59d3acf18..000000000 --- a/.sqlx/query-8e93732789b7a3da041e0382d91d2ae962ab8f9b3e33fafa9906392824411190.json +++ /dev/null @@ -1,81 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "WITH stats AS ( SELECT DISTINCT ON (device_id) device_id, endpoint, latest_handshake FROM wireguard_peer_stats WHERE network = $1 ORDER BY device_id, collected_at DESC ) SELECT d.id, d.name, d.wireguard_pubkey, d.user_id, d.created, d.description,\n d.device_type \"device_type: DeviceType\", configured, stats.endpoint FROM device d JOIN wireguard_network_device wnd ON wnd.device_id = d.id LEFT JOIN stats on d.id = stats.device_id WHERE wnd.wireguard_network_id = $1 AND wnd.is_authorized = true AND d.configured = true AND (NOW() - wnd.authorized_at) > $2 * interval '1 second' AND (NOW() - stats.latest_handshake) > $2 * interval '1 second'", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "name", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "wireguard_pubkey", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "user_id", - "type_info": "Int8" - }, - { - "ordinal": 4, - "name": "created", - "type_info": "Timestamp" - }, - { - "ordinal": 5, - "name": "description", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "device_type: DeviceType", - "type_info": { - "Custom": { - "name": "device_type", - "kind": { - "Enum": [ - "user", - "network" - ] - } - } - } - }, - { - "ordinal": 7, - "name": "configured", - "type_info": "Bool" - }, - { - "ordinal": 8, - "name": "endpoint", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Int8", - "Float8" - ] - }, - "nullable": [ - false, - false, - false, - false, - false, - true, - false, - false, - true - ] - }, - "hash": "8e93732789b7a3da041e0382d91d2ae962ab8f9b3e33fafa9906392824411190" -} From af9e118d35f8536161d0a084991b00d62f4431db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 30 Jan 2026 07:26:20 +0100 Subject: [PATCH 17/19] remove duplicate allowed peers method --- crates/defguard_core/src/grpc/gateway/handler.rs | 5 +++-- .../defguard_core/src/location_management/allowed_peers.rs | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index a744518fa..b8aeb883b 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -32,9 +32,10 @@ use crate::{ enterprise::firewall::try_get_location_firewall_config, grpc::{ TEN_SECS, - gateway::{GatewayError, events::GatewayEvent, get_peers, try_protos_into_stats_message}, + gateway::{GatewayError, events::GatewayEvent, try_protos_into_stats_message}, }, handlers::mail::send_gateway_disconnected_email, + location_management::allowed_peers::get_location_allowed_peers, }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -166,7 +167,7 @@ impl GatewayHandler { ); } - let peers = get_peers(&network, &self.pool).await?; + let peers = get_location_allowed_peers(&network, &self.pool).await?; let maybe_firewall_config = try_get_location_firewall_config(&network, &mut conn).await?; let payload = Some(core_response::Payload::Config(super::gen_config( diff --git a/crates/defguard_core/src/location_management/allowed_peers.rs b/crates/defguard_core/src/location_management/allowed_peers.rs index 94195212c..8017ca1c8 100644 --- a/crates/defguard_core/src/location_management/allowed_peers.rs +++ b/crates/defguard_core/src/location_management/allowed_peers.rs @@ -17,7 +17,7 @@ pub async fn get_location_allowed_peers<'e, E>( where E: PgExecutor<'e>, { - debug!("Fetching all peers for network {}", location.id); + debug!("Fetching all allowed peers for location {}", location.id); if should_prevent_service_location_usage(location) { warn!( From 16ddec7101c52d7a9de282bf4c5784b77d444972 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 30 Jan 2026 09:52:25 +0100 Subject: [PATCH 18/19] add dedicated gateway event for successful client MFA --- .../defguard_core/src/grpc/gateway/events.rs | 6 ++- crates/defguard_core/src/grpc/gateway/mod.rs | 39 +++++++++++++++++++ .../src/grpc/proxy/client_mfa.rs | 33 ++++++---------- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/crates/defguard_core/src/grpc/gateway/events.rs b/crates/defguard_core/src/grpc/gateway/events.rs index 6ea648c11..68596f21b 100644 --- a/crates/defguard_core/src/grpc/gateway/events.rs +++ b/crates/defguard_core/src/grpc/gateway/events.rs @@ -1,6 +1,9 @@ use defguard_common::db::{ Id, - models::{Device, WireguardNetwork, device::DeviceInfo}, + models::{ + Device, WireguardNetwork, + device::{DeviceInfo, WireguardNetworkDevice}, + }, }; use defguard_proto::{enterprise::firewall::FirewallConfig, gateway::Peer}; @@ -22,5 +25,6 @@ pub enum GatewayEvent { DeviceDeleted(DeviceInfo), FirewallConfigChanged(LocationId, FirewallConfig), FirewallDisabled(LocationId), + MfaSessionAuthorized(LocationId, Device, WireguardNetworkDevice), MfaSessionDisconnected(LocationId, Device), } diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 85c614265..32f1ec7b5 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -387,6 +387,8 @@ impl GatewayUpdatesHandler { .find(|info| info.network_id == self.network_id) { Some(network_info) => { + // FIXME: this shouldn't happen, since when the device is created + // it's impossible for MFA authorization to already be completed if self.network.mfa_enabled() && !network_info.is_authorized { debug!( "Created WireGuard device {} is not authorized to connect to \ @@ -481,6 +483,43 @@ impl GatewayUpdatesHandler { Ok(()) } } + GatewayEvent::MfaSessionAuthorized(location_id, device, network_device) => { + if location_id == self.network_id { + // validate that network info is for the correct location + if network_device.wireguard_network_id != location_id { + error!( + "Received MFA authorization success event for location {location_id} with invalid device config: {network_device:?}" + ); + continue; + } + + // FIXME: at this point the device authorization should already have been verified + if self.network.mfa_enabled() && !network_device.is_authorized { + debug!( + "Created WireGuard device {} is not authorized to connect to \ + MFA enabled location {}", + device.name, self.network.name + ); + continue; + } + + self.send_peer_update( + Peer { + pubkey: device.wireguard_pubkey, + allowed_ips: network_device + .wireguard_ips + .iter() + .map(IpAddr::to_string) + .collect(), + preshared_key: network_device.preshared_key.clone(), + keepalive_interval: Some(self.network.keepalive_interval as u32), + }, + 0, + ) + } else { + Ok(()) + } + } }; if result.is_err() { error!( diff --git a/crates/defguard_core/src/grpc/proxy/client_mfa.rs b/crates/defguard_core/src/grpc/proxy/client_mfa.rs index 9440e5026..0ba29f102 100644 --- a/crates/defguard_core/src/grpc/proxy/client_mfa.rs +++ b/crates/defguard_core/src/grpc/proxy/client_mfa.rs @@ -10,9 +10,8 @@ use defguard_common::{ db::{ Id, models::{ - BiometricAuth, BiometricChallenge, Device, DeviceNetworkInfo, User, WireguardNetwork, - device::{DeviceInfo, WireguardNetworkDevice}, - vpn_client_session::VpnClientSession, + BiometricAuth, BiometricChallenge, Device, User, WireguardNetwork, + device::WireguardNetworkDevice, vpn_client_session::VpnClientSession, wireguard::LocationMfaMode, }, }, @@ -680,16 +679,7 @@ impl ClientMfaServer { // send gateway event debug!("Sending `peer_create` message to gateway"); - let device_info = DeviceInfo { - device: device.clone(), - network_info: vec![DeviceNetworkInfo { - network_id: location.id, - device_wireguard_ips: network_device.wireguard_ips, - preshared_key: network_device.preshared_key.clone(), - is_authorized: network_device.is_authorized, - }], - }; - let event = GatewayEvent::DeviceCreated(device_info); + let event = GatewayEvent::MfaSessionAuthorized(location.id, device.clone(), network_device); self.wireguard_tx.send(event).map_err(|err| { error!("Error sending WireGuard event: {err}"); Status::internal("unexpected error") @@ -729,7 +719,7 @@ impl ClientMfaServer { debug!("Created new VPN client session: {vpn_client_session:?}"); let response = ClientMfaFinishResponse { - preshared_key: key.public, + preshared_key: key.public.clone(), token: match method { MfaMethod::MobileApprove => Some(request.token.clone()), _ => None, @@ -749,14 +739,13 @@ impl ClientMfaServer { })?; // If there is a desktop client websocket waiting for the preshared key, send it. - if let (Some(tx), Some(ref preshared_key)) = ( - self.remote_mfa_responses - .write() - .expect("Failed to write-lock ClientMfaServer::remote_mfa_responses") - .remove(&request.token), - network_device.preshared_key, - ) { - let _ = tx.send(preshared_key.clone()); + if let Some(tx) = self + .remote_mfa_responses + .write() + .expect("Failed to write-lock ClientMfaServer::remote_mfa_responses") + .remove(&request.token) + { + let _ = tx.send(key.public.clone()); } Ok(response) From b3eeb0840f2a1d329d0e58d8c0a0f142e2a94a68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Fri, 30 Jan 2026 22:47:13 +0100 Subject: [PATCH 19/19] clear preshared key when disconnecting MFA device --- crates/defguard_session_manager/src/lib.rs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 2c818850f..451e90c0e 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -4,7 +4,10 @@ use chrono::Utc; use defguard_common::{ db::{ Id, - models::{Device, User, WireguardNetwork, vpn_client_session::VpnClientSession}, + models::{ + Device, User, WireguardNetwork, device::WireguardNetworkDevice, + vpn_client_session::VpnClientSession, + }, }, messages::peer_stats_update::PeerStatsUpdate, }; @@ -263,6 +266,15 @@ impl SessionManager { // remove peers from GW for MFA locations if location.mfa_enabled() { + // FIXME: remove one MFA-related data is no longer stored here + // update device network config + if let Some(mut device_network_info) = + WireguardNetworkDevice::find(&mut *transaction, device.id, location.id).await? + { + device_network_info.is_authorized = false; + device_network_info.preshared_key = None; + device_network_info.update(&mut *transaction).await?; + }; self.send_peer_disconnect_message(location, &device)?; }