Skip to content
Merged
2 changes: 1 addition & 1 deletion proto
Submodule proto updated 1 files
+10 −0 core/proxy.proto
2 changes: 1 addition & 1 deletion src/enterprise/handlers/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub(super) async fn mfa_auth_callback(
device_info,
)?;

let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;

if let core_response::Payload::Empty(()) = payload {
info!("MFA authentication callback completed successfully");
Expand Down
4 changes: 2 additions & 2 deletions src/enterprise/handlers/openid_login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async fn auth_info(
let rx = state
.grpc_server
.send(core_request::Payload::AuthInfo(request), device_info)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;
if let core_response::Payload::AuthInfo(response) = payload {
debug!("Received auth info response");

Expand Down Expand Up @@ -164,7 +164,7 @@ async fn auth_callback(
let rx = state
.grpc_server
.send(core_request::Payload::AuthCallback(request), device_info)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;

if let core_response::Payload::AuthCallback(AuthCallbackResponse { url, token }) = payload {
debug!("Received auth callback response {url:?} {token:?}");
Expand Down
18 changes: 9 additions & 9 deletions src/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ pub struct Configuration {

pub(crate) struct ProxyServer {
current_id: Arc<AtomicU64>,
clients: Arc<Mutex<ClientMap>>,
results: Arc<Mutex<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
clients: Arc<RwLock<ClientMap>>,
results: Arc<RwLock<HashMap<u64, oneshot::Sender<core_response::Payload>>>>,
pub(crate) connected: Arc<AtomicBool>,
pub(crate) core_version: Arc<Mutex<Option<Version>>>,
config: Arc<Mutex<Option<Configuration>>>,
Expand All @@ -55,8 +55,8 @@ impl ProxyServer {
Self {
cookie_key,
current_id: Arc::new(AtomicU64::new(1)),
clients: Arc::new(Mutex::new(HashMap::new())),
results: Arc::new(Mutex::new(HashMap::new())),
clients: Arc::new(RwLock::new(HashMap::new())),
results: Arc::new(RwLock::new(HashMap::new())),
connected: Arc::new(AtomicBool::new(false)),
core_version: Arc::new(Mutex::new(None)),
config: Arc::new(Mutex::new(None)),
Expand Down Expand Up @@ -126,7 +126,7 @@ impl ProxyServer {
) -> Result<oneshot::Receiver<core_response::Payload>, ApiError> {
if let Some(client_tx) = self
.clients
.lock()
.read()
.expect("Failed to acquire lock on clients hashmap when sending message to core")
.values()
.next()
Expand All @@ -143,7 +143,7 @@ impl ProxyServer {
}
let (tx, rx) = oneshot::channel();
self.results
.lock()
.write()
.expect("Failed to acquire lock on results hashmap when sending CoreRequest")
.insert(id, tx);
self.connected.store(true, Ordering::Relaxed);
Expand Down Expand Up @@ -214,7 +214,7 @@ impl proxy_server::Proxy for ProxyServer {
info!("Defguard Core gRPC client connected from: {address}");
let (tx, rx) = mpsc::unbounded_channel();
self.clients
.lock()
.write()
.expect(
"Failed to acquire lock on clients hashmap when registering new core connection",
)
Expand All @@ -241,7 +241,7 @@ impl proxy_server::Proxy for ProxyServer {
*cookie_key.write().unwrap() = Some(key);
},
_ => {
let maybe_rx = results.lock().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id);
let maybe_rx = results.write().expect("Failed to acquire lock on results hashmap when processing response").remove(&response.id);
if let Some(rx) = maybe_rx {
if let Err(err) = rx.send(payload) {
error!("Failed to send message to rx {:?}", err.type_id());
Expand All @@ -265,7 +265,7 @@ impl proxy_server::Proxy for ProxyServer {
}
info!("Defguard core client disconnected: {address}");
connected.store(false, Ordering::Relaxed);
clients.lock().expect("Failed to acquire lock on clients hashmap when removing disconnected client").remove(&address);
clients.write().expect("Failed to acquire lock on clients hashmap when removing disconnected client").remove(&address);
}
.instrument(tracing::Span::current()),
);
Expand Down
134 changes: 62 additions & 72 deletions src/handlers/desktop_client_mfa.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::hash_map::Entry;
use std::time::Duration;

use axum::{
extract::{
Expand All @@ -12,18 +12,23 @@ use axum::{
use futures_util::{sink::SinkExt, stream::StreamExt};
use serde::Deserialize;
use serde_json::json;
use tokio::{sync::oneshot, task::JoinSet};
use tokio::task::JoinSet;

use crate::{
error::ApiError,
handlers::get_core_response,
http::AppState,
proto::{
core_request, core_response, ClientMfaFinishRequest, ClientMfaFinishResponse,
core_request,
core_response::{self, Payload},
AwaitRemoteMfaFinishRequest, ClientMfaFinishRequest, ClientMfaFinishResponse,
ClientMfaStartRequest, ClientMfaStartResponse, DeviceInfo,
},
};

// How much time the user has to approve remote MFA with mobile device
const REMOTE_AUTH_TIMEOUT: Duration = Duration::from_secs(60);

pub(crate) fn router() -> Router<AppState> {
Router::new()
.route("/start", post(start_client_mfa))
Expand Down Expand Up @@ -53,66 +58,74 @@ async fn await_remote_auth(
token: token.clone(),
},
),
device_info,
device_info.clone(),
)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, Some(REMOTE_AUTH_TIMEOUT)).await?;
if let core_response::Payload::ClientMfaTokenValidation(response) = payload {
if !response.token_valid {
return Err(ApiError::Unauthorized(String::new()));
}
// check if its already in the map
let contains_key = {
let sessions = state.remote_mfa_sessions.lock().await;
sessions.contains_key(&token)
};
if contains_key {
return Err(ApiError::Unauthorized(String::new()));
}
Ok(ws.on_upgrade(move |socket| handle_remote_auth_socket(socket, state.clone(), token)))

Ok(ws.on_upgrade(move |socket| {
handle_remote_auth_socket(socket, state.clone(), token, device_info)
}))
} else {
Err(ApiError::InvalidResponseType)
}
}

/// Handle axum web socket upgrade for `await_remote_auth`.
async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: String) {
let (tx, rx) = oneshot::channel();

{
let mut sessions = state.remote_mfa_sessions.lock().await;
match sessions.entry(token.clone()) {
Entry::Occupied(_) => {
return;
}
Entry::Vacant(v) => {
v.insert(tx);
}
}
}

async fn handle_remote_auth_socket(
socket: WebSocket,
state: AppState,
token: String,
device_info: DeviceInfo,
) {
let (mut ws_tx, mut ws_rx) = socket.split();
let mut set = JoinSet::new();

let request = AwaitRemoteMfaFinishRequest { token };
let rx = match state.grpc_server.send(
core_request::Payload::AwaitRemoteMfaFinish(request),
device_info,
) {
Ok(rx) => rx,
Err(err) => {
error!("Failed to send ClientRemoteMfaFinishRequest: {err:?}");
return;
}
};

// Response to ClientRemoteMfaFinishRequest comes once the user concludes MFA with mobile device.
// This task then sends the preshared key to the WebSocket where desktop client awaits for it.
set.spawn(async move {
if let Ok(msg) = rx.await {
let payload = json!({
"type": "mfa_success",
"preshared_key": &msg,
});
if let Ok(serialized) = serde_json::to_string(&payload) {
let message = Message::Text(serialized.into());
if ws_tx.send(message).await.is_err() {
error!("Failed to send preshared key via ws");
match rx.await {
Ok(Payload::AwaitRemoteMfaFinish(response)) => {
let ws_response = json!({
"type": "mfa_success",
"preshared_key": &response.preshared_key,
});
if let Ok(serialized) = serde_json::to_string(&ws_response) {
let message = Message::Text(serialized.into());
if let Err(err) = ws_tx.send(message).await {
error!("Failed to send preshared key via ws: {err:?}");
}
}
} else {
error!("Failed to serialize remote mfa ws client response message");
}
} else {
error!("Failed to receive preshared key from receiver");
}
Ok(_) => {
error!("Received wrong response type, expected ClientRemoteMfaFinish");
}
Err(err) => {
error!("Failed to receive preshared key from receiver: {err:?}");
}
};

// Close the websocket once we're done.
let _ = ws_tx.close().await;
});

// Another task to monitor the websocket connection in case desktop client disconnects
// or the connection errors-out.
set.spawn(async move {
while let Some(msg_result) = ws_rx.next().await {
match msg_result {
Expand All @@ -129,10 +142,9 @@ async fn handle_remote_auth_socket(socket: WebSocket, state: AppState, token: St
}
});

// Wait for whichever task finishes first and kill the other one.
let _ = set.join_next().await;
set.shutdown().await;
// This will remove token, if it's still there.
state.remote_mfa_sessions.lock().await.remove(&token);
}

#[instrument(level = "debug", skip(state, req))]
Expand All @@ -146,7 +158,7 @@ async fn start_client_mfa(
core_request::Payload::ClientMfaStart(req.clone()),
device_info,
)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;

if let core_response::Payload::ClientMfaStart(response) = payload {
info!("Started desktop client authorization {req:?}");
Expand All @@ -167,7 +179,7 @@ async fn finish_client_mfa(
let rx = state
.grpc_server
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;
if let core_response::Payload::ClientMfaFinish(response) = payload {
Ok(Json(response))
} else {
Expand All @@ -186,32 +198,10 @@ async fn finish_remote_mfa(
let rx = state
.grpc_server
.send(core_request::Payload::ClientMfaFinish(req), device_info)?;
let payload = get_core_response(rx).await?;
if let core_response::Payload::ClientMfaFinish(response) = payload {
// Check if this needs to be forwarded.
if let Some(token) = response.token {
let sender_option = {
let mut sessions = state.remote_mfa_sessions.lock().await;
sessions.remove(&token)
};
if let Some(sender) = sender_option {
let _ = sender.send(response.preshared_key);
}
// If desktop stopped listening for the result, there will be no place to send the
// result.
else {
error!("Remote MFA approve finished but session was not found.");
return Err(ApiError::Unexpected(String::new()));
}

info!("Finished desktop client authorization via mobile device");
Ok(Json(json!({})))
} else {
error!("Remote MFA Unexpected core response, token was not returned");
Err(ApiError::Unexpected(String::new()))
}
if let core_response::Payload::ClientMfaFinish(_response) = get_core_response(rx, None).await? {
Ok(Json(json!({})))
} else {
error!("Received invalid gRPC response type");
error!("Received invalid gRPC response type, expected ClientMfaFinish");
Err(ApiError::InvalidResponseType)
}
}
8 changes: 4 additions & 4 deletions src/handlers/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async fn start_enrollment_process(
let rx = state
.grpc_server
.send(core_request::Payload::EnrollmentStart(req), device_info)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;
debug!("Receving payload from the core service. Try to set private cookie for starting enrollment process.");
if let core_response::Payload::EnrollmentStart(response) = payload {
info!(
Expand Down Expand Up @@ -83,7 +83,7 @@ async fn activate_user(
let rx = state
.grpc_server
.send(core_request::Payload::ActivateUser(req), device_info)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;
debug!("Receiving payload from the core service. Trying to remove private cookie...");
if let core_response::Payload::Empty(()) = payload {
info!("Activated user - phone number {phone:?}");
Expand Down Expand Up @@ -116,7 +116,7 @@ async fn create_device(
let rx = state
.grpc_server
.send(core_request::Payload::NewDevice(req), device_info)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;
if let core_response::Payload::DeviceConfig(response) = payload {
info!("Added new device {name} {pubkey}");
Ok(Json(response))
Expand Down Expand Up @@ -144,7 +144,7 @@ async fn get_network_info(
let rx = state
.grpc_server
.send(core_request::Payload::ExistingDevice(req), device_info)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;
if let core_response::Payload::DeviceConfig(response) = payload {
info!("Got network info for device {pubkey}");
Ok(Json(response))
Expand Down
2 changes: 1 addition & 1 deletion src/handlers/mobile_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub(crate) async fn register_mobile_auth(
core_request::Payload::RegisterMobileAuth(send_data),
device_info,
)?;
let payload = get_core_response(rx).await?;
let payload = get_core_response(rx, None).await?;
if let core_response::Payload::Empty(()) = payload {
info!("Registered mobile device for auth");
Ok(())
Expand Down
14 changes: 10 additions & 4 deletions src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::time::Duration;
use axum::{extract::FromRequestParts, http::request::Parts};
use axum_client_ip::{InsecureClientIp, LeftmostXForwardedFor};
use axum_extra::{headers::UserAgent, TypedHeader};
use tokio::{sync::oneshot::Receiver, time::timeout};
use tokio::{sync::oneshot::Receiver, time};
use tonic::Code;

use super::proto::DeviceInfo;
Expand Down Expand Up @@ -69,9 +69,12 @@ where
/// Helper which awaits core response
///
/// Waits for core response with a given timeout and returns the response payload.
pub(crate) async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload, ApiError> {
pub(crate) async fn get_core_response(
rx: Receiver<Payload>,
timeout: Option<Duration>,
) -> Result<Payload, ApiError> {
debug!("Fetching core response.");
if let Ok(core_response) = timeout(CORE_RESPONSE_TIMEOUT, rx).await {
if let Ok(core_response) = time::timeout(timeout.unwrap_or(CORE_RESPONSE_TIMEOUT), rx).await {
debug!("Got gRPC response from Defguard Core");
if let Ok(Payload::CoreError(core_error)) = core_response {
if core_error.status_code == Code::FailedPrecondition as i32
Expand All @@ -92,7 +95,10 @@ pub(crate) async fn get_core_response(rx: Receiver<Payload>) -> Result<Payload,
core_response
.map_err(|err| ApiError::Unexpected(format!("Failed to receive core response: {err}")))
} else {
error!("Did not receive response from Core within {CORE_RESPONSE_TIMEOUT:?}");
error!(
"Did not receive response from Core within {:?}",
timeout.unwrap_or(CORE_RESPONSE_TIMEOUT)
);
Err(ApiError::CoreTimeout)
}
}
Expand Down
Loading