diff --git a/proto b/proto index 41343581..0b982922 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 4134358160e4f819515c9a6e5c014434cfb46d74 +Subproject commit 0b982922c4dab3304a8cb01aed1d8cee806600b7 diff --git a/src/enterprise/handlers/desktop_client_mfa.rs b/src/enterprise/handlers/desktop_client_mfa.rs index fc756143..60b5a674 100644 --- a/src/enterprise/handlers/desktop_client_mfa.rs +++ b/src/enterprise/handlers/desktop_client_mfa.rs @@ -90,7 +90,7 @@ pub(super) async fn mfa_auth_callback( device_info, )?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::Empty(()) = payload { info!("MFA authentication callback completed successfully"); diff --git a/src/enterprise/handlers/openid_login.rs b/src/enterprise/handlers/openid_login.rs index 49ad8385..848c2818 100644 --- a/src/enterprise/handlers/openid_login.rs +++ b/src/enterprise/handlers/openid_login.rs @@ -76,7 +76,7 @@ async fn auth_info( let rx = state .grpc_server .send(core_request::Payload::AuthInfo(request), device_info)?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::AuthInfo(response) = payload { debug!("Received auth info response"); @@ -164,7 +164,7 @@ async fn auth_callback( let rx = state .grpc_server .send(core_request::Payload::AuthCallback(request), device_info)?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::AuthCallback(AuthCallbackResponse { url, token }) = payload { debug!("Received auth callback response {url:?} {token:?}"); diff --git a/src/grpc.rs b/src/grpc.rs index 775f477c..38ac682c 100644 --- a/src/grpc.rs +++ b/src/grpc.rs @@ -40,8 +40,8 @@ pub struct Configuration { pub(crate) struct ProxyServer { current_id: Arc, - clients: Arc>, - results: Arc>>>, + clients: Arc>, + results: Arc>>>, pub(crate) connected: Arc, pub(crate) core_version: Arc>>, config: Arc>>, @@ -55,8 +55,8 @@ impl ProxyServer { Self { cookie_key, current_id: Arc::new(AtomicU64::new(1)), - clients: Arc::new(Mutex::new(HashMap::new())), - results: Arc::new(Mutex::new(HashMap::new())), + clients: Arc::new(RwLock::new(HashMap::new())), + results: Arc::new(RwLock::new(HashMap::new())), connected: Arc::new(AtomicBool::new(false)), core_version: Arc::new(Mutex::new(None)), config: Arc::new(Mutex::new(None)), @@ -126,7 +126,7 @@ impl ProxyServer { ) -> Result, ApiError> { if let Some(client_tx) = self .clients - .lock() + .read() .expect("Failed to acquire lock on clients hashmap when sending message to core") .values() .next() @@ -143,7 +143,7 @@ impl ProxyServer { } let (tx, rx) = oneshot::channel(); self.results - .lock() + .write() .expect("Failed to acquire lock on results hashmap when sending CoreRequest") .insert(id, tx); self.connected.store(true, Ordering::Relaxed); @@ -214,7 +214,7 @@ impl proxy_server::Proxy for ProxyServer { info!("Defguard Core gRPC client connected from: {address}"); let (tx, rx) = mpsc::unbounded_channel(); self.clients - .lock() + .write() .expect( "Failed to acquire lock on clients hashmap when registering new core connection", ) @@ -241,7 +241,7 @@ impl proxy_server::Proxy for ProxyServer { *cookie_key.write().unwrap() = Some(key); }, _ => { - let maybe_rx = results.lock().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id); + let maybe_rx = results.write().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id); if let Some(rx) = maybe_rx { if let Err(err) = rx.send(payload) { error!("Failed to send message to rx {:?}", err.type_id()); @@ -265,7 +265,7 @@ impl proxy_server::Proxy for ProxyServer { } info!("Defguard core client disconnected: {address}"); connected.store(false, Ordering::Relaxed); - clients.lock().expect("Failed to acquire lock on clients hashmap when removing disconnected client").remove(&address); + clients.write().expect("Failed to acquire lock on clients hashmap when removing disconnected client").remove(&address); } .instrument(tracing::Span::current()), ); diff --git a/src/handlers/desktop_client_mfa.rs b/src/handlers/desktop_client_mfa.rs index 55a62b68..9f278ff8 100644 --- a/src/handlers/desktop_client_mfa.rs +++ b/src/handlers/desktop_client_mfa.rs @@ -1,4 +1,4 @@ -use std::collections::hash_map::Entry; +use std::time::Duration; use axum::{ extract::{ @@ -12,18 +12,23 @@ use axum::{ use futures_util::{sink::SinkExt, stream::StreamExt}; use serde::Deserialize; use serde_json::json; -use tokio::{sync::oneshot, task::JoinSet}; +use tokio::task::JoinSet; use crate::{ error::ApiError, handlers::get_core_response, http::AppState, proto::{ - core_request, core_response, ClientMfaFinishRequest, ClientMfaFinishResponse, + core_request, + core_response::{self, Payload}, + AwaitRemoteMfaFinishRequest, ClientMfaFinishRequest, ClientMfaFinishResponse, ClientMfaStartRequest, ClientMfaStartResponse, DeviceInfo, }, }; +// How much time the user has to approve remote MFA with mobile device +const REMOTE_AUTH_TIMEOUT: Duration = Duration::from_secs(60); + pub(crate) fn router() -> Router { Router::new() .route("/start", post(start_client_mfa)) @@ -53,66 +58,74 @@ async fn await_remote_auth( token: token.clone(), }, ), - device_info, + device_info.clone(), )?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, Some(REMOTE_AUTH_TIMEOUT)).await?; if let core_response::Payload::ClientMfaTokenValidation(response) = payload { if !response.token_valid { return Err(ApiError::Unauthorized(String::new())); } - // check if its already in the map - let contains_key = { - let sessions = state.remote_mfa_sessions.lock().await; - sessions.contains_key(&token) - }; - if contains_key { - return Err(ApiError::Unauthorized(String::new())); - } - Ok(ws.on_upgrade(move |socket| handle_remote_auth_socket(socket, state.clone(), token))) + + Ok(ws.on_upgrade(move |socket| { + handle_remote_auth_socket(socket, state.clone(), token, device_info) + })) } else { Err(ApiError::InvalidResponseType) } } /// Handle axum web socket upgrade for `await_remote_auth`. -async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: String) { - let (tx, rx) = oneshot::channel(); - - { - let mut sessions = state.remote_mfa_sessions.lock().await; - match sessions.entry(token.clone()) { - Entry::Occupied(_) => { - return; - } - Entry::Vacant(v) => { - v.insert(tx); - } - } - } - +async fn handle_remote_auth_socket( + socket: WebSocket, + state: AppState, + token: String, + device_info: DeviceInfo, +) { let (mut ws_tx, mut ws_rx) = socket.split(); let mut set = JoinSet::new(); + let request = AwaitRemoteMfaFinishRequest { token }; + let rx = match state.grpc_server.send( + core_request::Payload::AwaitRemoteMfaFinish(request), + device_info, + ) { + Ok(rx) => rx, + Err(err) => { + error!("Failed to send ClientRemoteMfaFinishRequest: {err:?}"); + return; + } + }; + + // Response to ClientRemoteMfaFinishRequest comes once the user concludes MFA with mobile device. + // This task then sends the preshared key to the WebSocket where desktop client awaits for it. set.spawn(async move { - if let Ok(msg) = rx.await { - let payload = json!({ - "type": "mfa_success", - "preshared_key": &msg, - }); - if let Ok(serialized) = serde_json::to_string(&payload) { - let message = Message::Text(serialized.into()); - if ws_tx.send(message).await.is_err() { - error!("Failed to send preshared key via ws"); + match rx.await { + Ok(Payload::AwaitRemoteMfaFinish(response)) => { + let ws_response = json!({ + "type": "mfa_success", + "preshared_key": &response.preshared_key, + }); + if let Ok(serialized) = serde_json::to_string(&ws_response) { + let message = Message::Text(serialized.into()); + if let Err(err) = ws_tx.send(message).await { + error!("Failed to send preshared key via ws: {err:?}"); + } } - } else { - error!("Failed to serialize remote mfa ws client response message"); } - } else { - error!("Failed to receive preshared key from receiver"); - } + Ok(_) => { + error!("Received wrong response type, expected ClientRemoteMfaFinish"); + } + Err(err) => { + error!("Failed to receive preshared key from receiver: {err:?}"); + } + }; + + // Close the websocket once we're done. let _ = ws_tx.close().await; }); + // Another task to monitor the websocket connection in case desktop client disconnects + // or the connection errors-out. set.spawn(async move { while let Some(msg_result) = ws_rx.next().await { match msg_result { @@ -129,10 +142,9 @@ async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: St } }); + // Wait for whichever task finishes first and kill the other one. let _ = set.join_next().await; set.shutdown().await; - // This will remove token, if it's still there. - state.remote_mfa_sessions.lock().await.remove(&token); } #[instrument(level = "debug", skip(state, req))] @@ -146,7 +158,7 @@ async fn start_client_mfa( core_request::Payload::ClientMfaStart(req.clone()), device_info, )?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::ClientMfaStart(response) = payload { info!("Started desktop client authorization {req:?}"); @@ -167,7 +179,7 @@ async fn finish_client_mfa( let rx = state .grpc_server .send(core_request::Payload::ClientMfaFinish(req), device_info)?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::ClientMfaFinish(response) = payload { Ok(Json(response)) } else { @@ -186,32 +198,10 @@ async fn finish_remote_mfa( let rx = state .grpc_server .send(core_request::Payload::ClientMfaFinish(req), device_info)?; - let payload = get_core_response(rx).await?; - if let core_response::Payload::ClientMfaFinish(response) = payload { - // Check if this needs to be forwarded. - if let Some(token) = response.token { - let sender_option = { - let mut sessions = state.remote_mfa_sessions.lock().await; - sessions.remove(&token) - }; - if let Some(sender) = sender_option { - let _ = sender.send(response.preshared_key); - } - // If desktop stopped listening for the result, there will be no place to send the - // result. - else { - error!("Remote MFA approve finished but session was not found."); - return Err(ApiError::Unexpected(String::new())); - } - - info!("Finished desktop client authorization via mobile device"); - Ok(Json(json!({}))) - } else { - error!("Remote MFA Unexpected core response, token was not returned"); - Err(ApiError::Unexpected(String::new())) - } + if let core_response::Payload::ClientMfaFinish(_response) = get_core_response(rx, None).await? { + Ok(Json(json!({}))) } else { - error!("Received invalid gRPC response type"); + error!("Received invalid gRPC response type, expected ClientMfaFinish"); Err(ApiError::InvalidResponseType) } } diff --git a/src/handlers/enrollment.rs b/src/handlers/enrollment.rs index b8c0e73c..4979f354 100644 --- a/src/handlers/enrollment.rs +++ b/src/handlers/enrollment.rs @@ -45,7 +45,7 @@ async fn start_enrollment_process( let rx = state .grpc_server .send(core_request::Payload::EnrollmentStart(req), device_info)?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; debug!("Receving payload from the core service. Try to set private cookie for starting enrollment process."); if let core_response::Payload::EnrollmentStart(response) = payload { info!( @@ -83,7 +83,7 @@ async fn activate_user( let rx = state .grpc_server .send(core_request::Payload::ActivateUser(req), device_info)?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; debug!("Receiving payload from the core service. Trying to remove private cookie..."); if let core_response::Payload::Empty(()) = payload { info!("Activated user - phone number {phone:?}"); @@ -116,7 +116,7 @@ async fn create_device( let rx = state .grpc_server .send(core_request::Payload::NewDevice(req), device_info)?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::DeviceConfig(response) = payload { info!("Added new device {name} {pubkey}"); Ok(Json(response)) @@ -144,7 +144,7 @@ async fn get_network_info( let rx = state .grpc_server .send(core_request::Payload::ExistingDevice(req), device_info)?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::DeviceConfig(response) = payload { info!("Got network info for device {pubkey}"); Ok(Json(response)) diff --git a/src/handlers/mobile_client.rs b/src/handlers/mobile_client.rs index b39c2e98..c825d897 100644 --- a/src/handlers/mobile_client.rs +++ b/src/handlers/mobile_client.rs @@ -53,7 +53,7 @@ pub(crate) async fn register_mobile_auth( core_request::Payload::RegisterMobileAuth(send_data), device_info, )?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; if let core_response::Payload::Empty(()) = payload { info!("Registered mobile device for auth"); Ok(()) diff --git a/src/handlers/mod.rs b/src/handlers/mod.rs index 629cbc2d..a1557cb9 100644 --- a/src/handlers/mod.rs +++ b/src/handlers/mod.rs @@ -3,7 +3,7 @@ use std::time::Duration; use axum::{extract::FromRequestParts, http::request::Parts}; use axum_client_ip::{InsecureClientIp, LeftmostXForwardedFor}; use axum_extra::{headers::UserAgent, TypedHeader}; -use tokio::{sync::oneshot::Receiver, time::timeout}; +use tokio::{sync::oneshot::Receiver, time}; use tonic::Code; use super::proto::DeviceInfo; @@ -69,9 +69,12 @@ where /// Helper which awaits core response /// /// Waits for core response with a given timeout and returns the response payload. -pub(crate) async fn get_core_response(rx: Receiver) -> Result { +pub(crate) async fn get_core_response( + rx: Receiver, + timeout: Option, +) -> Result { debug!("Fetching core response."); - if let Ok(core_response) = timeout(CORE_RESPONSE_TIMEOUT, rx).await { + if let Ok(core_response) = time::timeout(timeout.unwrap_or(CORE_RESPONSE_TIMEOUT), rx).await { debug!("Got gRPC response from Defguard Core"); if let Ok(Payload::CoreError(core_error)) = core_response { if core_error.status_code == Code::FailedPrecondition as i32 @@ -92,7 +95,10 @@ pub(crate) async fn get_core_response(rx: Receiver) -> Result Ok(Json(response)), _ => Err(ApiError::InvalidResponseType), @@ -90,7 +90,7 @@ async fn register_code_mfa_finish( }), device_info, )?; - let payload = get_core_response(rx).await?; + let payload = get_core_response(rx, None).await?; match payload { core_response::Payload::CodeMfaSetupFinishResponse(response) => Ok(Json(response)), _ => Err(ApiError::InvalidResponseType), diff --git a/src/http.rs b/src/http.rs index 423ffe74..73ecd61d 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,5 +1,4 @@ use std::{ - collections::HashMap, net::{IpAddr, Ipv4Addr, SocketAddr}, path::Path, sync::{atomic::Ordering, Arc, RwLock}, @@ -20,7 +19,7 @@ use axum_extra::extract::cookie::Key; use clap::crate_version; use defguard_version::{server::DefguardVersionLayer, Version}; use serde::Serialize; -use tokio::{net::TcpListener, sync::oneshot, task::JoinSet}; +use tokio::{net::TcpListener, task::JoinSet}; use tower_governor::{ governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, GovernorLayer, }; @@ -52,8 +51,6 @@ pub const GRPC_KEY_NAME: &str = "proxy_grpc_key.pem"; #[derive(Clone)] pub(crate) struct AppState { pub(crate) grpc_server: ProxyServer, - pub(crate) remote_mfa_sessions: - Arc>>>, cookie_key: Arc>>, url: Url, } @@ -273,7 +270,6 @@ pub async fn run_server(env_config: EnvConfig, config: Configuration) -> anyhow: let shared_state = AppState { cookie_key, grpc_server, - remote_mfa_sessions: Arc::new(tokio::sync::Mutex::new(HashMap::new())), url: env_config.url.clone(), };