From ab10a6598b6a2c52fb1b01cc6806b0af2ec11ee3 Mon Sep 17 00:00:00 2001 From: pasta Date: Wed, 10 Dec 2025 21:25:34 -0600 Subject: [PATCH 1/9] feat(dash-spv): Add BIP324 v2 encrypted P2P transport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement BIP324 encrypted transport for dash-spv peer connections: - Add Transport trait abstracting v1/v2 protocol differences - Implement V1Transport (extracted from existing peer.rs code) - Implement V2Transport with ChaCha20-Poly1305 encryption - Add V2 handshake with ElligatorSwift key exchange - Add Dash-specific short message IDs (128-167) - Add TransportPreference config (V2Preferred, V2Only, V1Only) - Default to V2Preferred with automatic fallback to v1 Key implementation details: - Cache decrypted packet length to prevent cipher state desync - Support all Dash short IDs matching Dash Core's V2_DASH_IDS - Proper handling of bip324 crate's packet type header byte 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash-spv/Cargo.toml | 3 + dash-spv/src/client/config.rs | 17 + dash-spv/src/error.rs | 13 + dash-spv/src/network/manager.rs | 10 +- dash-spv/src/network/mod.rs | 2 + dash-spv/src/network/peer.rs | 600 ++++-------- dash-spv/src/network/transport/message_ids.rs | 296 ++++++ dash-spv/src/network/transport/mod.rs | 79 ++ dash-spv/src/network/transport/v1.rs | 478 ++++++++++ dash-spv/src/network/transport/v2.rs | 885 ++++++++++++++++++ .../src/network/transport/v2_handshake.rs | 413 ++++++++ dash-spv/tests/handshake_test.rs | 7 +- 12 files changed, 2386 insertions(+), 417 deletions(-) create mode 100644 dash-spv/src/network/transport/message_ids.rs create mode 100644 dash-spv/src/network/transport/mod.rs create mode 100644 dash-spv/src/network/transport/v1.rs create mode 100644 dash-spv/src/network/transport/v2.rs create mode 100644 dash-spv/src/network/transport/v2_handshake.rs diff --git a/dash-spv/Cargo.toml b/dash-spv/Cargo.toml index 7af44f2fe..848598d93 100644 --- a/dash-spv/Cargo.toml +++ b/dash-spv/Cargo.toml @@ -18,6 +18,9 @@ key-wallet-manager = { path = "../key-wallet-manager" } # BLS signatures blsful = { git = "https://github.com/dashpay/agora-blsful", rev = "0c34a7a488a0bd1c9a9a2196e793b303ad35c900" } +# BIP324 v2 P2P encrypted transport +bip324 = { git = "https://github.com/rust-bitcoin/bip324", rev = "8c469432", features = ["std", "tokio"] } + # CLI clap = { version = "4.0", features = ["derive", "env"] } diff --git a/dash-spv/src/client/config.rs b/dash-spv/src/client/config.rs index 11ccfb568..4ed27cc6f 100644 --- a/dash-spv/src/client/config.rs +++ b/dash-spv/src/client/config.rs @@ -7,6 +7,7 @@ use std::time::Duration; use dashcore::Network; // Serialization removed due to complex Address types +use crate::network::transport::TransportPreference; use crate::types::ValidationMode; /// Strategy for handling mempool (unconfirmed) transactions. @@ -152,6 +153,10 @@ pub struct ClientConfig { /// Timeout for QRInfo requests (default: 30 seconds). pub qr_info_timeout: Duration, + + /// Transport preference for peer connections (V1, V2, or V2 with fallback). + /// Default is V2Preferred: try V2 encrypted transport first, fall back to V1. + pub transport_preference: TransportPreference, } impl Default for ClientConfig { @@ -201,6 +206,8 @@ impl Default for ClientConfig { // QRInfo defaults (simplified per plan) qr_info_extra_share: false, // Matches DMLviewer.patch default qr_info_timeout: Duration::from_secs(30), + // Transport preference (BIP324 v2 encrypted by default with v1 fallback) + transport_preference: TransportPreference::default(), } } } @@ -342,6 +349,16 @@ impl ClientConfig { self } + /// Set transport preference for peer connections. + /// + /// - `V2Preferred` (default): Try BIP324 v2 encrypted transport first, fall back to v1 + /// - `V2Only`: Require BIP324 v2 encrypted transport, fail if peer doesn't support it + /// - `V1Only`: Use traditional unencrypted v1 transport only + pub fn with_transport_preference(mut self, preference: TransportPreference) -> Self { + self.transport_preference = preference; + self + } + /// Validate the configuration. pub fn validate(&self) -> Result<(), String> { // Note: Empty peers list is now valid - DNS discovery will be used automatically diff --git a/dash-spv/src/error.rs b/dash-spv/src/error.rs index 5e411449d..7f745fbde 100644 --- a/dash-spv/src/error.rs +++ b/dash-spv/src/error.rs @@ -104,6 +104,19 @@ pub enum NetworkError { #[error("System time error: {0}")] SystemTime(String), + + // BIP324 V2 transport errors + #[error("V2 handshake failed: {0}")] + V2HandshakeFailed(String), + + #[error("V2 decryption failed: {0}")] + V2DecryptionFailed(String), + + #[error("V2 encryption failed: {0}")] + V2EncryptionFailed(String), + + #[error("V2 not supported by peer")] + V2NotSupported, } /// Storage-related errors. diff --git a/dash-spv/src/network/manager.rs b/dash-spv/src/network/manager.rs index c0dc87ff2..cd76bc4c6 100644 --- a/dash-spv/src/network/manager.rs +++ b/dash-spv/src/network/manager.rs @@ -27,6 +27,7 @@ use crate::network::pool::PeerPool; use crate::network::reputation::{ misbehavior_scores, positive_scores, PeerReputationManager, ReputationAware, }; +use crate::network::transport::TransportPreference; use crate::network::{HandshakeManager, NetworkManager, Peer}; use crate::types::PeerInfo; @@ -71,6 +72,8 @@ pub struct PeerNetworkManager { exclusive_mode: bool, /// Cached count of currently connected peers for fast, non-blocking queries connected_peer_count: Arc, + /// Transport preference for peer connections (V1, V2, or V2 with fallback) + transport_preference: TransportPreference, } impl PeerNetworkManager { @@ -124,6 +127,7 @@ impl PeerNetworkManager { user_agent: config.user_agent.clone(), exclusive_mode, connected_peer_count: Arc::new(AtomicUsize::new(0)), + transport_preference: config.transport_preference, }) } @@ -210,13 +214,16 @@ impl PeerNetworkManager { let mempool_strategy = self.mempool_strategy; let user_agent = self.user_agent.clone(); let connected_peer_count = self.connected_peer_count.clone(); + let transport_preference = self.transport_preference; // Spawn connection task let mut tasks = self.tasks.lock().await; tasks.spawn(async move { log::debug!("Attempting to connect to {}", addr); - match Peer::connect(addr, CONNECTION_TIMEOUT.as_secs(), network).await { + match Peer::connect(addr, CONNECTION_TIMEOUT.as_secs(), network, transport_preference) + .await + { Ok(mut peer) => { // Perform handshake let mut handshake_manager = @@ -1069,6 +1076,7 @@ impl Clone for PeerNetworkManager { user_agent: self.user_agent.clone(), exclusive_mode: self.exclusive_mode, connected_peer_count: self.connected_peer_count.clone(), + transport_preference: self.transport_preference, } } } diff --git a/dash-spv/src/network/mod.rs b/dash-spv/src/network/mod.rs index 89e8bde78..2b358c8e9 100644 --- a/dash-spv/src/network/mod.rs +++ b/dash-spv/src/network/mod.rs @@ -9,6 +9,7 @@ pub mod peer; pub mod persist; pub mod pool; pub mod reputation; +pub mod transport; #[cfg(test)] mod tests; @@ -25,6 +26,7 @@ use dashcore::BlockHash; pub use handshake::{HandshakeManager, HandshakeState}; pub use manager::PeerNetworkManager; pub use peer::Peer; +pub use transport::{Transport, TransportPreference, V1Transport}; /// Network manager trait for abstracting network operations. #[async_trait] diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index 1147a663b..2c0100c07 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -2,36 +2,26 @@ use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; use std::time::{Duration, SystemTime}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; -use tokio::sync::Mutex; -use dashcore::consensus::{encode, Decodable}; -use dashcore::network::message::{NetworkMessage, RawNetworkMessage}; +use dashcore::network::message::NetworkMessage; use dashcore::Network; use crate::error::{NetworkError, NetworkResult}; use crate::network::constants::PING_INTERVAL; +use crate::network::transport::{ + Transport, TransportPreference, V1Transport, V2HandshakeManager, V2HandshakeResult, V2Transport, +}; use crate::types::PeerInfo; -/// Internal state for the TCP connection -struct ConnectionState { - stream: TcpStream, - // Stateful message framing buffer to ensure full frames before decoding - framing_buffer: Vec, -} - /// Dash P2P peer pub struct Peer { address: SocketAddr, - // Use a single mutex to protect both the write stream and read buffer - // This ensures no concurrent access to the underlying socket - state: Option>>, + /// The transport layer (V1 or V2) + transport: Option>, timeout: Duration, connected_at: Option, - bytes_sent: u64, network: Network, // Ping/pong state last_ping_sent: Option, @@ -47,6 +37,8 @@ pub struct Peer { sent_sendheaders2: bool, // Basic telemetry for resync events consecutive_resyncs: u32, + // Transport protocol version used (1 or 2) + transport_version: u8, } impl Peer { @@ -54,14 +46,19 @@ impl Peer { pub fn address(&self) -> SocketAddr { self.address } - /// Create a new peer. + + /// Get the transport protocol version (1 = unencrypted, 2 = BIP324 encrypted). + pub fn transport_version(&self) -> u8 { + self.transport_version + } + + /// Create a new peer (not connected). pub fn new(address: SocketAddr, timeout: Duration, network: Network) -> Self { Self { address, - state: None, + transport: None, timeout, connected_at: None, - bytes_sent: 0, network, last_ping_sent: None, last_pong_received: None, @@ -74,41 +71,64 @@ impl Peer { prefers_headers2: false, sent_sendheaders2: false, consecutive_resyncs: 0, + transport_version: 1, } } - /// Connect to a peer and return a connected instance. + /// Connect to a peer with the specified transport preference. + /// + /// # Arguments + /// * `address` - The peer's socket address + /// * `timeout_secs` - Connection timeout in seconds + /// * `network` - The Dash network (mainnet, testnet, etc.) + /// * `transport_pref` - V1Only, V2Only, or V2Preferred (default) + /// + /// # Returns + /// A connected Peer instance using the appropriate transport. pub async fn connect( address: SocketAddr, timeout_secs: u64, network: Network, + transport_pref: TransportPreference, ) -> NetworkResult { let timeout = Duration::from_secs(timeout_secs); - let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) - .await - .map_err(|_| { - NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) - })? - .map_err(|e| { - NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) - })?; - - stream.set_nodelay(true).map_err(|e| { - NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) - })?; - - let state = ConnectionState { - stream, - framing_buffer: Vec::new(), + let (transport, transport_version): (Box, u8) = match transport_pref { + TransportPreference::V1Only => { + tracing::info!("Connecting to {} using V1 transport (unencrypted)", address); + let transport = Self::establish_v1_transport(address, timeout, network).await?; + (Box::new(transport), 1) + } + TransportPreference::V2Only => { + tracing::info!( + "Connecting to {} using V2 transport (BIP324 encrypted, no fallback)", + address + ); + let transport = Self::establish_v2_transport(address, timeout, network).await?; + (Box::new(transport), 2) + } + TransportPreference::V2Preferred => { + tracing::info!( + "Connecting to {} using V2 transport (BIP324 encrypted, with V1 fallback)", + address + ); + match Self::try_v2_with_fallback(address, timeout, network).await? { + (transport, version) => (transport, version), + } + } }; + tracing::info!( + "Successfully connected to {} using V{} transport", + address, + transport_version + ); + Ok(Self { address, - state: Some(Arc::new(Mutex::new(state))), + transport: Some(transport), timeout, connected_at: Some(SystemTime::now()), - bytes_sent: 0, network, last_ping_sent: None, last_pong_received: None, @@ -121,48 +141,142 @@ impl Peer { prefers_headers2: false, sent_sendheaders2: false, consecutive_resyncs: 0, + transport_version, }) } - /// Connect to the peer (instance method for compatibility). - pub async fn connect_instance(&mut self) -> NetworkResult<()> { - let stream = tokio::time::timeout(self.timeout, TcpStream::connect(self.address)) + /// Establish a V1 (unencrypted) transport connection. + async fn establish_v1_transport( + address: SocketAddr, + timeout: Duration, + network: Network, + ) -> NetworkResult { + let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) + .await + .map_err(|_| { + NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) + })? + .map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) + })?; + + stream.set_nodelay(true).map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) + })?; + + Ok(V1Transport::new(stream, network, address)) + } + + /// Establish a V2 (BIP324 encrypted) transport connection. + /// Fails if peer doesn't support V2. + async fn establish_v2_transport( + address: SocketAddr, + timeout: Duration, + network: Network, + ) -> NetworkResult { + let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) + .await + .map_err(|_| { + NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) + })? + .map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) + })?; + + stream.set_nodelay(true).map_err(|e| { + NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) + })?; + + let handshake_manager = V2HandshakeManager::new_initiator(network, address); + match handshake_manager.perform_handshake(stream).await? { + V2HandshakeResult::Success(session) => { + Ok(V2Transport::new(session.stream, session.cipher, session.session_id, address)) + } + V2HandshakeResult::FallbackToV1 => Err(NetworkError::V2NotSupported), + } + } + + /// Try V2 transport first, fall back to V1 if peer doesn't support V2. + async fn try_v2_with_fallback( + address: SocketAddr, + timeout: Duration, + network: Network, + ) -> NetworkResult<(Box, u8)> { + // First, try to establish TCP connection + let stream = tokio::time::timeout(timeout, TcpStream::connect(address)) .await .map_err(|_| { - NetworkError::ConnectionFailed(format!("Connection to {} timed out", self.address)) + NetworkError::ConnectionFailed(format!("Connection to {} timed out", address)) })? .map_err(|e| { - NetworkError::ConnectionFailed(format!( - "Failed to connect to {}: {}", - self.address, e - )) + NetworkError::ConnectionFailed(format!("Failed to connect to {}: {}", address, e)) })?; - // Disable Nagle's algorithm for lower latency stream.set_nodelay(true).map_err(|e| { NetworkError::ConnectionFailed(format!("Failed to set TCP_NODELAY: {}", e)) })?; - let state = ConnectionState { - stream, - framing_buffer: Vec::new(), + // Try V2 handshake + let handshake_manager = V2HandshakeManager::new_initiator(network, address); + match handshake_manager.perform_handshake(stream).await { + Ok(V2HandshakeResult::Success(session)) => { + tracing::info!("V2 handshake succeeded with {}", address); + let transport = + V2Transport::new(session.stream, session.cipher, session.session_id, address); + Ok((Box::new(transport), 2)) + } + Ok(V2HandshakeResult::FallbackToV1) => { + tracing::info!( + "V2 handshake detected V1-only peer {}, reconnecting with V1 transport", + address + ); + // Need to reconnect since the stream was consumed + let transport = Self::establish_v1_transport(address, timeout, network).await?; + Ok((Box::new(transport), 1)) + } + Err(e) => { + tracing::warn!("V2 handshake failed with {}: {}, falling back to V1", address, e); + // Try V1 as fallback + let transport = Self::establish_v1_transport(address, timeout, network).await?; + Ok((Box::new(transport), 1)) + } + } + } + + /// Connect to the peer (instance method for compatibility). + pub async fn connect_instance( + &mut self, + transport_pref: TransportPreference, + ) -> NetworkResult<()> { + let (transport, transport_version): (Box, u8) = match transport_pref { + TransportPreference::V1Only => { + let t = + Self::establish_v1_transport(self.address, self.timeout, self.network).await?; + (Box::new(t), 1) + } + TransportPreference::V2Only => { + let t = + Self::establish_v2_transport(self.address, self.timeout, self.network).await?; + (Box::new(t), 2) + } + TransportPreference::V2Preferred => { + Self::try_v2_with_fallback(self.address, self.timeout, self.network).await? + } }; - self.state = Some(Arc::new(Mutex::new(state))); + self.transport = Some(transport); + self.transport_version = transport_version; self.connected_at = Some(SystemTime::now()); - tracing::info!("Connected to peer {}", self.address); + tracing::info!("Connected to peer {} using V{} transport", self.address, transport_version); Ok(()) } /// Disconnect from the peer. pub async fn disconnect(&mut self) -> NetworkResult<()> { - if let Some(state_arc) = self.state.take() { - if let Ok(state_mutex) = Arc::try_unwrap(state_arc) { - let mut state = state_mutex.into_inner(); - let _ = state.stream.shutdown().await; - } + if let Some(mut transport) = self.transport.take() { + transport.shutdown().await?; } self.connected_at = None; @@ -272,372 +386,28 @@ impl Peer { ); } - /// Helper function to read some bytes into the framing buffer. - async fn read_some(state: &mut ConnectionState) -> std::io::Result { - let mut tmp = [0u8; 8192]; - match state.stream.read(&mut tmp).await { - Ok(0) => Ok(0), - Ok(n) => { - state.framing_buffer.extend_from_slice(&tmp[..n]); - Ok(n) - } - Err(e) => Err(e), - } - } - /// Send a message to the peer. pub async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()> { - let state_arc = self - .state - .as_ref() + let transport = self + .transport + .as_mut() .ok_or_else(|| NetworkError::ConnectionFailed("Not connected".to_string()))?; - let raw_message = RawNetworkMessage { - magic: self.network.magic(), - payload: message, - }; - - let serialized = encode::serialize(&raw_message); - - // Log details for debugging headers2 issues - if matches!( - raw_message.payload, - NetworkMessage::GetHeaders2(_) | NetworkMessage::GetHeaders(_) - ) { - let msg_type = match raw_message.payload { - NetworkMessage::GetHeaders2(_) => "GetHeaders2", - NetworkMessage::GetHeaders(_) => "GetHeaders", - _ => "Unknown", - }; - tracing::debug!( - "Sending {} raw bytes (len={}): {:02x?}", - msg_type, - serialized.len(), - &serialized[..std::cmp::min(100, serialized.len())] - ); - } - - // Lock the state for the entire write operation - let mut state = state_arc.lock().await; - - // Write with error handling - match state.stream.write_all(&serialized).await { - Ok(_) => { - // Flush to ensure data is sent immediately - if let Err(e) = state.stream.flush().await { - tracing::warn!("Failed to flush socket {}: {}", self.address, e); - } - self.bytes_sent += serialized.len() as u64; - tracing::debug!("Sent message to {}: {:?}", self.address, raw_message.payload); - Ok(()) - } - Err(e) => { - tracing::warn!("Disconnecting {} due to write error: {}", self.address, e); - // Drop the lock before clearing connection state - drop(state); - // Clear connection state on write error - self.state = None; - self.connected_at = None; - Err(NetworkError::ConnectionFailed(format!("Write failed: {}", e))) - } - } + transport.send_message(message).await } /// Receive a message from the peer. pub async fn receive_message(&mut self) -> NetworkResult> { - // First check if we have a state - let state_arc = self - .state - .as_ref() + let transport = self + .transport + .as_mut() .ok_or_else(|| NetworkError::ConnectionFailed("Not connected".to_string()))?; - // Lock the state for the entire read operation - // This ensures no concurrent access to the socket - let mut state = state_arc.lock().await; - - // Buffered, stateful framing - const HEADER_LEN: usize = 24; // magic[4] + cmd[12] + length[4] + checksum[4] - const MAX_RESYNC_STEPS_PER_CALL: usize = 64; - - let result = async { - let magic_bytes = self.network.magic().to_le_bytes(); - let mut resync_steps = 0usize; - - loop { - // Ensure header availability - if state.framing_buffer.len() < HEADER_LEN { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(ref e) - if e.kind() == std::io::ErrorKind::ConnectionAborted - || e.kind() == std::io::ErrorKind::ConnectionReset => - { - tracing::info!("Peer {} connection reset/aborted", self.address); - return Err(NetworkError::PeerDisconnected); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - } - - // Align to magic - if state.framing_buffer.len() >= 4 && state.framing_buffer[..4] != magic_bytes { - if let Some(pos) = - state.framing_buffer.windows(4).position(|w| w == magic_bytes) - { - if pos > 0 { - tracing::warn!( - "{}: stream desync: skipping {} stray bytes before magic", - self.address, - pos - ); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - state.framing_buffer.drain(0..pos); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - } else { - // Keep last 3 bytes of potential magic prefix - if state.framing_buffer.len() > 3 { - let dropped = state.framing_buffer.len() - 3; - tracing::warn!( - "{}: stream desync: dropping {} bytes (no magic found)", - self.address, - dropped - ); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - state.framing_buffer.drain(0..dropped); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - } - // Need more data - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - continue; - } - } - - // Ensure full header - if state.framing_buffer.len() < HEADER_LEN { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - continue; - } - - // Parse header fields - let length_le = u32::from_le_bytes([ - state.framing_buffer[16], - state.framing_buffer[17], - state.framing_buffer[18], - state.framing_buffer[19], - ]) as usize; - let header_checksum = [ - state.framing_buffer[20], - state.framing_buffer[21], - state.framing_buffer[22], - state.framing_buffer[23], - ]; - // Validate announced length to prevent unbounded accumulation or overflow - if length_le > dashcore::network::message::MAX_MSG_SIZE { - return Err(NetworkError::ProtocolError(format!( - "Declared payload length {} exceeds MAX_MSG_SIZE {}", - length_le, - dashcore::network::message::MAX_MSG_SIZE - ))); - } - let total_len = match HEADER_LEN.checked_add(length_le) { - Some(v) => v, - None => { - return Err(NetworkError::ProtocolError( - "Message length overflow".to_string(), - )); - } - }; - - // Ensure full frame available - if state.framing_buffer.len() < total_len { - match Self::read_some(&mut state).await { - Ok(0) => { - tracing::info!("Peer {} closed connection (EOF)", self.address); - return Err(NetworkError::PeerDisconnected); - } - Ok(_) => {} - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - return Ok(None); - } - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { - return Ok(None); - } - Err(e) => { - return Err(NetworkError::ConnectionFailed(format!( - "Read failed: {}", - e - ))); - } - } - continue; - } - - // Verify checksum - let payload_slice = &state.framing_buffer[HEADER_LEN..total_len]; - let expected = { - let checksum = ::hash( - payload_slice, - ); - [checksum[0], checksum[1], checksum[2], checksum[3]] - }; - if expected != header_checksum { - tracing::warn!( - "Skipping message with invalid checksum from {}: expected {:02x?}, actual {:02x?}", - self.address, - expected, - header_checksum - ); - if header_checksum == [0, 0, 0, 0] { - tracing::warn!( - "All-zeros checksum detected from {}, likely corrupted stream - resyncing", - self.address - ); - } - // Resync by dropping a byte and retrying - state.framing_buffer.drain(0..1); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - - // Decode full RawNetworkMessage from the frame using existing decoder - let mut cursor = std::io::Cursor::new(&state.framing_buffer[..total_len]); - match RawNetworkMessage::consensus_decode(&mut cursor) { - Ok(raw_message) => { - // Consume bytes - state.framing_buffer.drain(0..total_len); - self.consecutive_resyncs = 0; - - // Validate magic matches our network - if raw_message.magic != self.network.magic() { - tracing::warn!( - "Received message with wrong magic bytes: expected {:#x}, got {:#x}", - self.network.magic(), - raw_message.magic - ); - return Err(NetworkError::ProtocolError(format!( - "Wrong magic bytes: expected {:#x}, got {:#x}", - self.network.magic(), - raw_message.magic - ))); - } - - tracing::trace!( - "Successfully decoded message from {}: {:?}", - self.address, - raw_message.payload.cmd() - ); - - if raw_message.payload.cmd() == "headers2" { - tracing::info!("🎉 Received Headers2 message from {}!", self.address); - } - - if let NetworkMessage::Block(ref block) = raw_message.payload { - let block_hash = block.block_hash(); - tracing::info!( - "Successfully decoded block {} from {}", - block_hash, - self.address - ); - } - - if let NetworkMessage::Headers2(ref headers2) = raw_message.payload { - tracing::info!( - "Successfully decoded Headers2 message from {} with {} compressed headers", - self.address, - headers2.headers.len() - ); - } - - return Ok(Some(raw_message.payload)); - } - Err(e) => { - tracing::warn!( - "{}: decode error after framing ({}), attempting resync", - self.address, - e - ); - state.framing_buffer.drain(0..1); - self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); - resync_steps += 1; - if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { - return Ok(None); - } - continue; - } - } - } - } - .await; + let result = transport.receive_message().await; - // Drop the lock before disconnecting - drop(state); - - // Handle disconnection if needed + // Handle disconnection if let Err(NetworkError::PeerDisconnected) = &result { - self.state = None; + self.transport = None; self.connected_at = None; } @@ -646,7 +416,7 @@ impl Peer { /// Check if the connection is active. pub fn is_connected(&self) -> bool { - self.state.is_some() + self.transport.as_ref().map(|t| t.is_connected()).unwrap_or(false) } /// Check if connection appears healthy (not just connected). @@ -701,7 +471,11 @@ impl Peer { /// Get connection statistics. pub fn stats(&self) -> (u64, u64) { - (self.bytes_sent, 0) // TODO: Track bytes received + if let Some(transport) = &self.transport { + (transport.bytes_sent(), transport.bytes_received()) + } else { + (0, 0) + } } /// Send a ping message with a random nonce. diff --git a/dash-spv/src/network/transport/message_ids.rs b/dash-spv/src/network/transport/message_ids.rs new file mode 100644 index 000000000..e5691f415 --- /dev/null +++ b/dash-spv/src/network/transport/message_ids.rs @@ -0,0 +1,296 @@ +//! BIP324 short message IDs for Dash. +//! +//! BIP324 uses 1-byte short IDs for common messages to reduce bandwidth. +//! Less common messages use extended format: 0x00 + 12-byte ASCII command. +//! +//! Dash extends BIP324 with its own message IDs in the 128-255 range: +//! - IDs 0-32: Standard BIP324 (Bitcoin) messages +//! - IDs 128-167: Dash-specific messages +//! +//! ## Design Notes +//! +//! There's intentional asymmetry between the two main functions: +//! - `short_id_to_command`: Handles ALL short IDs (for receiving messages) +//! - `network_message_to_short_id`: Only handles NetworkMessage variants that exist +//! +//! This means we can decode incoming messages with short IDs even if dashcore +//! doesn't have a dedicated NetworkMessage variant for them (they'll be decoded +//! as Unknown messages via the extended format fallback in decode_by_command). + +use dashcore::network::message::NetworkMessage; + +/// Extended message marker (12-byte ASCII command follows). +pub const MSG_ID_EXTENDED: u8 = 0; + +// ============================================================================= +// Standard BIP324 short message IDs (1-28) +// Matches Dash Core's V2_BITCOIN_IDS array +// ============================================================================= +pub const MSG_ID_ADDR: u8 = 1; +pub const MSG_ID_BLOCK: u8 = 2; +pub const MSG_ID_BLOCKTXN: u8 = 3; +pub const MSG_ID_CMPCTBLOCK: u8 = 4; +// ID 5 is reserved for FEEFILTER but not implemented in Dash +pub const MSG_ID_FILTERADD: u8 = 6; +pub const MSG_ID_FILTERCLEAR: u8 = 7; +pub const MSG_ID_FILTERLOAD: u8 = 8; +pub const MSG_ID_GETBLOCKS: u8 = 9; +pub const MSG_ID_GETBLOCKTXN: u8 = 10; +pub const MSG_ID_GETDATA: u8 = 11; +pub const MSG_ID_GETHEADERS: u8 = 12; +pub const MSG_ID_HEADERS: u8 = 13; +pub const MSG_ID_INV: u8 = 14; +pub const MSG_ID_MEMPOOL: u8 = 15; +pub const MSG_ID_MERKLEBLOCK: u8 = 16; +pub const MSG_ID_NOTFOUND: u8 = 17; +pub const MSG_ID_PING: u8 = 18; +pub const MSG_ID_PONG: u8 = 19; +pub const MSG_ID_SENDCMPCT: u8 = 20; +pub const MSG_ID_TX: u8 = 21; +pub const MSG_ID_GETCFILTERS: u8 = 22; +pub const MSG_ID_CFILTER: u8 = 23; +pub const MSG_ID_GETCFHEADERS: u8 = 24; +pub const MSG_ID_CFHEADERS: u8 = 25; +pub const MSG_ID_GETCFCHECKPT: u8 = 26; +pub const MSG_ID_CFCHECKPT: u8 = 27; +pub const MSG_ID_ADDRV2: u8 = 28; +// IDs 29-32 are reserved but unimplemented in BIP324 + +// ============================================================================= +// Dash-specific short message IDs (128-167) +// Matches Dash Core's V2_DASH_IDS array +// ============================================================================= +pub const MSG_ID_SPORK: u8 = 128; +pub const MSG_ID_GETSPORKS: u8 = 129; +pub const MSG_ID_SENDDSQUEUE: u8 = 130; +pub const MSG_ID_DSACCEPT: u8 = 131; +pub const MSG_ID_DSVIN: u8 = 132; +pub const MSG_ID_DSFINALTX: u8 = 133; +pub const MSG_ID_DSSIGNFINALTX: u8 = 134; +pub const MSG_ID_DSCOMPLETE: u8 = 135; +pub const MSG_ID_DSSTATUSUPDATE: u8 = 136; +pub const MSG_ID_DSTX: u8 = 137; +pub const MSG_ID_DSQUEUE: u8 = 138; +pub const MSG_ID_SYNCSTATUSCOUNT: u8 = 139; +pub const MSG_ID_MNGOVERNANCESYNC: u8 = 140; +pub const MSG_ID_MNGOVERNANCEOBJECT: u8 = 141; +pub const MSG_ID_MNGOVERNANCEOBJECTVOTE: u8 = 142; +pub const MSG_ID_GETMNLISTDIFF: u8 = 143; +pub const MSG_ID_MNLISTDIFF: u8 = 144; +pub const MSG_ID_QSENDRECSIGS: u8 = 145; +pub const MSG_ID_QFCOMMITMENT: u8 = 146; +pub const MSG_ID_QCONTRIB: u8 = 147; +pub const MSG_ID_QCOMPLAINT: u8 = 148; +pub const MSG_ID_QJUSTIFICATION: u8 = 149; +pub const MSG_ID_QPCOMMITMENT: u8 = 150; +pub const MSG_ID_QWATCH: u8 = 151; +pub const MSG_ID_QSIGSESANN: u8 = 152; +pub const MSG_ID_QSIGSHARESINV: u8 = 153; +pub const MSG_ID_QGETSIGSHARES: u8 = 154; +pub const MSG_ID_QBSIGSHARES: u8 = 155; +pub const MSG_ID_QSIGREC: u8 = 156; +pub const MSG_ID_QSIGSHARE: u8 = 157; +pub const MSG_ID_QGETDATA: u8 = 158; +pub const MSG_ID_QDATA: u8 = 159; +pub const MSG_ID_CLSIG: u8 = 160; +pub const MSG_ID_ISDLOCK: u8 = 161; +pub const MSG_ID_MNAUTH: u8 = 162; +pub const MSG_ID_GETHEADERS2: u8 = 163; +pub const MSG_ID_SENDHEADERS2: u8 = 164; +pub const MSG_ID_HEADERS2: u8 = 165; +pub const MSG_ID_GETQUORUMROTATIONINFO: u8 = 166; +pub const MSG_ID_QUORUMROTATIONINFO: u8 = 167; + +/// Get the short message ID for a NetworkMessage, if one exists. +/// +/// Returns `Some(id)` for common messages that have short IDs, +/// or `None` for messages that require extended format. +pub fn network_message_to_short_id(msg: &NetworkMessage) -> Option { + match msg { + // Standard BIP324 messages + NetworkMessage::Addr(_) => Some(MSG_ID_ADDR), + NetworkMessage::Block(_) => Some(MSG_ID_BLOCK), + NetworkMessage::BlockTxn(_) => Some(MSG_ID_BLOCKTXN), + NetworkMessage::CmpctBlock(_) => Some(MSG_ID_CMPCTBLOCK), + // Note: FeeFilter is ID 5 in BIP324 but not implemented in Dash + NetworkMessage::FilterAdd(_) => Some(MSG_ID_FILTERADD), + NetworkMessage::FilterClear => Some(MSG_ID_FILTERCLEAR), + NetworkMessage::FilterLoad(_) => Some(MSG_ID_FILTERLOAD), + NetworkMessage::GetBlocks(_) => Some(MSG_ID_GETBLOCKS), + NetworkMessage::GetBlockTxn(_) => Some(MSG_ID_GETBLOCKTXN), + NetworkMessage::GetData(_) => Some(MSG_ID_GETDATA), + NetworkMessage::GetHeaders(_) => Some(MSG_ID_GETHEADERS), + NetworkMessage::Headers(_) => Some(MSG_ID_HEADERS), + NetworkMessage::Inv(_) => Some(MSG_ID_INV), + NetworkMessage::MemPool => Some(MSG_ID_MEMPOOL), + NetworkMessage::MerkleBlock(_) => Some(MSG_ID_MERKLEBLOCK), + NetworkMessage::NotFound(_) => Some(MSG_ID_NOTFOUND), + NetworkMessage::Ping(_) => Some(MSG_ID_PING), + NetworkMessage::Pong(_) => Some(MSG_ID_PONG), + NetworkMessage::SendCmpct(_) => Some(MSG_ID_SENDCMPCT), + NetworkMessage::Tx(_) => Some(MSG_ID_TX), + NetworkMessage::GetCFilters(_) => Some(MSG_ID_GETCFILTERS), + NetworkMessage::CFilter(_) => Some(MSG_ID_CFILTER), + NetworkMessage::GetCFHeaders(_) => Some(MSG_ID_GETCFHEADERS), + NetworkMessage::CFHeaders(_) => Some(MSG_ID_CFHEADERS), + NetworkMessage::GetCFCheckpt(_) => Some(MSG_ID_GETCFCHECKPT), + NetworkMessage::CFCheckpt(_) => Some(MSG_ID_CFCHECKPT), + NetworkMessage::AddrV2(_) => Some(MSG_ID_ADDRV2), + + // Dash-specific messages (only variants that exist in dashcore) + NetworkMessage::SendDsq(_) => Some(MSG_ID_SENDDSQUEUE), + NetworkMessage::GetMnListD(_) => Some(MSG_ID_GETMNLISTDIFF), + NetworkMessage::MnListDiff(_) => Some(MSG_ID_MNLISTDIFF), + NetworkMessage::CLSig(_) => Some(MSG_ID_CLSIG), + NetworkMessage::ISLock(_) => Some(MSG_ID_ISDLOCK), + NetworkMessage::GetHeaders2(_) => Some(MSG_ID_GETHEADERS2), + NetworkMessage::SendHeaders2 => Some(MSG_ID_SENDHEADERS2), + NetworkMessage::Headers2(_) => Some(MSG_ID_HEADERS2), + NetworkMessage::GetQRInfo(_) => Some(MSG_ID_GETQUORUMROTATIONINFO), + NetworkMessage::QRInfo(_) => Some(MSG_ID_QUORUMROTATIONINFO), + + // All other messages use extended format + _ => None, + } +} + +/// Get the command string for a short message ID. +/// +/// Returns `Some(command)` for valid short IDs, +/// or `None` for unknown IDs. +pub fn short_id_to_command(id: u8) -> Option<&'static str> { + match id { + // Standard BIP324 messages + MSG_ID_ADDR => Some("addr"), + MSG_ID_BLOCK => Some("block"), + MSG_ID_BLOCKTXN => Some("blocktxn"), + MSG_ID_CMPCTBLOCK => Some("cmpctblock"), + MSG_ID_FILTERADD => Some("filteradd"), + MSG_ID_FILTERCLEAR => Some("filterclear"), + MSG_ID_FILTERLOAD => Some("filterload"), + MSG_ID_GETBLOCKS => Some("getblocks"), + MSG_ID_GETBLOCKTXN => Some("getblocktxn"), + MSG_ID_GETDATA => Some("getdata"), + MSG_ID_GETHEADERS => Some("getheaders"), + MSG_ID_HEADERS => Some("headers"), + MSG_ID_INV => Some("inv"), + MSG_ID_MEMPOOL => Some("mempool"), + MSG_ID_MERKLEBLOCK => Some("merkleblock"), + MSG_ID_NOTFOUND => Some("notfound"), + MSG_ID_PING => Some("ping"), + MSG_ID_PONG => Some("pong"), + MSG_ID_SENDCMPCT => Some("sendcmpct"), + MSG_ID_TX => Some("tx"), + MSG_ID_GETCFILTERS => Some("getcfilters"), + MSG_ID_CFILTER => Some("cfilter"), + MSG_ID_GETCFHEADERS => Some("getcfheaders"), + MSG_ID_CFHEADERS => Some("cfheaders"), + MSG_ID_GETCFCHECKPT => Some("getcfcheckpt"), + MSG_ID_CFCHECKPT => Some("cfcheckpt"), + MSG_ID_ADDRV2 => Some("addrv2"), + + // Dash-specific messages + MSG_ID_SPORK => Some("spork"), + MSG_ID_GETSPORKS => Some("getsporks"), + MSG_ID_SENDDSQUEUE => Some("senddsq"), + MSG_ID_DSACCEPT => Some("dsa"), + MSG_ID_DSVIN => Some("dsi"), + MSG_ID_DSFINALTX => Some("dsf"), + MSG_ID_DSSIGNFINALTX => Some("dss"), + MSG_ID_DSCOMPLETE => Some("dsc"), + MSG_ID_DSSTATUSUPDATE => Some("dssu"), + MSG_ID_DSTX => Some("dstx"), + MSG_ID_DSQUEUE => Some("dsq"), + MSG_ID_SYNCSTATUSCOUNT => Some("ssc"), + MSG_ID_MNGOVERNANCESYNC => Some("govsync"), + MSG_ID_MNGOVERNANCEOBJECT => Some("govobj"), + MSG_ID_MNGOVERNANCEOBJECTVOTE => Some("govobjvote"), + MSG_ID_GETMNLISTDIFF => Some("getmnlistd"), + MSG_ID_MNLISTDIFF => Some("mnlistdiff"), + MSG_ID_QSENDRECSIGS => Some("qsendrecsigs"), + MSG_ID_QFCOMMITMENT => Some("qfcommit"), + MSG_ID_QCONTRIB => Some("qcontrib"), + MSG_ID_QCOMPLAINT => Some("qcomplaint"), + MSG_ID_QJUSTIFICATION => Some("qjustify"), + MSG_ID_QPCOMMITMENT => Some("qpcommit"), + MSG_ID_QWATCH => Some("qwatch"), + MSG_ID_QSIGSESANN => Some("qsigsesann"), + MSG_ID_QSIGSHARESINV => Some("qsigsinv"), + MSG_ID_QGETSIGSHARES => Some("qgetsigs"), + MSG_ID_QBSIGSHARES => Some("qbsigs"), + MSG_ID_QSIGREC => Some("qsigrec"), + MSG_ID_QSIGSHARE => Some("qsigshare"), + MSG_ID_QGETDATA => Some("qgetdata"), + MSG_ID_QDATA => Some("qdata"), + MSG_ID_CLSIG => Some("clsig"), + MSG_ID_ISDLOCK => Some("isdlock"), + MSG_ID_MNAUTH => Some("mnauth"), + MSG_ID_GETHEADERS2 => Some("getheaders2"), + MSG_ID_SENDHEADERS2 => Some("sendheaders2"), + MSG_ID_HEADERS2 => Some("headers2"), + MSG_ID_GETQUORUMROTATIONINFO => Some("getqrinfo"), + MSG_ID_QUORUMROTATIONINFO => Some("qrinfo"), + + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ping_pong_ids() { + assert_eq!(network_message_to_short_id(&NetworkMessage::Ping(0)), Some(MSG_ID_PING)); + assert_eq!(network_message_to_short_id(&NetworkMessage::Pong(0)), Some(MSG_ID_PONG)); + } + + #[test] + fn test_short_id_to_command() { + assert_eq!(short_id_to_command(MSG_ID_PING), Some("ping")); + assert_eq!(short_id_to_command(MSG_ID_PONG), Some("pong")); + assert_eq!(short_id_to_command(MSG_ID_BLOCK), Some("block")); + assert_eq!(short_id_to_command(255), None); + } + + #[test] + fn test_dash_short_ids() { + // Test Dash-specific short IDs + assert_eq!(short_id_to_command(MSG_ID_SPORK), Some("spork")); + assert_eq!(short_id_to_command(MSG_ID_SENDDSQUEUE), Some("senddsq")); + assert_eq!(short_id_to_command(MSG_ID_CLSIG), Some("clsig")); + assert_eq!(short_id_to_command(MSG_ID_ISDLOCK), Some("isdlock")); + assert_eq!(short_id_to_command(MSG_ID_MNLISTDIFF), Some("mnlistdiff")); + assert_eq!(short_id_to_command(MSG_ID_HEADERS2), Some("headers2")); + } + + #[test] + fn test_extended_format_for_non_short_id_messages() { + // Version is not a short ID message + use dashcore::network::address::Address; + use dashcore::network::constants::ServiceFlags; + use dashcore::network::message_network::VersionMessage; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + let addr = Address::new( + &SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8333), + ServiceFlags::NONE, + ); + + let version = VersionMessage { + version: 70015, + services: ServiceFlags::NONE, + timestamp: 0, + receiver: addr.clone(), + sender: addr, + nonce: 0, + user_agent: "/test/".to_string(), + start_height: 0, + relay: false, + mn_auth_challenge: [0u8; 32], + masternode_connection: false, + }; + + assert!(network_message_to_short_id(&NetworkMessage::Version(version)).is_none()); + } +} diff --git a/dash-spv/src/network/transport/mod.rs b/dash-spv/src/network/transport/mod.rs new file mode 100644 index 000000000..0ee7eea29 --- /dev/null +++ b/dash-spv/src/network/transport/mod.rs @@ -0,0 +1,79 @@ +//! Transport layer abstraction for Dash P2P connections. +//! +//! This module provides a `Transport` trait that abstracts the underlying +//! communication protocol (V1 unencrypted or V2 BIP324 encrypted). + +pub mod message_ids; +pub mod v1; +pub mod v2; +pub mod v2_handshake; + +use async_trait::async_trait; +use dashcore::network::message::NetworkMessage; + +use crate::error::NetworkResult; + +pub use v1::V1Transport; +pub use v2::V2Transport; +pub use v2_handshake::{V2HandshakeManager, V2HandshakeResult, V2Session}; + +/// Transport preference for peer connections. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum TransportPreference { + /// Use V2 encrypted transport only (fail if peer doesn't support). + V2Only, + /// Prefer V2 encrypted transport, fallback to V1 if needed (default). + #[default] + V2Preferred, + /// Use V1 unencrypted transport only (for compatibility testing). + V1Only, +} + +/// Result of establishing a transport connection. +pub enum TransportEstablishResult { + /// Successfully established V1 transport. + V1(V1Transport), + /// Need to fallback to V1 (V2 handshake detected V1-only peer). + FallbackToV1, +} + +/// Abstract transport layer for P2P communication. +/// +/// This trait is implemented by both V1Transport (unencrypted) and +/// V2Transport (BIP324 encrypted) to provide a unified interface +/// for message exchange. +#[async_trait] +pub trait Transport: Send + Sync { + /// Send a network message over the transport. + /// + /// # Arguments + /// * `message` - The network message to send + /// + /// # Returns + /// * `Ok(())` on success + /// * `Err(NetworkError)` on failure + async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()>; + + /// Receive a network message from the transport. + /// + /// # Returns + /// * `Ok(Some(message))` if a complete message was received + /// * `Ok(None)` if no complete message is available yet (non-blocking) + /// * `Err(NetworkError)` on failure or disconnection + async fn receive_message(&mut self) -> NetworkResult>; + + /// Check if the transport is connected. + fn is_connected(&self) -> bool; + + /// Get the transport protocol version (1 or 2). + fn protocol_version(&self) -> u8; + + /// Get the number of bytes sent over this transport. + fn bytes_sent(&self) -> u64; + + /// Get the number of bytes received over this transport. + fn bytes_received(&self) -> u64; + + /// Shutdown the transport connection. + async fn shutdown(&mut self) -> NetworkResult<()>; +} diff --git a/dash-spv/src/network/transport/v1.rs b/dash-spv/src/network/transport/v1.rs new file mode 100644 index 000000000..0d331c462 --- /dev/null +++ b/dash-spv/src/network/transport/v1.rs @@ -0,0 +1,478 @@ +//! V1 Transport - Unencrypted Dash P2P protocol transport. +//! +//! This implements the traditional Bitcoin/Dash P2P message framing: +//! - 4 bytes: Network magic +//! - 12 bytes: Command string +//! - 4 bytes: Payload length (little-endian) +//! - 4 bytes: Checksum (first 4 bytes of SHA256d of payload) +//! - Variable: Payload + +use std::net::SocketAddr; + +use async_trait::async_trait; +use dashcore::consensus::{encode, Decodable}; +use dashcore::network::message::{NetworkMessage, RawNetworkMessage, MAX_MSG_SIZE}; +use dashcore::Network; +use dashcore_hashes::Hash; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +use super::Transport; +use crate::error::{NetworkError, NetworkResult}; + +/// Header length for V1 protocol: magic(4) + command(12) + length(4) + checksum(4) +const HEADER_LEN: usize = 24; + +/// Maximum resync steps per receive call to prevent infinite loops. +const MAX_RESYNC_STEPS_PER_CALL: usize = 64; + +/// Read buffer size for TCP reads. +const READ_BUFFER_SIZE: usize = 8192; + +/// V1 Transport implementation for unencrypted P2P communication. +pub struct V1Transport { + /// The underlying TCP stream. + stream: TcpStream, + /// Stateful message framing buffer. + framing_buffer: Vec, + /// Network for magic byte validation. + network: Network, + /// Remote peer address (for logging). + peer_address: SocketAddr, + /// Bytes sent counter. + bytes_sent: u64, + /// Bytes received counter. + bytes_received: u64, + /// Whether the connection is active. + connected: bool, + /// Consecutive resync counter (for telemetry). + consecutive_resyncs: u32, +} + +impl V1Transport { + /// Create a new V1 transport from an established TCP stream. + /// + /// # Arguments + /// * `stream` - An already-connected TCP stream + /// * `network` - The Dash network (for magic byte validation) + /// * `peer_address` - Remote peer address (for logging) + pub fn new(stream: TcpStream, network: Network, peer_address: SocketAddr) -> Self { + Self { + stream, + framing_buffer: Vec::with_capacity(READ_BUFFER_SIZE), + network, + peer_address, + bytes_sent: 0, + bytes_received: 0, + connected: true, + consecutive_resyncs: 0, + } + } + + /// Helper function to read some bytes into the framing buffer. + async fn read_some(&mut self) -> std::io::Result { + let mut tmp = [0u8; READ_BUFFER_SIZE]; + match self.stream.read(&mut tmp).await { + Ok(0) => Ok(0), + Ok(n) => { + self.framing_buffer.extend_from_slice(&tmp[..n]); + self.bytes_received += n as u64; + Ok(n) + } + Err(e) => Err(e), + } + } + + /// Get the consecutive resync count (for telemetry). + pub fn consecutive_resyncs(&self) -> u32 { + self.consecutive_resyncs + } +} + +#[async_trait] +impl Transport for V1Transport { + async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + let raw_message = RawNetworkMessage { + magic: self.network.magic(), + payload: message, + }; + + let serialized = encode::serialize(&raw_message); + + // Log details for debugging headers2 issues + if matches!( + raw_message.payload, + NetworkMessage::GetHeaders2(_) | NetworkMessage::GetHeaders(_) + ) { + let msg_type = match raw_message.payload { + NetworkMessage::GetHeaders2(_) => "GetHeaders2", + NetworkMessage::GetHeaders(_) => "GetHeaders", + _ => "Unknown", + }; + tracing::debug!( + "V1Transport: Sending {} raw bytes (len={}): {:02x?}", + msg_type, + serialized.len(), + &serialized[..std::cmp::min(100, serialized.len())] + ); + } + + // Write with error handling + match self.stream.write_all(&serialized).await { + Ok(_) => { + // Flush to ensure data is sent immediately + if let Err(e) = self.stream.flush().await { + tracing::warn!( + "V1Transport: Failed to flush socket {}: {}", + self.peer_address, + e + ); + } + self.bytes_sent += serialized.len() as u64; + tracing::debug!( + "V1Transport: Sent message to {}: {:?}", + self.peer_address, + raw_message.payload + ); + Ok(()) + } + Err(e) => { + tracing::warn!( + "V1Transport: Disconnecting {} due to write error: {}", + self.peer_address, + e + ); + self.connected = false; + Err(NetworkError::ConnectionFailed(format!("Write failed: {}", e))) + } + } + } + + async fn receive_message(&mut self) -> NetworkResult> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + let magic_bytes = self.network.magic().to_le_bytes(); + let mut resync_steps = 0usize; + + loop { + // Ensure header availability + if self.framing_buffer.len() < HEADER_LEN { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(ref e) + if e.kind() == std::io::ErrorKind::ConnectionAborted + || e.kind() == std::io::ErrorKind::ConnectionReset => + { + tracing::info!( + "V1Transport: Peer {} connection reset/aborted", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + } + + // Align to magic + if self.framing_buffer.len() >= 4 && self.framing_buffer[..4] != magic_bytes { + if let Some(pos) = self.framing_buffer.windows(4).position(|w| w == magic_bytes) { + if pos > 0 { + tracing::warn!( + "V1Transport {}: stream desync: skipping {} stray bytes before magic", + self.peer_address, + pos + ); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + self.framing_buffer.drain(0..pos); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + continue; + } + } else { + // Keep last 3 bytes of potential magic prefix + if self.framing_buffer.len() > 3 { + let dropped = self.framing_buffer.len() - 3; + tracing::warn!( + "V1Transport {}: stream desync: dropping {} bytes (no magic found)", + self.peer_address, + dropped + ); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + self.framing_buffer.drain(0..dropped); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + } + // Need more data + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!( + "Read failed: {}", + e + ))); + } + } + continue; + } + } + + // Ensure full header + if self.framing_buffer.len() < HEADER_LEN { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + continue; + } + + // Parse header fields + let length_le = u32::from_le_bytes([ + self.framing_buffer[16], + self.framing_buffer[17], + self.framing_buffer[18], + self.framing_buffer[19], + ]) as usize; + let header_checksum = [ + self.framing_buffer[20], + self.framing_buffer[21], + self.framing_buffer[22], + self.framing_buffer[23], + ]; + + // Validate announced length to prevent unbounded accumulation or overflow + if length_le > MAX_MSG_SIZE { + return Err(NetworkError::ProtocolError(format!( + "Declared payload length {} exceeds MAX_MSG_SIZE {}", + length_le, MAX_MSG_SIZE + ))); + } + let total_len = match HEADER_LEN.checked_add(length_le) { + Some(v) => v, + None => { + return Err(NetworkError::ProtocolError("Message length overflow".to_string())); + } + }; + + // Ensure full frame available + if self.framing_buffer.len() < total_len { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V1Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + continue; + } + + // Verify checksum + let payload_slice = &self.framing_buffer[HEADER_LEN..total_len]; + let expected = { + let checksum = dashcore_hashes::sha256d::Hash::hash(payload_slice); + [checksum[0], checksum[1], checksum[2], checksum[3]] + }; + if expected != header_checksum { + tracing::warn!( + "V1Transport: Skipping message with invalid checksum from {}: expected {:02x?}, actual {:02x?}", + self.peer_address, + expected, + header_checksum + ); + if header_checksum == [0, 0, 0, 0] { + tracing::warn!( + "V1Transport: All-zeros checksum detected from {}, likely corrupted stream - resyncing", + self.peer_address + ); + } + // Resync by dropping a byte and retrying + self.framing_buffer.drain(0..1); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + continue; + } + + // Decode full RawNetworkMessage from the frame using existing decoder + let mut cursor = std::io::Cursor::new(&self.framing_buffer[..total_len]); + match RawNetworkMessage::consensus_decode(&mut cursor) { + Ok(raw_message) => { + // Consume bytes + self.framing_buffer.drain(0..total_len); + self.consecutive_resyncs = 0; + + // Validate magic matches our network + if raw_message.magic != self.network.magic() { + tracing::warn!( + "V1Transport: Received message with wrong magic bytes: expected {:#x}, got {:#x}", + self.network.magic(), + raw_message.magic + ); + return Err(NetworkError::ProtocolError(format!( + "Wrong magic bytes: expected {:#x}, got {:#x}", + self.network.magic(), + raw_message.magic + ))); + } + + tracing::trace!( + "V1Transport: Successfully decoded message from {}: {:?}", + self.peer_address, + raw_message.payload.cmd() + ); + + if raw_message.payload.cmd() == "headers2" { + tracing::info!( + "V1Transport: Received Headers2 message from {}!", + self.peer_address + ); + } + + if let NetworkMessage::Block(ref block) = raw_message.payload { + let block_hash = block.block_hash(); + tracing::info!( + "V1Transport: Successfully decoded block {} from {}", + block_hash, + self.peer_address + ); + } + + if let NetworkMessage::Headers2(ref headers2) = raw_message.payload { + tracing::info!( + "V1Transport: Successfully decoded Headers2 message from {} with {} compressed headers", + self.peer_address, + headers2.headers.len() + ); + } + + return Ok(Some(raw_message.payload)); + } + Err(e) => { + tracing::warn!( + "V1Transport {}: decode error after framing ({}), attempting resync", + self.peer_address, + e + ); + self.framing_buffer.drain(0..1); + self.consecutive_resyncs = self.consecutive_resyncs.saturating_add(1); + resync_steps += 1; + if resync_steps >= MAX_RESYNC_STEPS_PER_CALL { + return Ok(None); + } + continue; + } + } + } + } + + fn is_connected(&self) -> bool { + self.connected + } + + fn protocol_version(&self) -> u8 { + 1 + } + + fn bytes_sent(&self) -> u64 { + self.bytes_sent + } + + fn bytes_received(&self) -> u64 { + self.bytes_received + } + + async fn shutdown(&mut self) -> NetworkResult<()> { + if self.connected { + let _ = self.stream.shutdown().await; + self.connected = false; + tracing::info!("V1Transport: Shutdown connection to {}", self.peer_address); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_header_len() { + // Verify our header length constant is correct + assert_eq!(HEADER_LEN, 4 + 12 + 4 + 4); // magic + command + length + checksum + } +} diff --git a/dash-spv/src/network/transport/v2.rs b/dash-spv/src/network/transport/v2.rs new file mode 100644 index 000000000..6a88b565f --- /dev/null +++ b/dash-spv/src/network/transport/v2.rs @@ -0,0 +1,885 @@ +//! V2 Transport - BIP324 encrypted Dash P2P protocol transport. +//! +//! This implements the BIP324 encrypted transport protocol: +//! - 3 bytes: Encrypted length +//! - 1 byte: Header (flags, short message ID or 0x00 for extended) +//! - Variable: Contents (for extended format: 12-byte command + payload) +//! - 16 bytes: Authentication tag (ChaCha20-Poly1305) + +use std::net::SocketAddr; + +use async_trait::async_trait; +use bip324::{CipherSession, PacketType, NUM_LENGTH_BYTES}; +use dashcore::consensus::{encode::serialize, Decodable}; +use dashcore::network::message::{CommandString, NetworkMessage, MAX_MSG_SIZE}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +use super::message_ids::{network_message_to_short_id, short_id_to_command, MSG_ID_EXTENDED}; +use super::Transport; +use crate::error::{NetworkError, NetworkResult}; + +/// Read buffer size for TCP reads. +const READ_BUFFER_SIZE: usize = 8192; + +/// Extended command length in bytes. +const COMMAND_LEN: usize = 12; + +/// V2 Transport implementation for BIP324 encrypted P2P communication. +pub struct V2Transport { + /// The underlying TCP stream. + stream: TcpStream, + /// The cipher session for encryption/decryption. + cipher: CipherSession, + /// Session ID for optional MitM verification. + session_id: [u8; 32], + /// Stateful receive buffer for partial reads. + receive_buffer: Vec, + /// Remote peer address (for logging). + peer_address: SocketAddr, + /// Bytes sent counter. + bytes_sent: u64, + /// Bytes received counter. + bytes_received: u64, + /// Whether the connection is active. + connected: bool, + /// Cached decrypted packet length (to avoid re-decrypting on partial reads). + /// This is needed because `decrypt_packet_len` advances the cipher state. + pending_packet_len: Option, +} + +impl V2Transport { + /// Create a new V2 transport from a successful handshake. + /// + /// # Arguments + /// * `stream` - The TCP stream (ownership transferred from handshake) + /// * `cipher` - The cipher session for encryption/decryption + /// * `session_id` - Session ID for optional MitM verification + /// * `peer_address` - Remote peer address (for logging) + pub fn new( + stream: TcpStream, + cipher: CipherSession, + session_id: [u8; 32], + peer_address: SocketAddr, + ) -> Self { + Self { + stream, + cipher, + session_id, + receive_buffer: Vec::with_capacity(READ_BUFFER_SIZE), + peer_address, + bytes_sent: 0, + bytes_received: 0, + connected: true, + pending_packet_len: None, + } + } + + /// Get the session ID for optional out-of-band MitM verification. + pub fn session_id(&self) -> &[u8; 32] { + &self.session_id + } + + /// Serialize the payload of a NetworkMessage. + /// + /// This mirrors the payload serialization logic from RawNetworkMessage's Encodable impl. + fn serialize_payload(message: &NetworkMessage) -> Vec { + match message { + NetworkMessage::Version(ref dat) => serialize(dat), + NetworkMessage::Addr(ref dat) => serialize(dat), + NetworkMessage::Inv(ref dat) => serialize(dat), + NetworkMessage::GetData(ref dat) => serialize(dat), + NetworkMessage::NotFound(ref dat) => serialize(dat), + NetworkMessage::GetBlocks(ref dat) => serialize(dat), + NetworkMessage::GetHeaders(ref dat) => serialize(dat), + NetworkMessage::Tx(ref dat) => serialize(dat), + NetworkMessage::Block(ref dat) => serialize(dat), + NetworkMessage::Headers(ref dat) => { + // Headers need special serialization with trailing zero byte per header + Self::serialize_headers(dat) + } + NetworkMessage::GetHeaders2(ref dat) => serialize(dat), + NetworkMessage::Headers2(ref dat) => serialize(dat), + NetworkMessage::Ping(ref dat) => serialize(dat), + NetworkMessage::Pong(ref dat) => serialize(dat), + NetworkMessage::MerkleBlock(ref dat) => serialize(dat), + NetworkMessage::FilterLoad(ref dat) => serialize(dat), + NetworkMessage::FilterAdd(ref dat) => serialize(dat), + NetworkMessage::GetCFilters(ref dat) => serialize(dat), + NetworkMessage::CFilter(ref dat) => serialize(dat), + NetworkMessage::GetCFHeaders(ref dat) => serialize(dat), + NetworkMessage::CFHeaders(ref dat) => serialize(dat), + NetworkMessage::GetCFCheckpt(ref dat) => serialize(dat), + NetworkMessage::CFCheckpt(ref dat) => serialize(dat), + NetworkMessage::SendCmpct(ref dat) => serialize(dat), + NetworkMessage::CmpctBlock(ref dat) => serialize(dat), + NetworkMessage::GetBlockTxn(ref dat) => serialize(dat), + NetworkMessage::BlockTxn(ref dat) => serialize(dat), + NetworkMessage::Alert(ref dat) => serialize(dat), + NetworkMessage::Reject(ref dat) => serialize(dat), + NetworkMessage::FeeFilter(ref dat) => serialize(dat), + NetworkMessage::AddrV2(ref dat) => serialize(dat), + NetworkMessage::GetMnListD(ref dat) => serialize(dat), + NetworkMessage::MnListDiff(ref dat) => serialize(dat), + NetworkMessage::GetQRInfo(ref dat) => serialize(dat), + NetworkMessage::QRInfo(ref dat) => serialize(dat), + NetworkMessage::CLSig(ref dat) => serialize(dat), + NetworkMessage::ISLock(ref dat) => serialize(dat), + NetworkMessage::SendDsq(wants_dsq) => serialize(&(*wants_dsq as u8)), + NetworkMessage::Unknown { + payload: ref data, + .. + } => serialize(data), + NetworkMessage::Verack + | NetworkMessage::SendHeaders + | NetworkMessage::SendHeaders2 + | NetworkMessage::MemPool + | NetworkMessage::GetAddr + | NetworkMessage::WtxidRelay + | NetworkMessage::FilterClear + | NetworkMessage::SendAddrV2 => vec![], + } + } + + /// Serialize headers with trailing zero byte per header (matches HeaderSerializationWrapper). + fn serialize_headers(headers: &[dashcore::block::Header]) -> Vec { + use dashcore::consensus::Encodable; + let mut buf = Vec::new(); + // VarInt for count + let _ = dashcore::VarInt(headers.len() as u64).consensus_encode(&mut buf); + // Each header + trailing zero + for header in headers { + let _ = header.consensus_encode(&mut buf); + buf.push(0u8); + } + buf + } + + /// Deserialize headers with trailing zero byte per header (matches HeaderDeserializationWrapper). + fn deserialize_headers(payload: &[u8]) -> NetworkResult> { + let mut cursor = std::io::Cursor::new(payload); + let count = dashcore::VarInt::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode headers count: {}", e)) + })?; + + let mut headers = Vec::with_capacity(count.0 as usize); + for _ in 0..count.0 { + let header = dashcore::block::Header::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode header: {}", e)) + })?; + headers.push(header); + // Read and discard the trailing zero byte + let _trailing = u8::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode header trailing byte: {}", e)) + })?; + } + Ok(headers) + } + + /// Encode a NetworkMessage into V2 plaintext format. + /// + /// Format: + /// - Short format (common messages): payload bytes (header byte added by cipher) + /// - Extended format (Dash-specific): 12-byte command + payload bytes + fn encode_message(&self, message: &NetworkMessage) -> NetworkResult> { + // Serialize the message payload + let payload = Self::serialize_payload(message); + + // Check for short message ID + if let Some(short_id) = network_message_to_short_id(message) { + // Short format: just the short ID byte followed by payload + // The short ID will be put in the header byte by the cipher + // So we return: [short_id] + payload + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(short_id); + plaintext.extend_from_slice(&payload); + Ok(plaintext) + } else { + // Extended format: 0x00 header + 12-byte command + payload + let cmd = message.cmd(); + let cmd_bytes = cmd.as_bytes(); + + // Create 12-byte null-padded command + let mut command = [0u8; COMMAND_LEN]; + let copy_len = std::cmp::min(cmd_bytes.len(), COMMAND_LEN); + command[..copy_len].copy_from_slice(&cmd_bytes[..copy_len]); + + let mut plaintext = Vec::with_capacity(1 + COMMAND_LEN + payload.len()); + plaintext.push(MSG_ID_EXTENDED); // 0x00 marker for extended format + plaintext.extend_from_slice(&command); + plaintext.extend_from_slice(&payload); + Ok(plaintext) + } + } + + /// Decode a V2 plaintext into a NetworkMessage. + /// + /// # Arguments + /// * `plaintext` - Decrypted plaintext (header byte + optional command + payload) + fn decode_message(&self, plaintext: &[u8]) -> NetworkResult { + // The bip324 crate prepends a "packet type" byte (0 for Genuine, 128 for Decoy) + // Our actual message ID/content starts at byte 1 + if plaintext.len() < 2 { + return Err(NetworkError::ProtocolError("V2 message too short".to_string())); + } + + // Byte 0 is the crate's packet type indicator (always 0 for genuine messages) + // Byte 1 is our actual message ID (short ID or 0 for extended format) + let _crate_header = plaintext[0]; // Should be 0 for genuine, 128 for decoy + let message_id = plaintext[1]; + + // Trace: log first bytes of decrypted plaintext (verbose, for debugging only) + let preview_len = std::cmp::min(20, plaintext.len()); + tracing::trace!( + "V2Transport: Decrypted message preview ({} bytes total): {:02x?}, message_id={}", + plaintext.len(), + &plaintext[..preview_len], + message_id + ); + + if message_id == MSG_ID_EXTENDED { + // Extended format: 12-byte command + payload (starting at byte 2) + if plaintext.len() < 2 + COMMAND_LEN { + return Err(NetworkError::ProtocolError( + "V2 extended message too short".to_string(), + )); + } + + let command_bytes = &plaintext[2..2 + COMMAND_LEN]; + let payload = &plaintext[2 + COMMAND_LEN..]; + + // Find null terminator in command + let cmd_end = command_bytes.iter().position(|&b| b == 0).unwrap_or(COMMAND_LEN); + let cmd = std::str::from_utf8(&command_bytes[..cmd_end]).map_err(|_| { + NetworkError::ProtocolError("Invalid UTF-8 in V2 command".to_string()) + })?; + + tracing::trace!( + "V2Transport: Decoding extended format message '{}' ({} bytes payload) from {}", + cmd, + payload.len(), + self.peer_address + ); + + // Decode the NetworkMessage based on command + self.decode_by_command(cmd, payload) + } else { + // Short format: message_id is the short message ID, payload starts at byte 2 + let payload = &plaintext[2..]; + + let cmd = short_id_to_command(message_id).ok_or_else(|| { + NetworkError::ProtocolError(format!("Unknown V2 short message ID: {}", message_id)) + })?; + + tracing::trace!( + "V2Transport: Decoding short format message '{}' (ID={}, {} bytes payload) from {}", + cmd, + message_id, + payload.len(), + self.peer_address + ); + + self.decode_by_command(cmd, payload) + } + } + + /// Decode a NetworkMessage from command string and payload bytes. + fn decode_by_command(&self, cmd: &str, payload: &[u8]) -> NetworkResult { + // Create a cursor for decoding + let mut cursor = std::io::Cursor::new(payload); + + // Decode based on command + // Note: This mirrors the NetworkMessage variants and their Decodable impls + let message = match cmd { + "addr" => { + let addrs: Vec<(u32, dashcore::network::address::Address)> = + Decodable::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode addr: {}", e)) + })?; + NetworkMessage::Addr(addrs) + } + "block" => { + let block = dashcore::Block::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode block: {}", e)) + })?; + NetworkMessage::Block(block) + } + "blocktxn" => { + let blocktxn = + dashcore::network::message_compact_blocks::BlockTxn::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode blocktxn: {}", e)) + })?; + NetworkMessage::BlockTxn(blocktxn) + } + "cmpctblock" => { + let cmpctblock = + dashcore::network::message_compact_blocks::CmpctBlock::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode cmpctblock: {}", e)) + })?; + NetworkMessage::CmpctBlock(cmpctblock) + } + "feefilter" => { + let fee = i64::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode feefilter: {}", e)) + })?; + NetworkMessage::FeeFilter(fee) + } + "filteradd" => { + let filteradd = + dashcore::network::message_bloom::FilterAdd::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode filteradd: {}", + e + )) + })?; + NetworkMessage::FilterAdd(filteradd) + } + "filterclear" => NetworkMessage::FilterClear, + "filterload" => { + let filterload = + dashcore::network::message_bloom::FilterLoad::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode filterload: {}", + e + )) + })?; + NetworkMessage::FilterLoad(filterload) + } + "getblocks" => { + let getblocks = + dashcore::network::message_blockdata::GetBlocksMessage::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode getblocks: {}", e)) + })?; + NetworkMessage::GetBlocks(getblocks) + } + "getblocktxn" => { + let getblocktxn = + dashcore::network::message_compact_blocks::GetBlockTxn::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode getblocktxn: {}", e)) + })?; + NetworkMessage::GetBlockTxn(getblocktxn) + } + "getdata" => { + let inv: Vec = + Decodable::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode getdata: {}", e)) + })?; + NetworkMessage::GetData(inv) + } + "getheaders" => { + let getheaders = + dashcore::network::message_blockdata::GetHeadersMessage::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode getheaders: {}", e)) + })?; + NetworkMessage::GetHeaders(getheaders) + } + "headers" => { + // Headers have special deserialization (VarInt count + each header + trailing zero) + let headers = Self::deserialize_headers(payload)?; + NetworkMessage::Headers(headers) + } + "inv" => { + let inv: Vec = + Decodable::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode inv: {}", e)) + })?; + NetworkMessage::Inv(inv) + } + "mempool" => NetworkMessage::MemPool, + "merkleblock" => { + let merkleblock = + dashcore::MerkleBlock::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode merkleblock: {}", e)) + })?; + NetworkMessage::MerkleBlock(merkleblock) + } + "notfound" => { + let inv: Vec = + Decodable::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode notfound: {}", e)) + })?; + NetworkMessage::NotFound(inv) + } + "ping" => { + let nonce = u64::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode ping: {}", e)) + })?; + NetworkMessage::Ping(nonce) + } + "pong" => { + let nonce = u64::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode pong: {}", e)) + })?; + NetworkMessage::Pong(nonce) + } + "sendcmpct" => { + let sendcmpct = + dashcore::network::message_compact_blocks::SendCmpct::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode sendcmpct: {}", e)) + })?; + NetworkMessage::SendCmpct(sendcmpct) + } + "tx" => { + let tx = dashcore::Transaction::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode tx: {}", e)) + })?; + NetworkMessage::Tx(tx) + } + "getcfilters" => { + let getcfilters = + dashcore::network::message_filter::GetCFilters::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode getcfilters: {}", e)) + })?; + NetworkMessage::GetCFilters(getcfilters) + } + "cfilter" => { + let cfilter = + dashcore::network::message_filter::CFilter::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode cfilter: {}", e)) + })?; + NetworkMessage::CFilter(cfilter) + } + "getcfheaders" => { + let getcfheaders = + dashcore::network::message_filter::GetCFHeaders::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode getcfheaders: {}", + e + )) + })?; + NetworkMessage::GetCFHeaders(getcfheaders) + } + "cfheaders" => { + let cfheaders = + dashcore::network::message_filter::CFHeaders::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode cfheaders: {}", + e + )) + })?; + NetworkMessage::CFHeaders(cfheaders) + } + "getcfcheckpt" => { + let getcfcheckpt = + dashcore::network::message_filter::GetCFCheckpt::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode getcfcheckpt: {}", + e + )) + })?; + NetworkMessage::GetCFCheckpt(getcfcheckpt) + } + "cfcheckpt" => { + let cfcheckpt = + dashcore::network::message_filter::CFCheckpt::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode cfcheckpt: {}", + e + )) + })?; + NetworkMessage::CFCheckpt(cfcheckpt) + } + // Dash-specific messages (extended format) + "version" => { + let version = dashcore::network::message_network::VersionMessage::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode version: {}", e)) + })?; + NetworkMessage::Version(version) + } + "verack" => NetworkMessage::Verack, + "sendheaders" => NetworkMessage::SendHeaders, + "getaddr" => NetworkMessage::GetAddr, + "wtxidrelay" => NetworkMessage::WtxidRelay, + "sendaddrv2" => NetworkMessage::SendAddrV2, + "addrv2" => { + let addrs: Vec = + Decodable::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode addrv2: {}", e)) + })?; + NetworkMessage::AddrV2(addrs) + } + "reject" => { + let reject = + dashcore::network::message_network::Reject::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode reject: {}", e)) + })?; + NetworkMessage::Reject(reject) + } + // Dash-specific extended messages + "mnlistdiff" => { + let mnlistdiff = + dashcore::network::message_sml::MnListDiff::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode mnlistdiff: {}", + e + )) + })?; + NetworkMessage::MnListDiff(mnlistdiff) + } + "getmnlistd" => { + let getmnlistd = + dashcore::network::message_sml::GetMnListDiff::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode getmnlistd: {}", + e + )) + })?; + NetworkMessage::GetMnListD(getmnlistd) + } + "qrinfo" => { + let qrinfo = + dashcore::network::message_qrinfo::QRInfo::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode qrinfo: {}", e)) + })?; + NetworkMessage::QRInfo(qrinfo) + } + "getqrinfo" => { + let getqrinfo = + dashcore::network::message_qrinfo::GetQRInfo::consensus_decode(&mut cursor) + .map_err(|e| { + NetworkError::ProtocolError(format!( + "Failed to decode getqrinfo: {}", + e + )) + })?; + NetworkMessage::GetQRInfo(getqrinfo) + } + "clsig" => { + let clsig = dashcore::ChainLock::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode clsig: {}", e)) + })?; + NetworkMessage::CLSig(clsig) + } + "isdlock" => { + let islock = dashcore::InstantLock::consensus_decode(&mut cursor).map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode isdlock: {}", e)) + })?; + NetworkMessage::ISLock(islock) + } + "headers2" => { + let headers2 = + dashcore::network::message_headers2::Headers2Message::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode headers2: {}", e)) + })?; + NetworkMessage::Headers2(headers2) + } + "getheaders2" => { + // getheaders2 uses same format as getheaders + let getheaders2 = + dashcore::network::message_blockdata::GetHeadersMessage::consensus_decode( + &mut cursor, + ) + .map_err(|e| { + NetworkError::ProtocolError(format!("Failed to decode getheaders2: {}", e)) + })?; + NetworkMessage::GetHeaders2(getheaders2) + } + "sendheaders2" => NetworkMessage::SendHeaders2, + "senddsq" => { + // SendDsq is a single bool (serialized as u8) + let wants_dsq = if payload.is_empty() { + false + } else { + payload[0] != 0 + }; + NetworkMessage::SendDsq(wants_dsq) + } + // Unknown command - use Unknown variant + _ => { + tracing::warn!( + "V2Transport: Unknown command '{}' from {}, storing as raw bytes", + cmd, + self.peer_address + ); + NetworkMessage::Unknown { + command: CommandString::try_from(cmd.to_string()).unwrap_or_else(|_| { + CommandString::try_from("unknown".to_string()).expect("valid") + }), + payload: payload.to_vec(), + } + } + }; + + Ok(message) + } + + /// Helper function to read some bytes into the receive buffer. + async fn read_some(&mut self) -> std::io::Result { + let mut tmp = [0u8; READ_BUFFER_SIZE]; + match self.stream.read(&mut tmp).await { + Ok(0) => Ok(0), + Ok(n) => { + self.receive_buffer.extend_from_slice(&tmp[..n]); + self.bytes_received += n as u64; + Ok(n) + } + Err(e) => Err(e), + } + } +} + +#[async_trait] +impl Transport for V2Transport { + async fn send_message(&mut self, message: NetworkMessage) -> NetworkResult<()> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + // Encode the message to V2 plaintext format + let plaintext = self.encode_message(&message)?; + + tracing::debug!( + "V2Transport: Encoding message {:?} ({} bytes plaintext) for {}", + message.cmd(), + plaintext.len(), + self.peer_address + ); + + // Encrypt the message + // Note: The bip324 crate handles the header byte internally, but we're + // putting our message type in the plaintext, so we use Genuine packet type + let encrypted = + self.cipher.outbound().encrypt_to_vec(&plaintext, PacketType::Genuine, None); + + // Write the encrypted packet + match self.stream.write_all(&encrypted).await { + Ok(_) => { + // Flush to ensure data is sent immediately + if let Err(e) = self.stream.flush().await { + tracing::warn!( + "V2Transport: Failed to flush socket {}: {}", + self.peer_address, + e + ); + } + self.bytes_sent += encrypted.len() as u64; + tracing::debug!( + "V2Transport: Sent encrypted message to {}: {:?} ({} bytes)", + self.peer_address, + message.cmd(), + encrypted.len() + ); + Ok(()) + } + Err(e) => { + tracing::warn!( + "V2Transport: Disconnecting {} due to write error: {}", + self.peer_address, + e + ); + self.connected = false; + Err(NetworkError::ConnectionFailed(format!("Write failed: {}", e))) + } + } + } + + async fn receive_message(&mut self) -> NetworkResult> { + if !self.connected { + return Err(NetworkError::ConnectionFailed("Not connected".to_string())); + } + + loop { + // Step 1: Ensure we have at least 3 bytes for the length + while self.receive_buffer.len() < NUM_LENGTH_BYTES { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V2Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(ref e) + if e.kind() == std::io::ErrorKind::ConnectionAborted + || e.kind() == std::io::ErrorKind::ConnectionReset => + { + tracing::info!( + "V2Transport: Peer {} connection reset/aborted", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + } + + // Step 2: Decrypt the length (only if we haven't already for this packet) + // IMPORTANT: decrypt_packet_len advances the cipher state, so we must + // cache the result if we don't have enough bytes for the full packet yet. + let packet_len = if let Some(cached_len) = self.pending_packet_len { + cached_len + } else { + let len_bytes: [u8; NUM_LENGTH_BYTES] = + self.receive_buffer[..NUM_LENGTH_BYTES].try_into().expect("3 bytes for length"); + + // Note: decrypt_packet_len returns the length of remaining data to read + // (header + contents + tag), NOT just the contents length + let decrypted_len = self.cipher.inbound().decrypt_packet_len(len_bytes); + + // Validate packet length + if decrypted_len > MAX_MSG_SIZE + 1 + 16 { + // MAX_MSG_SIZE + header + tag + return Err(NetworkError::ProtocolError(format!( + "V2 packet too large: {} bytes", + decrypted_len + ))); + } + + // Cache the length in case we need to return early + self.pending_packet_len = Some(decrypted_len); + decrypted_len + }; + + let total_len = NUM_LENGTH_BYTES + packet_len; + + // Step 3: Ensure we have the complete packet + while self.receive_buffer.len() < total_len { + match self.read_some().await { + Ok(0) => { + tracing::info!( + "V2Transport: Peer {} closed connection (EOF)", + self.peer_address + ); + self.connected = false; + return Err(NetworkError::PeerDisconnected); + } + Ok(_) => {} + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + return Ok(None); + } + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => { + return Ok(None); + } + Err(e) => { + self.connected = false; + return Err(NetworkError::ConnectionFailed(format!("Read failed: {}", e))); + } + } + } + + // Step 4: Extract and decrypt the packet (excluding length bytes which are already consumed) + let ciphertext = &self.receive_buffer[NUM_LENGTH_BYTES..total_len]; + + let (packet_type, plaintext) = + self.cipher.inbound().decrypt_to_vec(ciphertext, None).map_err(|e| { + NetworkError::V2DecryptionFailed(format!("Decryption failed: {}", e)) + })?; + + // Consume the packet from the buffer and clear cached length + self.receive_buffer.drain(0..total_len); + self.pending_packet_len = None; + + // Step 5: Handle decoy packets + if packet_type == PacketType::Decoy { + tracing::debug!( + "V2Transport: Received decoy packet from {}, ignoring", + self.peer_address + ); + continue; // Read next packet + } + + // Step 6: Decode the message + // Note: plaintext includes the header byte at position 0 + let message = self.decode_message(&plaintext)?; + + tracing::trace!( + "V2Transport: Successfully decoded message from {}: {:?}", + self.peer_address, + message.cmd() + ); + + return Ok(Some(message)); + } + } + + fn is_connected(&self) -> bool { + self.connected + } + + fn protocol_version(&self) -> u8 { + 2 + } + + fn bytes_sent(&self) -> u64 { + self.bytes_sent + } + + fn bytes_received(&self) -> u64 { + self.bytes_received + } + + async fn shutdown(&mut self) -> NetworkResult<()> { + if self.connected { + let _ = self.stream.shutdown().await; + self.connected = false; + tracing::info!("V2Transport: Shutdown connection to {}", self.peer_address); + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_command_len() { + // Verify command length constant + assert_eq!(COMMAND_LEN, 12); + } + + #[test] + fn test_short_id_encoding() { + // Verify ping/pong use short IDs + assert!(network_message_to_short_id(&NetworkMessage::Ping(0)).is_some()); + assert!(network_message_to_short_id(&NetworkMessage::Pong(0)).is_some()); + } +} diff --git a/dash-spv/src/network/transport/v2_handshake.rs b/dash-spv/src/network/transport/v2_handshake.rs new file mode 100644 index 000000000..29ef1003c --- /dev/null +++ b/dash-spv/src/network/transport/v2_handshake.rs @@ -0,0 +1,413 @@ +//! V2 Handshake implementation for BIP324 encrypted transport. +//! +//! This module implements the BIP324 handshake protocol: +//! 1. Key Exchange: ElligatorSwift-encoded public keys + garbage data +//! 2. Version Negotiation: Encrypted version packets confirm mutual v2 support +//! +//! The handshake detects v1-only peers by checking if the first bytes +//! received match the network magic (indicating v1 protocol). + +use std::net::SocketAddr; +use std::time::Duration; + +use bip324::{CipherSession, GarbageResult, Handshake, Role, VersionResult}; +use dashcore::Network; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; + +use crate::error::{NetworkError, NetworkResult}; + +/// Maximum garbage data size per BIP324 spec. +const MAX_GARBAGE_LEN: usize = 4095; + +/// Timeout for handshake operations. +const HANDSHAKE_TIMEOUT: Duration = Duration::from_secs(10); + +/// Size of ElligatorSwift public key. +const ELLIGATOR_SWIFT_KEY_SIZE: usize = 64; + +/// Size of garbage terminator. +const GARBAGE_TERMINATOR_SIZE: usize = 16; + +/// Result of the V2 handshake attempt. +pub enum V2HandshakeResult { + /// Successfully completed V2 handshake. + Success(V2Session), + /// Detected V1-only peer (first bytes matched network magic). + FallbackToV1, +} + +/// Session data from a successful V2 handshake. +pub struct V2Session { + /// The TCP stream (ownership transferred from handshake). + pub stream: TcpStream, + /// The cipher session for encryption/decryption. + pub cipher: CipherSession, + /// Session ID for optional out-of-band MitM verification. + pub session_id: [u8; 32], +} + +/// V2 Handshake manager for BIP324 encrypted connections. +pub struct V2HandshakeManager { + /// Network magic bytes for key derivation. + magic: [u8; 4], + /// Our role in the handshake (initiator or responder). + role: Role, + /// Peer address (for logging). + peer_address: SocketAddr, +} + +impl V2HandshakeManager { + /// Create a new handshake manager for initiating connections. + /// + /// The initiator sends the first message (their ElligatorSwift pubkey). + pub fn new_initiator(network: Network, peer_address: SocketAddr) -> Self { + Self { + magic: network.magic().to_le_bytes(), + role: Role::Initiator, + peer_address, + } + } + + /// Create a new handshake manager for responding to connections. + /// + /// The responder waits for the initiator's pubkey first. + pub fn new_responder(network: Network, peer_address: SocketAddr) -> Self { + Self { + magic: network.magic().to_le_bytes(), + role: Role::Responder, + peer_address, + } + } + + /// Perform the V2 handshake on the given TCP stream. + /// + /// # Arguments + /// * `stream` - A connected TCP stream + /// + /// # Returns + /// * `V2HandshakeResult::Success(session)` - Handshake completed successfully + /// * `V2HandshakeResult::FallbackToV1` - Detected v1-only peer + /// + /// # Errors + /// Returns `NetworkError` if the handshake fails (e.g., timeout, protocol error). + pub async fn perform_handshake( + self, + mut stream: TcpStream, + ) -> NetworkResult { + tracing::debug!("V2 handshake: Starting as {:?} with {}", self.role, self.peer_address); + + // Create the handshake state machine + let handshake = Handshake::new(self.magic, self.role).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to create handshake: {}", e)) + })?; + + // Step 1: Send our public key (no garbage for simplicity) + let mut send_key_buffer = vec![0u8; Handshake::send_key_len(None)]; + let handshake = handshake.send_key(None, &mut send_key_buffer).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to prepare key: {}", e)) + })?; + + tracing::debug!( + "V2 handshake: Sending our ElligatorSwift pubkey ({} bytes) to {}", + send_key_buffer.len(), + self.peer_address + ); + + tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.write_all(&send_key_buffer)) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to send pubkey: {}", e)) + })?; + + stream + .flush() + .await + .map_err(|e| NetworkError::V2HandshakeFailed(format!("Failed to flush: {}", e)))?; + + // Step 2: Read the remote's public key (64 bytes) + // First, peek at the initial bytes to detect v1 magic + let mut peek_buf = [0u8; 4]; + let peek_result = tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.peek(&mut peek_buf)).await; + + match peek_result { + Ok(Ok(n)) if n >= 4 => { + if peek_buf == self.magic { + tracing::info!( + "V2 handshake: Detected V1-only peer {} (received magic bytes)", + self.peer_address + ); + return Ok(V2HandshakeResult::FallbackToV1); + } + } + Ok(Ok(_)) => { + // Not enough bytes to determine, continue with v2 + } + Ok(Err(e)) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to peek for v1 detection: {}", + e + ))); + } + Err(_) => { + return Err(NetworkError::Timeout); + } + } + + // Read the full remote pubkey (64 bytes) + let mut remote_pubkey = [0u8; ELLIGATOR_SWIFT_KEY_SIZE]; + tracing::debug!( + "V2 handshake: Reading remote ElligatorSwift pubkey from {}", + self.peer_address + ); + + tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut remote_pubkey)) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to read remote pubkey: {}", e)) + })?; + + // Step 3: Process the remote's public key and derive session keys + let handshake = handshake.receive_key(remote_pubkey).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to process remote pubkey: {}", e)) + })?; + + tracing::debug!("V2 handshake: Derived session keys with {}", self.peer_address); + + // Step 4: Send garbage terminator + version packet + let mut send_version_buffer = vec![0u8; Handshake::send_version_len(None)]; + let handshake = handshake.send_version(&mut send_version_buffer, None).map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to prepare version: {}", e)) + })?; + + tracing::debug!( + "V2 handshake: Sending garbage terminator + version ({} bytes) to {}", + send_version_buffer.len(), + self.peer_address + ); + + tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.write_all(&send_version_buffer)) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!("Failed to send version: {}", e)) + })?; + + stream + .flush() + .await + .map_err(|e| NetworkError::V2HandshakeFailed(format!("Failed to flush: {}", e)))?; + + // Step 5: Receive remote garbage + terminator + // Read up to MAX_GARBAGE_LEN + GARBAGE_TERMINATOR_SIZE bytes + let mut garbage_buffer = Vec::with_capacity(MAX_GARBAGE_LEN + GARBAGE_TERMINATOR_SIZE); + let mut handshake_state = handshake; + + tracing::debug!( + "V2 handshake: Scanning for remote garbage terminator from {}", + self.peer_address + ); + + let scan_start = std::time::Instant::now(); + loop { + // Check timeout + if scan_start.elapsed() > HANDSHAKE_TIMEOUT { + return Err(NetworkError::Timeout); + } + + // Read a chunk + let mut chunk = [0u8; 256]; + let n = match tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut chunk), + ) + .await + { + Ok(Ok(0)) => { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed during garbage scan".to_string(), + )); + } + Ok(Ok(n)) => n, + Ok(Err(e)) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to read garbage: {}", + e + ))); + } + Err(_) => { + return Err(NetworkError::Timeout); + } + }; + + garbage_buffer.extend_from_slice(&chunk[..n]); + + // Try to find the garbage terminator + match handshake_state.receive_garbage(&garbage_buffer) { + Ok(GarbageResult::FoundGarbage { + handshake, + consumed_bytes, + }) => { + tracing::debug!( + "V2 handshake: Found garbage terminator after {} bytes from {}", + consumed_bytes, + self.peer_address + ); + + // Keep any remaining bytes after the garbage + let remaining = garbage_buffer[consumed_bytes..].to_vec(); + + // Step 6: Receive version packet + // The version packet follows the garbage terminator + let mut handshake = handshake; + + // Read version packet: 3-byte length + encrypted content + let mut version_data = remaining; + + // Read at least 3 bytes for the length prefix + while version_data.len() < 3 { + let mut more = [0u8; 64]; + let n = tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut more), + ) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to read version packet length: {}", + e + )) + })?; + if n == 0 { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed before version packet".to_string(), + )); + } + version_data.extend_from_slice(&more[..n]); + } + + // Decrypt the packet length (first 3 bytes) + let length_bytes: [u8; 3] = version_data[..3].try_into().map_err(|_| { + NetworkError::V2HandshakeFailed( + "Failed to extract length bytes".to_string(), + ) + })?; + let packet_len = handshake.decrypt_packet_len(length_bytes).map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to decrypt packet length: {}", + e + )) + })?; + + tracing::debug!( + "V2 handshake: Version packet length is {} bytes from {}", + packet_len, + self.peer_address + ); + + // Read more data if needed to have the full packet + let total_needed = 3 + packet_len; // length prefix + packet content + while version_data.len() < total_needed { + let mut more = [0u8; 64]; + let n = tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut more), + ) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to read version packet content: {}", + e + )) + })?; + if n == 0 { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed before version packet complete".to_string(), + )); + } + version_data.extend_from_slice(&more[..n]); + } + + // Extract just the packet content (excluding the 3-byte length prefix) + let mut packet_content = version_data[3..3 + packet_len].to_vec(); + + // Process version packet + match handshake.receive_version(&mut packet_content) { + Ok(VersionResult::Complete { + cipher, + }) => { + tracing::info!( + "V2 handshake: Completed successfully with {}", + self.peer_address + ); + + return Ok(V2HandshakeResult::Success(V2Session { + stream, + cipher, + session_id: [0u8; 32], // TODO: Get actual session ID + })); + } + Ok(VersionResult::Decoy(_handshake)) => { + // Received a decoy packet, need to continue reading + // For now, treat as error (can be enhanced later) + return Err(NetworkError::V2HandshakeFailed( + "Received decoy packet - not yet supported".to_string(), + )); + } + Err(e) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to process version packet: {}", + e + ))); + } + } + } + Ok(GarbageResult::NeedMoreData(hs)) => { + handshake_state = hs; + // Continue reading more data + if garbage_buffer.len() > MAX_GARBAGE_LEN + GARBAGE_TERMINATOR_SIZE { + return Err(NetworkError::V2HandshakeFailed( + "Garbage terminator not found within limit".to_string(), + )); + } + } + Err(e) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to process garbage: {}", + e + ))); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_handshake_manager_creation() { + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + + let initiator = V2HandshakeManager::new_initiator(Network::Dash, addr); + assert_eq!(initiator.role, Role::Initiator); + // Dash mainnet magic: 0xBD6B0CBF in little-endian + assert_eq!(initiator.magic, [0xbf, 0x0c, 0x6b, 0xbd]); + + let responder = V2HandshakeManager::new_responder(Network::Dash, addr); + assert_eq!(responder.role, Role::Responder); + } + + #[test] + fn test_testnet_magic() { + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let manager = V2HandshakeManager::new_initiator(Network::Testnet, addr); + // Dash testnet magic: 0xFFCAE2CE in little-endian + assert_eq!(manager.magic, [0xce, 0xe2, 0xca, 0xff]); + } +} diff --git a/dash-spv/tests/handshake_test.rs b/dash-spv/tests/handshake_test.rs index d8cb6579f..a25f5dda5 100644 --- a/dash-spv/tests/handshake_test.rs +++ b/dash-spv/tests/handshake_test.rs @@ -4,6 +4,7 @@ use std::net::SocketAddr; use std::time::Duration; use dash_spv::client::config::MempoolStrategy; +use dash_spv::network::transport::TransportPreference; use dash_spv::network::{HandshakeManager, NetworkManager, Peer, PeerNetworkManager}; use dash_spv::{ClientConfig, Network}; @@ -13,7 +14,7 @@ async fn test_handshake_with_mainnet_peer() { let _ = env_logger::builder().filter_level(log::LevelFilter::Debug).is_test(true).try_init(); let peer_addr: SocketAddr = "127.0.0.1:9999".parse().expect("Valid peer address"); - let result = Peer::connect(peer_addr, 10, Network::Dash).await; + let result = Peer::connect(peer_addr, 10, Network::Dash, TransportPreference::V1Only).await; match result { Ok(mut connection) => { @@ -54,7 +55,7 @@ async fn test_handshake_timeout() { // Using a non-routable IP that will cause the connection to hang let peer_addr: SocketAddr = "10.255.255.1:9999".parse().expect("Valid peer address"); let start = std::time::Instant::now(); - let result = Peer::connect(peer_addr, 2, Network::Dash).await; + let result = Peer::connect(peer_addr, 2, Network::Dash, TransportPreference::V1Only).await; let elapsed = start.elapsed(); assert!(result.is_err(), "Connection should fail for non-routable peer"); @@ -92,7 +93,7 @@ async fn test_multiple_connect_disconnect_cycles() { for i in 1..=3 { println!("Attempt {} to connect to {}", i, peer_addr); - match connection.connect_instance().await { + match connection.connect_instance(TransportPreference::V1Only).await { Ok(_) => { assert!(connection.is_connected(), "Should be connected after successful connect"); From df5046c49e27a70be2697d8bebffb627fc1c1b86 Mon Sep 17 00:00:00 2001 From: pasta Date: Thu, 11 Dec 2025 09:34:03 -0600 Subject: [PATCH 2/9] refactor(dashcore): Extract payload serialization for V2 transport reuse MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add consensus_encode_payload() and consensus_decode_payload() methods to NetworkMessage, enabling V2 (BIP324) transport to use dashcore's canonical message serialization instead of duplicating ~400 lines of encoding logic. Changes: - Add NetworkMessage::consensus_encode_payload() for payload-only serialization - Add NetworkMessage::consensus_decode_payload(cmd, payload) for decoding - Refactor RawNetworkMessage encode/decode to use new methods - Remove duplicated serialize_payload, deserialize_headers, decode_by_command from V2Transport (~380 lines removed) This ensures Headers messages are properly serialized with trailing zero byte (transaction count) and validates during deserialization that it equals 0. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash-spv/src/network/transport/v2.rs | 468 +-------------------------- dash/src/network/message.rs | 262 ++++++++------- 2 files changed, 159 insertions(+), 571 deletions(-) diff --git a/dash-spv/src/network/transport/v2.rs b/dash-spv/src/network/transport/v2.rs index 6a88b565f..ef161a783 100644 --- a/dash-spv/src/network/transport/v2.rs +++ b/dash-spv/src/network/transport/v2.rs @@ -10,8 +10,7 @@ use std::net::SocketAddr; use async_trait::async_trait; use bip324::{CipherSession, PacketType, NUM_LENGTH_BYTES}; -use dashcore::consensus::{encode::serialize, Decodable}; -use dashcore::network::message::{CommandString, NetworkMessage, MAX_MSG_SIZE}; +use dashcore::network::message::{NetworkMessage, MAX_MSG_SIZE}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -80,110 +79,14 @@ impl V2Transport { &self.session_id } - /// Serialize the payload of a NetworkMessage. - /// - /// This mirrors the payload serialization logic from RawNetworkMessage's Encodable impl. - fn serialize_payload(message: &NetworkMessage) -> Vec { - match message { - NetworkMessage::Version(ref dat) => serialize(dat), - NetworkMessage::Addr(ref dat) => serialize(dat), - NetworkMessage::Inv(ref dat) => serialize(dat), - NetworkMessage::GetData(ref dat) => serialize(dat), - NetworkMessage::NotFound(ref dat) => serialize(dat), - NetworkMessage::GetBlocks(ref dat) => serialize(dat), - NetworkMessage::GetHeaders(ref dat) => serialize(dat), - NetworkMessage::Tx(ref dat) => serialize(dat), - NetworkMessage::Block(ref dat) => serialize(dat), - NetworkMessage::Headers(ref dat) => { - // Headers need special serialization with trailing zero byte per header - Self::serialize_headers(dat) - } - NetworkMessage::GetHeaders2(ref dat) => serialize(dat), - NetworkMessage::Headers2(ref dat) => serialize(dat), - NetworkMessage::Ping(ref dat) => serialize(dat), - NetworkMessage::Pong(ref dat) => serialize(dat), - NetworkMessage::MerkleBlock(ref dat) => serialize(dat), - NetworkMessage::FilterLoad(ref dat) => serialize(dat), - NetworkMessage::FilterAdd(ref dat) => serialize(dat), - NetworkMessage::GetCFilters(ref dat) => serialize(dat), - NetworkMessage::CFilter(ref dat) => serialize(dat), - NetworkMessage::GetCFHeaders(ref dat) => serialize(dat), - NetworkMessage::CFHeaders(ref dat) => serialize(dat), - NetworkMessage::GetCFCheckpt(ref dat) => serialize(dat), - NetworkMessage::CFCheckpt(ref dat) => serialize(dat), - NetworkMessage::SendCmpct(ref dat) => serialize(dat), - NetworkMessage::CmpctBlock(ref dat) => serialize(dat), - NetworkMessage::GetBlockTxn(ref dat) => serialize(dat), - NetworkMessage::BlockTxn(ref dat) => serialize(dat), - NetworkMessage::Alert(ref dat) => serialize(dat), - NetworkMessage::Reject(ref dat) => serialize(dat), - NetworkMessage::FeeFilter(ref dat) => serialize(dat), - NetworkMessage::AddrV2(ref dat) => serialize(dat), - NetworkMessage::GetMnListD(ref dat) => serialize(dat), - NetworkMessage::MnListDiff(ref dat) => serialize(dat), - NetworkMessage::GetQRInfo(ref dat) => serialize(dat), - NetworkMessage::QRInfo(ref dat) => serialize(dat), - NetworkMessage::CLSig(ref dat) => serialize(dat), - NetworkMessage::ISLock(ref dat) => serialize(dat), - NetworkMessage::SendDsq(wants_dsq) => serialize(&(*wants_dsq as u8)), - NetworkMessage::Unknown { - payload: ref data, - .. - } => serialize(data), - NetworkMessage::Verack - | NetworkMessage::SendHeaders - | NetworkMessage::SendHeaders2 - | NetworkMessage::MemPool - | NetworkMessage::GetAddr - | NetworkMessage::WtxidRelay - | NetworkMessage::FilterClear - | NetworkMessage::SendAddrV2 => vec![], - } - } - - /// Serialize headers with trailing zero byte per header (matches HeaderSerializationWrapper). - fn serialize_headers(headers: &[dashcore::block::Header]) -> Vec { - use dashcore::consensus::Encodable; - let mut buf = Vec::new(); - // VarInt for count - let _ = dashcore::VarInt(headers.len() as u64).consensus_encode(&mut buf); - // Each header + trailing zero - for header in headers { - let _ = header.consensus_encode(&mut buf); - buf.push(0u8); - } - buf - } - - /// Deserialize headers with trailing zero byte per header (matches HeaderDeserializationWrapper). - fn deserialize_headers(payload: &[u8]) -> NetworkResult> { - let mut cursor = std::io::Cursor::new(payload); - let count = dashcore::VarInt::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode headers count: {}", e)) - })?; - - let mut headers = Vec::with_capacity(count.0 as usize); - for _ in 0..count.0 { - let header = dashcore::block::Header::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode header: {}", e)) - })?; - headers.push(header); - // Read and discard the trailing zero byte - let _trailing = u8::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode header trailing byte: {}", e)) - })?; - } - Ok(headers) - } - /// Encode a NetworkMessage into V2 plaintext format. /// /// Format: /// - Short format (common messages): payload bytes (header byte added by cipher) /// - Extended format (Dash-specific): 12-byte command + payload bytes fn encode_message(&self, message: &NetworkMessage) -> NetworkResult> { - // Serialize the message payload - let payload = Self::serialize_payload(message); + // Serialize the message payload using dashcore's canonical serialization + let payload = message.consensus_encode_payload(); // Check for short message ID if let Some(short_id) = network_message_to_short_id(message) { @@ -237,7 +140,7 @@ impl V2Transport { message_id ); - if message_id == MSG_ID_EXTENDED { + let (cmd, payload) = if message_id == MSG_ID_EXTENDED { // Extended format: 12-byte command + payload (starting at byte 2) if plaintext.len() < 2 + COMMAND_LEN { return Err(NetworkError::ProtocolError( @@ -261,8 +164,7 @@ impl V2Transport { self.peer_address ); - // Decode the NetworkMessage based on command - self.decode_by_command(cmd, payload) + (cmd, payload) } else { // Short format: message_id is the short message ID, payload starts at byte 2 let payload = &plaintext[2..]; @@ -279,364 +181,12 @@ impl V2Transport { self.peer_address ); - self.decode_by_command(cmd, payload) - } - } - - /// Decode a NetworkMessage from command string and payload bytes. - fn decode_by_command(&self, cmd: &str, payload: &[u8]) -> NetworkResult { - // Create a cursor for decoding - let mut cursor = std::io::Cursor::new(payload); - - // Decode based on command - // Note: This mirrors the NetworkMessage variants and their Decodable impls - let message = match cmd { - "addr" => { - let addrs: Vec<(u32, dashcore::network::address::Address)> = - Decodable::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode addr: {}", e)) - })?; - NetworkMessage::Addr(addrs) - } - "block" => { - let block = dashcore::Block::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode block: {}", e)) - })?; - NetworkMessage::Block(block) - } - "blocktxn" => { - let blocktxn = - dashcore::network::message_compact_blocks::BlockTxn::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode blocktxn: {}", e)) - })?; - NetworkMessage::BlockTxn(blocktxn) - } - "cmpctblock" => { - let cmpctblock = - dashcore::network::message_compact_blocks::CmpctBlock::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode cmpctblock: {}", e)) - })?; - NetworkMessage::CmpctBlock(cmpctblock) - } - "feefilter" => { - let fee = i64::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode feefilter: {}", e)) - })?; - NetworkMessage::FeeFilter(fee) - } - "filteradd" => { - let filteradd = - dashcore::network::message_bloom::FilterAdd::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode filteradd: {}", - e - )) - })?; - NetworkMessage::FilterAdd(filteradd) - } - "filterclear" => NetworkMessage::FilterClear, - "filterload" => { - let filterload = - dashcore::network::message_bloom::FilterLoad::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode filterload: {}", - e - )) - })?; - NetworkMessage::FilterLoad(filterload) - } - "getblocks" => { - let getblocks = - dashcore::network::message_blockdata::GetBlocksMessage::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode getblocks: {}", e)) - })?; - NetworkMessage::GetBlocks(getblocks) - } - "getblocktxn" => { - let getblocktxn = - dashcore::network::message_compact_blocks::GetBlockTxn::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode getblocktxn: {}", e)) - })?; - NetworkMessage::GetBlockTxn(getblocktxn) - } - "getdata" => { - let inv: Vec = - Decodable::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode getdata: {}", e)) - })?; - NetworkMessage::GetData(inv) - } - "getheaders" => { - let getheaders = - dashcore::network::message_blockdata::GetHeadersMessage::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode getheaders: {}", e)) - })?; - NetworkMessage::GetHeaders(getheaders) - } - "headers" => { - // Headers have special deserialization (VarInt count + each header + trailing zero) - let headers = Self::deserialize_headers(payload)?; - NetworkMessage::Headers(headers) - } - "inv" => { - let inv: Vec = - Decodable::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode inv: {}", e)) - })?; - NetworkMessage::Inv(inv) - } - "mempool" => NetworkMessage::MemPool, - "merkleblock" => { - let merkleblock = - dashcore::MerkleBlock::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode merkleblock: {}", e)) - })?; - NetworkMessage::MerkleBlock(merkleblock) - } - "notfound" => { - let inv: Vec = - Decodable::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode notfound: {}", e)) - })?; - NetworkMessage::NotFound(inv) - } - "ping" => { - let nonce = u64::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode ping: {}", e)) - })?; - NetworkMessage::Ping(nonce) - } - "pong" => { - let nonce = u64::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode pong: {}", e)) - })?; - NetworkMessage::Pong(nonce) - } - "sendcmpct" => { - let sendcmpct = - dashcore::network::message_compact_blocks::SendCmpct::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode sendcmpct: {}", e)) - })?; - NetworkMessage::SendCmpct(sendcmpct) - } - "tx" => { - let tx = dashcore::Transaction::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode tx: {}", e)) - })?; - NetworkMessage::Tx(tx) - } - "getcfilters" => { - let getcfilters = - dashcore::network::message_filter::GetCFilters::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode getcfilters: {}", e)) - })?; - NetworkMessage::GetCFilters(getcfilters) - } - "cfilter" => { - let cfilter = - dashcore::network::message_filter::CFilter::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode cfilter: {}", e)) - })?; - NetworkMessage::CFilter(cfilter) - } - "getcfheaders" => { - let getcfheaders = - dashcore::network::message_filter::GetCFHeaders::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode getcfheaders: {}", - e - )) - })?; - NetworkMessage::GetCFHeaders(getcfheaders) - } - "cfheaders" => { - let cfheaders = - dashcore::network::message_filter::CFHeaders::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode cfheaders: {}", - e - )) - })?; - NetworkMessage::CFHeaders(cfheaders) - } - "getcfcheckpt" => { - let getcfcheckpt = - dashcore::network::message_filter::GetCFCheckpt::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode getcfcheckpt: {}", - e - )) - })?; - NetworkMessage::GetCFCheckpt(getcfcheckpt) - } - "cfcheckpt" => { - let cfcheckpt = - dashcore::network::message_filter::CFCheckpt::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode cfcheckpt: {}", - e - )) - })?; - NetworkMessage::CFCheckpt(cfcheckpt) - } - // Dash-specific messages (extended format) - "version" => { - let version = dashcore::network::message_network::VersionMessage::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode version: {}", e)) - })?; - NetworkMessage::Version(version) - } - "verack" => NetworkMessage::Verack, - "sendheaders" => NetworkMessage::SendHeaders, - "getaddr" => NetworkMessage::GetAddr, - "wtxidrelay" => NetworkMessage::WtxidRelay, - "sendaddrv2" => NetworkMessage::SendAddrV2, - "addrv2" => { - let addrs: Vec = - Decodable::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode addrv2: {}", e)) - })?; - NetworkMessage::AddrV2(addrs) - } - "reject" => { - let reject = - dashcore::network::message_network::Reject::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode reject: {}", e)) - })?; - NetworkMessage::Reject(reject) - } - // Dash-specific extended messages - "mnlistdiff" => { - let mnlistdiff = - dashcore::network::message_sml::MnListDiff::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode mnlistdiff: {}", - e - )) - })?; - NetworkMessage::MnListDiff(mnlistdiff) - } - "getmnlistd" => { - let getmnlistd = - dashcore::network::message_sml::GetMnListDiff::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode getmnlistd: {}", - e - )) - })?; - NetworkMessage::GetMnListD(getmnlistd) - } - "qrinfo" => { - let qrinfo = - dashcore::network::message_qrinfo::QRInfo::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode qrinfo: {}", e)) - })?; - NetworkMessage::QRInfo(qrinfo) - } - "getqrinfo" => { - let getqrinfo = - dashcore::network::message_qrinfo::GetQRInfo::consensus_decode(&mut cursor) - .map_err(|e| { - NetworkError::ProtocolError(format!( - "Failed to decode getqrinfo: {}", - e - )) - })?; - NetworkMessage::GetQRInfo(getqrinfo) - } - "clsig" => { - let clsig = dashcore::ChainLock::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode clsig: {}", e)) - })?; - NetworkMessage::CLSig(clsig) - } - "isdlock" => { - let islock = dashcore::InstantLock::consensus_decode(&mut cursor).map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode isdlock: {}", e)) - })?; - NetworkMessage::ISLock(islock) - } - "headers2" => { - let headers2 = - dashcore::network::message_headers2::Headers2Message::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode headers2: {}", e)) - })?; - NetworkMessage::Headers2(headers2) - } - "getheaders2" => { - // getheaders2 uses same format as getheaders - let getheaders2 = - dashcore::network::message_blockdata::GetHeadersMessage::consensus_decode( - &mut cursor, - ) - .map_err(|e| { - NetworkError::ProtocolError(format!("Failed to decode getheaders2: {}", e)) - })?; - NetworkMessage::GetHeaders2(getheaders2) - } - "sendheaders2" => NetworkMessage::SendHeaders2, - "senddsq" => { - // SendDsq is a single bool (serialized as u8) - let wants_dsq = if payload.is_empty() { - false - } else { - payload[0] != 0 - }; - NetworkMessage::SendDsq(wants_dsq) - } - // Unknown command - use Unknown variant - _ => { - tracing::warn!( - "V2Transport: Unknown command '{}' from {}, storing as raw bytes", - cmd, - self.peer_address - ); - NetworkMessage::Unknown { - command: CommandString::try_from(cmd.to_string()).unwrap_or_else(|_| { - CommandString::try_from("unknown".to_string()).expect("valid") - }), - payload: payload.to_vec(), - } - } + (cmd, payload) }; - Ok(message) + // Decode the NetworkMessage using dashcore's canonical decoder + NetworkMessage::consensus_decode_payload(cmd, payload) + .map_err(|e| NetworkError::ProtocolError(format!("Failed to decode '{}': {}", cmd, e))) } /// Helper function to read some bytes into the receive buffer. diff --git a/dash/src/network/message.rs b/dash/src/network/message.rs index 786b664fa..e16498693 100644 --- a/dash/src/network/message.rs +++ b/dash/src/network/message.rs @@ -352,45 +352,21 @@ impl NetworkMessage { _ => CommandString::try_from_static(self.cmd()).expect("cmd returns valid commands"), } } -} -impl RawNetworkMessage { - /// Return the message command as a static string reference. + /// Serialize the message payload without V1 framing (magic/command/checksum). /// - /// This returns `"unknown"` for [NetworkMessage::Unknown], - /// regardless of the actual command in the unknown message. - /// Use the [Self::command] method to get the command for unknown messages. - pub fn cmd(&self) -> &'static str { - self.payload.cmd() - } - - /// Return the CommandString for the message command. - pub fn command(&self) -> CommandString { - self.payload.command() - } -} - -struct HeaderSerializationWrapper<'a>(&'a Vec); - -impl<'a> Encodable for HeaderSerializationWrapper<'a> { - #[inline] - fn consensus_encode(&self, w: &mut W) -> Result { - let mut len = 0; - len += VarInt(self.0.len() as u64).consensus_encode(w)?; - for header in self.0.iter() { - len += header.consensus_encode(w)?; - len += 0u8.consensus_encode(w)?; - } - Ok(len) - } -} - -impl Encodable for RawNetworkMessage { - fn consensus_encode(&self, w: &mut W) -> Result { - let mut len = 0; - len += self.magic.consensus_encode(w)?; - len += self.command().consensus_encode(w)?; - len += CheckedData(match self.payload { + /// This method returns the raw serialized bytes of the message payload, + /// suitable for use with V2 (BIP324) transport or other protocols that + /// handle framing separately. + /// + /// # Note on Headers serialization + /// + /// The `Headers` message is serialized with a trailing zero byte after each + /// header, representing an empty transaction count (VarInt). This matches + /// the Bitcoin/Dash protocol where headers messages reuse the block + /// serialization format but with no transactions. + pub fn consensus_encode_payload(&self) -> Vec { + match *self { NetworkMessage::Version(ref dat) => serialize(dat), NetworkMessage::Addr(ref dat) => serialize(dat), NetworkMessage::Inv(ref dat) => serialize(dat), @@ -420,8 +396,19 @@ impl Encodable for RawNetworkMessage { NetworkMessage::BlockTxn(ref dat) => serialize(dat), NetworkMessage::Alert(ref dat) => serialize(dat), NetworkMessage::Reject(ref dat) => serialize(dat), - NetworkMessage::FeeFilter(ref data) => serialize(data), + NetworkMessage::FeeFilter(ref dat) => serialize(dat), NetworkMessage::AddrV2(ref dat) => serialize(dat), + NetworkMessage::GetMnListD(ref dat) => serialize(dat), + NetworkMessage::MnListDiff(ref dat) => serialize(dat), + NetworkMessage::GetQRInfo(ref dat) => serialize(dat), + NetworkMessage::QRInfo(ref dat) => serialize(dat), + NetworkMessage::CLSig(ref dat) => serialize(dat), + NetworkMessage::ISLock(ref dat) => serialize(dat), + NetworkMessage::SendDsq(wants_dsq) => serialize(&(wants_dsq as u8)), + NetworkMessage::Unknown { + payload: ref data, + .. + } => serialize(data), NetworkMessage::Verack | NetworkMessage::SendHeaders | NetworkMessage::SendHeaders2 @@ -430,81 +417,29 @@ impl Encodable for RawNetworkMessage { | NetworkMessage::WtxidRelay | NetworkMessage::FilterClear | NetworkMessage::SendAddrV2 => vec![], - NetworkMessage::Unknown { - payload: ref data, - .. - } => serialize(data), - NetworkMessage::GetMnListD(ref dat) => serialize(dat), - NetworkMessage::MnListDiff(ref dat) => serialize(dat), - NetworkMessage::GetQRInfo(ref dat) => serialize(dat), - NetworkMessage::QRInfo(ref dat) => serialize(dat), - NetworkMessage::CLSig(ref dat) => serialize(dat), - NetworkMessage::ISLock(ref dat) => serialize(dat), - NetworkMessage::SendDsq(wants_dsq) => serialize(&(wants_dsq as u8)), - }) - .consensus_encode(w)?; - Ok(len) - } -} - -struct HeaderDeserializationWrapper(Vec); - -impl Decodable for HeaderDeserializationWrapper { - #[inline] - fn consensus_decode_from_finite_reader( - r: &mut R, - ) -> Result { - let len = VarInt::consensus_decode(r)?.0; - // should be above usual number of items to avoid - // allocation - let mut ret = Vec::with_capacity(core::cmp::min(1024 * 16, len as usize)); - for _ in 0..len { - ret.push(Decodable::consensus_decode(r)?); - if u8::consensus_decode(r)? != 0u8 { - return Err(encode::Error::ParseFailed( - "Headers message should not contain transactions", - )); - } } - Ok(HeaderDeserializationWrapper(ret)) } - #[inline] - fn consensus_decode(r: &mut R) -> Result { - Self::consensus_decode_from_finite_reader(r.take(MAX_MSG_SIZE as u64).by_ref()) - } -} - -impl Decodable for RawNetworkMessage { - fn consensus_decode_from_finite_reader( - r: &mut R, - ) -> Result { - let magic = Decodable::consensus_decode_from_finite_reader(r)?; - let cmd = CommandString::consensus_decode_from_finite_reader(r)?; - let raw_payload = match CheckedData::consensus_decode_from_finite_reader(r) { - Ok(cd) => cd.0, - Err(encode::Error::InvalidChecksum { - expected, - actual, - }) => { - // Include message command and magic in logging to aid diagnostics - log::warn!( - "Invalid payload checksum for network message '{}' (magic {:#x}): expected {:02x?}, actual {:02x?}", - cmd.0, - magic, - expected, - actual - ); - return Err(encode::Error::InvalidChecksum { - expected, - actual, - }); - } - Err(e) => return Err(e), - }; - - let mut mem_d = io::Cursor::new(raw_payload); - let payload = match &cmd.0[..] { + /// Decode a message payload from raw bytes given a command string. + /// + /// This method decodes the raw payload bytes into a `NetworkMessage` variant + /// based on the command string. It handles all standard Bitcoin and Dash-specific + /// message types, including special cases like `Headers` which has trailing + /// transaction count bytes. + /// + /// This is the inverse of [`consensus_encode_payload`], suitable for use with + /// V2 (BIP324) transport or other protocols that handle framing separately. + /// + /// # Arguments + /// * `cmd` - The command string identifying the message type (e.g., "version", "headers") + /// * `payload` - The raw payload bytes to decode + /// + /// # Returns + /// * `Ok(NetworkMessage)` - Successfully decoded message + /// * `Err(encode::Error)` - Decoding failed + pub fn consensus_decode_payload(cmd: &str, payload: &[u8]) -> Result { + let mut mem_d = io::Cursor::new(payload); + let message = match cmd { "version" => { NetworkMessage::Version(Decodable::consensus_decode_from_finite_reader(&mut mem_d)?) } @@ -650,10 +585,113 @@ impl Decodable for RawNetworkMessage { NetworkMessage::SendDsq(byte != 0) } _ => NetworkMessage::Unknown { - command: cmd, - payload: mem_d.into_inner(), + command: CommandString::try_from(cmd.to_string()) + .map_err(|_| encode::Error::ParseFailed("Invalid command string"))?, + payload: payload.to_vec(), }, }; + Ok(message) + } +} + +impl RawNetworkMessage { + /// Return the message command as a static string reference. + /// + /// This returns `"unknown"` for [NetworkMessage::Unknown], + /// regardless of the actual command in the unknown message. + /// Use the [Self::command] method to get the command for unknown messages. + pub fn cmd(&self) -> &'static str { + self.payload.cmd() + } + + /// Return the CommandString for the message command. + pub fn command(&self) -> CommandString { + self.payload.command() + } +} + +struct HeaderSerializationWrapper<'a>(&'a Vec); + +impl<'a> Encodable for HeaderSerializationWrapper<'a> { + #[inline] + fn consensus_encode(&self, w: &mut W) -> Result { + let mut len = 0; + len += VarInt(self.0.len() as u64).consensus_encode(w)?; + for header in self.0.iter() { + len += header.consensus_encode(w)?; + len += 0u8.consensus_encode(w)?; + } + Ok(len) + } +} + +impl Encodable for RawNetworkMessage { + fn consensus_encode(&self, w: &mut W) -> Result { + let mut len = 0; + len += self.magic.consensus_encode(w)?; + len += self.command().consensus_encode(w)?; + len += CheckedData(self.payload.consensus_encode_payload()).consensus_encode(w)?; + Ok(len) + } +} + +struct HeaderDeserializationWrapper(Vec); + +impl Decodable for HeaderDeserializationWrapper { + #[inline] + fn consensus_decode_from_finite_reader( + r: &mut R, + ) -> Result { + let len = VarInt::consensus_decode(r)?.0; + // should be above usual number of items to avoid + // allocation + let mut ret = Vec::with_capacity(core::cmp::min(1024 * 16, len as usize)); + for _ in 0..len { + ret.push(Decodable::consensus_decode(r)?); + if u8::consensus_decode(r)? != 0u8 { + return Err(encode::Error::ParseFailed( + "Headers message should not contain transactions", + )); + } + } + Ok(HeaderDeserializationWrapper(ret)) + } + + #[inline] + fn consensus_decode(r: &mut R) -> Result { + Self::consensus_decode_from_finite_reader(r.take(MAX_MSG_SIZE as u64).by_ref()) + } +} + +impl Decodable for RawNetworkMessage { + fn consensus_decode_from_finite_reader( + r: &mut R, + ) -> Result { + let magic = Decodable::consensus_decode_from_finite_reader(r)?; + let cmd = CommandString::consensus_decode_from_finite_reader(r)?; + let raw_payload = match CheckedData::consensus_decode_from_finite_reader(r) { + Ok(cd) => cd.0, + Err(encode::Error::InvalidChecksum { + expected, + actual, + }) => { + // Include message command and magic in logging to aid diagnostics + log::warn!( + "Invalid payload checksum for network message '{}' (magic {:#x}): expected {:02x?}, actual {:02x?}", + cmd.0, + magic, + expected, + actual + ); + return Err(encode::Error::InvalidChecksum { + expected, + actual, + }); + } + Err(e) => return Err(e), + }; + + let payload = NetworkMessage::consensus_decode_payload(&cmd.0, &raw_payload)?; Ok(RawNetworkMessage { magic, payload, From c4c99ec2e585060ea45d67d934de78dd6c12fbf7 Mon Sep 17 00:00:00 2001 From: pasta Date: Thu, 11 Dec 2025 09:45:00 -0600 Subject: [PATCH 3/9] chore(dash-spv): Remove unused TransportEstablishResult enum MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dead code that was never integrated - V2HandshakeResult already serves this purpose in the actual transport establishment flow. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash-spv/src/network/transport/mod.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/dash-spv/src/network/transport/mod.rs b/dash-spv/src/network/transport/mod.rs index 0ee7eea29..9af05d2fc 100644 --- a/dash-spv/src/network/transport/mod.rs +++ b/dash-spv/src/network/transport/mod.rs @@ -29,14 +29,6 @@ pub enum TransportPreference { V1Only, } -/// Result of establishing a transport connection. -pub enum TransportEstablishResult { - /// Successfully established V1 transport. - V1(V1Transport), - /// Need to fallback to V1 (V2 handshake detected V1-only peer). - FallbackToV1, -} - /// Abstract transport layer for P2P communication. /// /// This trait is implemented by both V1Transport (unencrypted) and From f20538e4a4edfe67864cd0efd08101fe88b22e22 Mon Sep 17 00:00:00 2001 From: pasta Date: Thu, 11 Dec 2025 10:29:29 -0600 Subject: [PATCH 4/9] fix(dash-spv): Improve V2 handshake session_id and decoy handling MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use cipher.id() for actual BIP324 session ID instead of zeros - Handle decoy packets by looping until genuine version packet received - Add documentation explaining why we use bip324's low-level API 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- .../src/network/transport/v2_handshake.rs | 222 ++++++++++-------- 1 file changed, 125 insertions(+), 97 deletions(-) diff --git a/dash-spv/src/network/transport/v2_handshake.rs b/dash-spv/src/network/transport/v2_handshake.rs index 29ef1003c..9c1fd3636 100644 --- a/dash-spv/src/network/transport/v2_handshake.rs +++ b/dash-spv/src/network/transport/v2_handshake.rs @@ -6,6 +6,26 @@ //! //! The handshake detects v1-only peers by checking if the first bytes //! received match the network magic (indicating v1 protocol). +//! +//! ## Why Not Use `bip324::futures::handshake()`? +//! +//! The bip324 crate provides a high-level `futures::handshake()` function, but +//! it doesn't meet dash-spv's requirements: +//! +//! 1. **V1 Detection Strategy**: bip324 detects V1-only peers *after* reading the +//! 64-byte remote key (consuming the bytes). dash-spv uses `stream.peek()` to +//! detect V1 magic *without* consuming bytes, allowing the same TCP connection +//! to be reused for V1 fallback. +//! +//! 2. **Return Type Mismatch**: bip324 returns split ciphers and a wrapped +//! `ProtocolSessionReader`. dash-spv needs the original `TcpStream` back +//! plus a `CipherSession` for the transport layer. +//! +//! 3. **Timeout Handling**: bip324's async handshake has no built-in timeouts. +//! dash-spv needs per-operation and cumulative timeout handling. +//! +//! Therefore, we use bip324's low-level `Handshake` state machine with custom +//! async I/O wrappers that provide the control we need. use std::net::SocketAddr; use std::time::Duration; @@ -259,110 +279,118 @@ impl V2HandshakeManager { // Keep any remaining bytes after the garbage let remaining = garbage_buffer[consumed_bytes..].to_vec(); - // Step 6: Receive version packet - // The version packet follows the garbage terminator + // Step 6: Receive version packet (may be preceded by decoy packets) + // Loop until we receive the genuine version packet let mut handshake = handshake; - - // Read version packet: 3-byte length + encrypted content - let mut version_data = remaining; - - // Read at least 3 bytes for the length prefix - while version_data.len() < 3 { - let mut more = [0u8; 64]; - let n = tokio::time::timeout( - HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), - stream.read(&mut more), - ) - .await - .map_err(|_| NetworkError::Timeout)? - .map_err(|e| { - NetworkError::V2HandshakeFailed(format!( - "Failed to read version packet length: {}", - e - )) - })?; - if n == 0 { - return Err(NetworkError::V2HandshakeFailed( - "Connection closed before version packet".to_string(), - )); + let mut leftover_data = remaining; + + loop { + // Read at least 3 bytes for the length prefix + while leftover_data.len() < 3 { + let mut more = [0u8; 64]; + let n = tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut more), + ) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to read packet length: {}", + e + )) + })?; + if n == 0 { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed before version packet".to_string(), + )); + } + leftover_data.extend_from_slice(&more[..n]); } - version_data.extend_from_slice(&more[..n]); - } - // Decrypt the packet length (first 3 bytes) - let length_bytes: [u8; 3] = version_data[..3].try_into().map_err(|_| { - NetworkError::V2HandshakeFailed( - "Failed to extract length bytes".to_string(), - ) - })?; - let packet_len = handshake.decrypt_packet_len(length_bytes).map_err(|e| { - NetworkError::V2HandshakeFailed(format!( - "Failed to decrypt packet length: {}", - e - )) - })?; + // Decrypt the packet length (first 3 bytes) + let length_bytes: [u8; 3] = + leftover_data[..3].try_into().map_err(|_| { + NetworkError::V2HandshakeFailed( + "Failed to extract length bytes".to_string(), + ) + })?; + let packet_len = + handshake.decrypt_packet_len(length_bytes).map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to decrypt packet length: {}", + e + )) + })?; + + tracing::debug!( + "V2 handshake: Packet length is {} bytes from {}", + packet_len, + self.peer_address + ); + + // Read more data if needed to have the full packet + let total_needed = 3 + packet_len; // length prefix + packet content + while leftover_data.len() < total_needed { + let mut more = [0u8; 64]; + let n = tokio::time::timeout( + HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), + stream.read(&mut more), + ) + .await + .map_err(|_| NetworkError::Timeout)? + .map_err(|e| { + NetworkError::V2HandshakeFailed(format!( + "Failed to read packet content: {}", + e + )) + })?; + if n == 0 { + return Err(NetworkError::V2HandshakeFailed( + "Connection closed before packet complete".to_string(), + )); + } + leftover_data.extend_from_slice(&more[..n]); + } - tracing::debug!( - "V2 handshake: Version packet length is {} bytes from {}", - packet_len, - self.peer_address - ); + // Extract just the packet content (excluding the 3-byte length prefix) + let mut packet_content = leftover_data[3..3 + packet_len].to_vec(); - // Read more data if needed to have the full packet - let total_needed = 3 + packet_len; // length prefix + packet content - while version_data.len() < total_needed { - let mut more = [0u8; 64]; - let n = tokio::time::timeout( - HANDSHAKE_TIMEOUT.saturating_sub(scan_start.elapsed()), - stream.read(&mut more), - ) - .await - .map_err(|_| NetworkError::Timeout)? - .map_err(|e| { - NetworkError::V2HandshakeFailed(format!( - "Failed to read version packet content: {}", - e - )) - })?; - if n == 0 { - return Err(NetworkError::V2HandshakeFailed( - "Connection closed before version packet complete".to_string(), - )); - } - version_data.extend_from_slice(&more[..n]); - } + // Keep any data after this packet for the next iteration + leftover_data = leftover_data[3 + packet_len..].to_vec(); - // Extract just the packet content (excluding the 3-byte length prefix) - let mut packet_content = version_data[3..3 + packet_len].to_vec(); - - // Process version packet - match handshake.receive_version(&mut packet_content) { - Ok(VersionResult::Complete { - cipher, - }) => { - tracing::info!( - "V2 handshake: Completed successfully with {}", - self.peer_address - ); - - return Ok(V2HandshakeResult::Success(V2Session { - stream, + // Process packet + match handshake.receive_version(&mut packet_content) { + Ok(VersionResult::Complete { cipher, - session_id: [0u8; 32], // TODO: Get actual session ID - })); - } - Ok(VersionResult::Decoy(_handshake)) => { - // Received a decoy packet, need to continue reading - // For now, treat as error (can be enhanced later) - return Err(NetworkError::V2HandshakeFailed( - "Received decoy packet - not yet supported".to_string(), - )); - } - Err(e) => { - return Err(NetworkError::V2HandshakeFailed(format!( - "Failed to process version packet: {}", - e - ))); + }) => { + tracing::info!( + "V2 handshake: Completed successfully with {}", + self.peer_address + ); + + let session_id = *cipher.id(); + return Ok(V2HandshakeResult::Success(V2Session { + stream, + cipher, + session_id, + })); + } + Ok(VersionResult::Decoy(next_handshake)) => { + // Received a decoy packet, continue reading for version packet + tracing::debug!( + "V2 handshake: Received decoy packet from {}, continuing", + self.peer_address + ); + handshake = next_handshake; + // Continue loop to read next packet + } + Err(e) => { + return Err(NetworkError::V2HandshakeFailed(format!( + "Failed to process packet: {}", + e + ))); + } } } } From dba0deae09486fe9a68f3276c0292f2272558c54 Mon Sep 17 00:00:00 2001 From: pasta Date: Thu, 11 Dec 2025 10:59:16 -0600 Subject: [PATCH 5/9] refactor(dash-spv): Remove redundant garbage buffer size check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit bip324 crate already enforces the max garbage limit internally and returns NoGarbageTerminator error if exceeded. No need to duplicate the check in dash-spv. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash-spv/src/network/transport/v2_handshake.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/dash-spv/src/network/transport/v2_handshake.rs b/dash-spv/src/network/transport/v2_handshake.rs index 9c1fd3636..66590b98b 100644 --- a/dash-spv/src/network/transport/v2_handshake.rs +++ b/dash-spv/src/network/transport/v2_handshake.rs @@ -395,13 +395,9 @@ impl V2HandshakeManager { } } Ok(GarbageResult::NeedMoreData(hs)) => { + // Continue reading more data; bip324 enforces the max garbage limit + // internally and will return NoGarbageTerminator if exceeded handshake_state = hs; - // Continue reading more data - if garbage_buffer.len() > MAX_GARBAGE_LEN + GARBAGE_TERMINATOR_SIZE { - return Err(NetworkError::V2HandshakeFailed( - "Garbage terminator not found within limit".to_string(), - )); - } } Err(e) => { return Err(NetworkError::V2HandshakeFailed(format!( From d0842f0314075f6576edf0e978eea47d35ba6e6b Mon Sep 17 00:00:00 2001 From: pasta Date: Thu, 11 Dec 2025 11:58:50 -0600 Subject: [PATCH 6/9] fix(dash-spv): Resolve clippy warnings in network module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove unused consecutive_resyncs field from Peer (V1Transport has its own) - Simplify match_single_binding to direct let binding - Box V2Session in V2HandshakeResult to fix large_enum_variant warning 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash-spv/src/network/peer.rs | 8 +------- dash-spv/src/network/transport/v2_handshake.rs | 6 +++--- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/dash-spv/src/network/peer.rs b/dash-spv/src/network/peer.rs index 2c0100c07..cb0107c4c 100644 --- a/dash-spv/src/network/peer.rs +++ b/dash-spv/src/network/peer.rs @@ -35,8 +35,6 @@ pub struct Peer { relay: Option, prefers_headers2: bool, sent_sendheaders2: bool, - // Basic telemetry for resync events - consecutive_resyncs: u32, // Transport protocol version used (1 or 2) transport_version: u8, } @@ -70,7 +68,6 @@ impl Peer { relay: None, prefers_headers2: false, sent_sendheaders2: false, - consecutive_resyncs: 0, transport_version: 1, } } @@ -112,9 +109,7 @@ impl Peer { "Connecting to {} using V2 transport (BIP324 encrypted, with V1 fallback)", address ); - match Self::try_v2_with_fallback(address, timeout, network).await? { - (transport, version) => (transport, version), - } + Self::try_v2_with_fallback(address, timeout, network).await? } }; @@ -140,7 +135,6 @@ impl Peer { relay: None, prefers_headers2: false, sent_sendheaders2: false, - consecutive_resyncs: 0, transport_version, }) } diff --git a/dash-spv/src/network/transport/v2_handshake.rs b/dash-spv/src/network/transport/v2_handshake.rs index 66590b98b..e9cf2aae2 100644 --- a/dash-spv/src/network/transport/v2_handshake.rs +++ b/dash-spv/src/network/transport/v2_handshake.rs @@ -52,7 +52,7 @@ const GARBAGE_TERMINATOR_SIZE: usize = 16; /// Result of the V2 handshake attempt. pub enum V2HandshakeResult { /// Successfully completed V2 handshake. - Success(V2Session), + Success(Box), /// Detected V1-only peer (first bytes matched network magic). FallbackToV1, } @@ -370,11 +370,11 @@ impl V2HandshakeManager { ); let session_id = *cipher.id(); - return Ok(V2HandshakeResult::Success(V2Session { + return Ok(V2HandshakeResult::Success(Box::new(V2Session { stream, cipher, session_id, - })); + }))); } Ok(VersionResult::Decoy(next_handshake)) => { // Received a decoy packet, continue reading for version packet From 1fcf235eeddfed2e9c825128375a27aef6afbaee Mon Sep 17 00:00:00 2001 From: pasta Date: Thu, 11 Dec 2025 17:51:04 -0600 Subject: [PATCH 7/9] test: Add BIP324 V2 transport tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add comprehensive tests for the BIP324 V2 encrypted transport: - message.rs: 4 tests for consensus_encode_payload/consensus_decode_payload round-trips covering standard Bitcoin, Dash-specific, empty payload, and Headers special encoding - message_ids.rs: 1 test for bidirectional short ID mapping consistency - v2.rs: 2 tests for V2 message framing (short format and extended format) - handshake_test.rs: 1 integration test for V2Preferred fallback behavior 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash-spv/src/network/transport/message_ids.rs | 38 ++++ dash-spv/src/network/transport/v2.rs | 142 ++++++++++++++ dash-spv/tests/handshake_test.rs | 58 ++++++ dash/src/network/message.rs | 182 ++++++++++++++++++ 4 files changed, 420 insertions(+) diff --git a/dash-spv/src/network/transport/message_ids.rs b/dash-spv/src/network/transport/message_ids.rs index e5691f415..874238fa0 100644 --- a/dash-spv/src/network/transport/message_ids.rs +++ b/dash-spv/src/network/transport/message_ids.rs @@ -293,4 +293,42 @@ mod tests { assert!(network_message_to_short_id(&NetworkMessage::Version(version)).is_none()); } + + #[test] + fn test_short_id_to_command_bidirectional_consistency() { + // For messages that have short IDs, verify the command string matches + // what dashcore returns via cmd() + let test_cases: Vec<(NetworkMessage, u8)> = vec![ + (NetworkMessage::Ping(0), MSG_ID_PING), + (NetworkMessage::Pong(0), MSG_ID_PONG), + (NetworkMessage::Inv(vec![]), MSG_ID_INV), + (NetworkMessage::GetData(vec![]), MSG_ID_GETDATA), + (NetworkMessage::MemPool, MSG_ID_MEMPOOL), + (NetworkMessage::FilterClear, MSG_ID_FILTERCLEAR), + (NetworkMessage::SendHeaders2, MSG_ID_SENDHEADERS2), + (NetworkMessage::SendDsq(false), MSG_ID_SENDDSQUEUE), + ]; + + for (msg, expected_id) in test_cases { + // Verify network_message_to_short_id returns the expected ID + let short_id = network_message_to_short_id(&msg); + assert_eq!( + short_id, + Some(expected_id), + "Message {} should have short ID {}", + msg.cmd(), + expected_id + ); + + // Verify short_id_to_command returns the correct command + let cmd = short_id_to_command(expected_id); + assert_eq!( + cmd, + Some(msg.cmd()), + "Short ID {} should map to command '{}'", + expected_id, + msg.cmd() + ); + } + } } diff --git a/dash-spv/src/network/transport/v2.rs b/dash-spv/src/network/transport/v2.rs index ef161a783..0018af23c 100644 --- a/dash-spv/src/network/transport/v2.rs +++ b/dash-spv/src/network/transport/v2.rs @@ -432,4 +432,146 @@ mod tests { assert!(network_message_to_short_id(&NetworkMessage::Ping(0)).is_some()); assert!(network_message_to_short_id(&NetworkMessage::Pong(0)).is_some()); } + + /// Helper: Encode a message the same way V2Transport::encode_message does + fn test_encode_v2_message(message: &NetworkMessage) -> Vec { + let payload = message.consensus_encode_payload(); + + if let Some(short_id) = network_message_to_short_id(message) { + // Short format: [short_id] + payload + let mut plaintext = Vec::with_capacity(1 + payload.len()); + plaintext.push(short_id); + plaintext.extend_from_slice(&payload); + plaintext + } else { + // Extended format: [0x00] + [12-byte command] + payload + let cmd = message.cmd(); + let cmd_bytes = cmd.as_bytes(); + let mut command = [0u8; COMMAND_LEN]; + let copy_len = std::cmp::min(cmd_bytes.len(), COMMAND_LEN); + command[..copy_len].copy_from_slice(&cmd_bytes[..copy_len]); + + let mut plaintext = Vec::with_capacity(1 + COMMAND_LEN + payload.len()); + plaintext.push(MSG_ID_EXTENDED); + plaintext.extend_from_slice(&command); + plaintext.extend_from_slice(&payload); + plaintext + } + } + + /// Helper: Decode a V2 message with simulated cipher header byte + fn test_decode_v2_message(plaintext: &[u8]) -> Result { + // Simulate: prepend a packet type byte (0 for genuine) like the cipher does + let mut with_header = vec![0u8]; // Packet type = genuine + with_header.extend_from_slice(plaintext); + + if with_header.len() < 2 { + return Err(NetworkError::ProtocolError("V2 message too short".to_string())); + } + + let message_id = with_header[1]; + + if message_id == MSG_ID_EXTENDED { + // Extended format + if with_header.len() < 2 + COMMAND_LEN { + return Err(NetworkError::ProtocolError( + "Extended format message too short".to_string(), + )); + } + + let command_bytes = &with_header[2..2 + COMMAND_LEN]; + let cmd = std::str::from_utf8(command_bytes) + .map_err(|_| NetworkError::ProtocolError("Invalid UTF-8 in command".to_string()))? + .trim_end_matches('\0'); + + let payload = &with_header[2 + COMMAND_LEN..]; + NetworkMessage::consensus_decode_payload(cmd, payload) + .map_err(|e| NetworkError::ProtocolError(format!("Failed to decode: {}", e))) + } else { + // Short format + let cmd = short_id_to_command(message_id).ok_or_else(|| { + NetworkError::ProtocolError(format!("Unknown short ID: {}", message_id)) + })?; + + let payload = &with_header[2..]; + NetworkMessage::consensus_decode_payload(cmd, payload) + .map_err(|e| NetworkError::ProtocolError(format!("Failed to decode: {}", e))) + } + } + + #[test] + fn test_short_id_round_trip_common_messages() { + // Messages that should use short format (1 byte ID) + let short_format_messages: Vec = vec![ + NetworkMessage::Ping(0x1234567890abcdef), + NetworkMessage::Pong(0xfedcba0987654321), + NetworkMessage::Inv(vec![]), + NetworkMessage::GetData(vec![]), + NetworkMessage::NotFound(vec![]), + NetworkMessage::MemPool, + NetworkMessage::FilterClear, + NetworkMessage::SendHeaders2, + NetworkMessage::SendDsq(true), + ]; + + for original in &short_format_messages { + // Verify it uses short format (first byte is the short ID, not 0x00) + let encoded = test_encode_v2_message(original); + assert_ne!( + encoded[0], + MSG_ID_EXTENDED, + "{} should use short format, not extended", + original.cmd() + ); + + // Verify round-trip + let decoded = test_decode_v2_message(&encoded) + .expect(&format!("Failed to decode {}", original.cmd())); + assert_eq!(original, &decoded, "Round-trip failed for {} message", original.cmd()); + } + } + + #[test] + fn test_extended_format_round_trip() { + use dashcore::network::address::Address; + use dashcore::network::constants::ServiceFlags; + use dashcore::network::message_network::VersionMessage; + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + let addr = Address::new( + &SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8333), + ServiceFlags::NONE, + ); + + let version = VersionMessage { + version: 70015, + services: ServiceFlags::NONE, + timestamp: 0, + receiver: addr.clone(), + sender: addr, + nonce: 0, + user_agent: "/test/".to_string(), + start_height: 0, + relay: false, + mn_auth_challenge: [0u8; 32], + masternode_connection: false, + }; + + // Version message should use extended format (no short ID) + let original = NetworkMessage::Version(version); + + let encoded = test_encode_v2_message(&original); + + // Verify extended format: first byte should be 0x00 + assert_eq!(encoded[0], MSG_ID_EXTENDED, "Version message should use extended format"); + + // Verify command is in bytes 1-12 + let cmd_bytes = &encoded[1..1 + COMMAND_LEN]; + let cmd = std::str::from_utf8(cmd_bytes).unwrap().trim_end_matches('\0'); + assert_eq!(cmd, "version", "Command should be 'version'"); + + // Verify round-trip + let decoded = test_decode_v2_message(&encoded).expect("Failed to decode version message"); + assert_eq!(original, decoded, "Version round-trip failed"); + } } diff --git a/dash-spv/tests/handshake_test.rs b/dash-spv/tests/handshake_test.rs index a25f5dda5..e23ce8d2c 100644 --- a/dash-spv/tests/handshake_test.rs +++ b/dash-spv/tests/handshake_test.rs @@ -115,3 +115,61 @@ async fn test_multiple_connect_disconnect_cycles() { } } } + +// ============================================================================= +// BIP324 V2 Transport Integration Tests +// ============================================================================= + +/// Test V2Preferred mode which tries V2 first then falls back to V1. +/// This test verifies the fallback mechanism works correctly. +#[tokio::test] +async fn test_v2preferred_fallback_to_v1() { + let _ = env_logger::builder().filter_level(log::LevelFilter::Debug).is_test(true).try_init(); + + let peer_addr: SocketAddr = "127.0.0.1:9999".parse().expect("Valid peer address"); + let result = + Peer::connect(peer_addr, 10, Network::Dash, TransportPreference::V2Preferred).await; + + match result { + Ok(mut connection) => { + let transport_version = connection.transport_version(); + println!( + "✓ Connected to {} using V{} transport (V2Preferred mode)", + peer_addr, transport_version + ); + + // V2Preferred should use V2 if supported, V1 otherwise + // Most current nodes are V1-only, so we typically expect V1 + assert!( + transport_version == 1 || transport_version == 2, + "Transport version should be 1 or 2" + ); + + // Perform application-level handshake to verify transport works + let mut handshake_manager = HandshakeManager::new( + Network::Dash, + MempoolStrategy::BloomFilter, + Some("v2pref_test".parse().unwrap()), + ); + handshake_manager + .perform_handshake(&mut connection) + .await + .expect("Application handshake failed"); + + assert!(connection.is_connected(), "Should be connected after handshake"); + + // Verify peer info is populated + let peer_info = connection.peer_info(); + assert_eq!(peer_info.address, peer_addr); + assert!(peer_info.connected); + + connection.disconnect().await.expect("Failed to disconnect"); + println!("✓ V2Preferred test passed (used V{} transport)", transport_version); + } + Err(e) => { + println!("✗ Connection failed: {}", e); + println!("Note: This test requires a Dash Core node running at 127.0.0.1:9999"); + // Don't fail - node might not be available in CI + } + } +} diff --git a/dash/src/network/message.rs b/dash/src/network/message.rs index e16498693..11c7a3ea1 100644 --- a/dash/src/network/message.rs +++ b/dash/src/network/message.rs @@ -1084,4 +1084,186 @@ mod test { let msg = NetworkMessage::SendDsq(true); assert_eq!(msg.cmd(), "senddsq"); } + + // ========================================================================= + // V2 Transport Payload Encode/Decode Tests + // These tests verify consensus_encode_payload and consensus_decode_payload + // for use with BIP324 V2 encrypted transport. + // ========================================================================= + + /// Helper to test round-trip encoding/decoding for a message + fn test_payload_round_trip(msg: &NetworkMessage) { + let encoded = msg.consensus_encode_payload(); + let decoded = NetworkMessage::consensus_decode_payload(msg.cmd(), &encoded) + .expect(&format!("Failed to decode {} message", msg.cmd())); + assert_eq!(msg, &decoded, "Round-trip failed for {} message", msg.cmd()); + } + + #[test] + #[cfg(feature = "core-block-hash-use-x11")] + fn test_encode_decode_standard_bitcoin_messages() { + // Use deserialized test data to avoid complex construction + let tx: Transaction = deserialize(&hex!("0100000001a15d57094aa7a21a28cb20b59aab8fc7d1149a3bdbcddba9c622e4f5f6a99ece010000006c493046022100f93bb0e7d8db7bd46e40132d1f8242026e045f03a0efe71bbb8e3f475e970d790221009337cd7f1f929f00cc6ff01f03729b069a7c21b59b1736ddfee5db5946c5da8c0121033b9b137ee87d5a812d6f506efdd37f0affa7ffc310711c06c7f3e097c9447c52ffffffff0100e1f505000000001976a9140389035a9225b3839e2bbf32d826a1e222031fd888ac00000000")).unwrap(); + let block: Block = deserialize(&include_bytes!("../../tests/data/testnet_block_000000000000045e0b1660b6445b5e5c5ab63c9a4f956be7e1e69be04fa4497b.raw")[..]).unwrap(); + let header: block::Header = deserialize(&hex!("010000004ddccd549d28f385ab457e98d1b11ce80bfea2c5ab93015ade4973e400000000bf4473e53794beae34e64fccc471dace6ae544180816f89591894e0f417a914cd74d6e49ffff001d323b3a7b")).unwrap(); + + let inv = vec![Inventory::Transaction(hash([3u8; 32]).into())]; + + let messages: Vec = vec![ + NetworkMessage::Ping(0x1234567890abcdef), + NetworkMessage::Pong(0xfedcba0987654321), + NetworkMessage::Inv(inv.clone()), + NetworkMessage::GetData(inv.clone()), + NetworkMessage::NotFound(inv), + NetworkMessage::GetBlocks(GetBlocksMessage { + version: 70015, + locator_hashes: vec![hash_x11([4u8; 32]).into()], + stop_hash: hash_x11([5u8; 32]).into(), + }), + NetworkMessage::GetHeaders(GetHeadersMessage { + version: 70015, + locator_hashes: vec![hash_x11([6u8; 32]).into()], + stop_hash: hash_x11([7u8; 32]).into(), + }), + NetworkMessage::Headers(vec![header]), + NetworkMessage::Tx(tx), + NetworkMessage::Block(block), + NetworkMessage::FilterLoad(FilterLoad { + filter: vec![0x01, 0x02, 0x03], + hash_funcs: 11, + tweak: 0x12345678, + flags: BloomFlags::All, + }), + NetworkMessage::FilterAdd(FilterAdd { + data: vec![0xaa, 0xbb, 0xcc], + }), + NetworkMessage::SendCmpct(SendCmpct { + send_compact: true, + version: 1, + }), + NetworkMessage::GetCFilters(GetCFilters { + filter_type: 0, + start_height: 100, + stop_hash: hash_x11([8u8; 32]).into(), + }), + NetworkMessage::GetCFHeaders(GetCFHeaders { + filter_type: 0, + start_height: 100, + stop_hash: hash_x11([9u8; 32]).into(), + }), + NetworkMessage::GetCFCheckpt(GetCFCheckpt { + filter_type: 0, + stop_hash: hash_x11([10u8; 32]).into(), + }), + ]; + + for msg in &messages { + test_payload_round_trip(msg); + } + } + + #[test] + #[cfg(feature = "core-block-hash-use-x11")] + fn test_encode_decode_dash_specific_messages() { + use crate::bls_sig_utils::BLSSignature; + use crate::hash_types::CycleHash; + use crate::network::message_sml::GetMnListDiff; + use crate::{ChainLock, InstantLock}; + + let messages: Vec = vec![ + NetworkMessage::SendDsq(true), + NetworkMessage::SendDsq(false), + NetworkMessage::GetMnListD(GetMnListDiff { + base_block_hash: hash_x11([1u8; 32]).into(), + block_hash: hash_x11([2u8; 32]).into(), + }), + NetworkMessage::CLSig(ChainLock { + block_height: 123456, + block_hash: hash_x11([3u8; 32]).into(), + signature: BLSSignature::from([0u8; 96]), + }), + NetworkMessage::ISLock(InstantLock { + version: 1, + inputs: vec![], + txid: hash([4u8; 32]).into(), + cyclehash: CycleHash::from([5u8; 32]), + signature: BLSSignature::from([0u8; 96]), + }), + NetworkMessage::GetHeaders2(GetHeadersMessage { + version: 70015, + locator_hashes: vec![hash_x11([6u8; 32]).into()], + stop_hash: hash_x11([7u8; 32]).into(), + }), + NetworkMessage::SendHeaders2, + ]; + + for msg in &messages { + test_payload_round_trip(msg); + } + } + + #[test] + fn test_encode_decode_empty_payload_messages() { + let empty_payload_messages: Vec = vec![ + NetworkMessage::Verack, + NetworkMessage::SendHeaders, + NetworkMessage::SendHeaders2, + NetworkMessage::MemPool, + NetworkMessage::GetAddr, + NetworkMessage::WtxidRelay, + NetworkMessage::FilterClear, + NetworkMessage::SendAddrV2, + ]; + + for msg in &empty_payload_messages { + // Verify encoding produces empty payload + let encoded = msg.consensus_encode_payload(); + assert!( + encoded.is_empty(), + "{} should have empty payload, got {} bytes", + msg.cmd(), + encoded.len() + ); + + // Verify decoding works with empty payload + let decoded = NetworkMessage::consensus_decode_payload(msg.cmd(), &[]) + .expect(&format!("Failed to decode empty {} message", msg.cmd())); + assert_eq!(msg, &decoded, "Empty payload round-trip failed for {}", msg.cmd()); + } + } + + #[test] + #[cfg(feature = "core-block-hash-use-x11")] + fn test_headers_message_special_encoding() { + let header = block::Header { + version: block::Version::from_consensus(1), + prev_blockhash: hash_x11([1u8; 32]).into(), + merkle_root: hash([2u8; 32]).into(), + time: 1234567890, + bits: crate::pow::CompactTarget::from_consensus(0x1d00ffff), + nonce: 42, + }; + + // Test empty headers + let empty_headers = NetworkMessage::Headers(vec![]); + let encoded = empty_headers.consensus_encode_payload(); + assert_eq!(encoded, vec![0x00], "Empty headers should encode to single 0x00 varint"); + test_payload_round_trip(&empty_headers); + + // Test single header + let single_header = NetworkMessage::Headers(vec![header.clone()]); + let encoded = single_header.consensus_encode_payload(); + // Should be: varint(1) + header_bytes + 0x00 (tx count) + // Header is 80 bytes, so total should be 1 + 80 + 1 = 82 bytes + assert_eq!(encoded.len(), 82, "Single header should be 82 bytes"); + assert_eq!(encoded[81], 0x00, "Header should have trailing zero tx count"); + test_payload_round_trip(&single_header); + + // Test multiple headers + let multi_headers = NetworkMessage::Headers(vec![header.clone(), header.clone(), header]); + let encoded = multi_headers.consensus_encode_payload(); + // Should be: varint(3) + 3 * (header + 0x00) = 1 + 3*81 = 244 bytes + assert_eq!(encoded.len(), 244, "Three headers should be 244 bytes"); + test_payload_round_trip(&multi_headers); + } } From 3c0b6e69dd6e1f2bf3e05c77211fc147e0d72f47 Mon Sep 17 00:00:00 2001 From: pasta Date: Fri, 12 Dec 2025 11:35:27 -0600 Subject: [PATCH 8/9] test: Add regression test for Unknown message payload encoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test that exposes a bug in consensus_encode_payload for Unknown message variants - serialize(data) adds a VarInt length prefix but consensus_decode_payload expects raw bytes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash/src/network/message.rs | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/dash/src/network/message.rs b/dash/src/network/message.rs index 11c7a3ea1..0b9d8907b 100644 --- a/dash/src/network/message.rs +++ b/dash/src/network/message.rs @@ -1266,4 +1266,29 @@ mod test { assert_eq!(encoded.len(), 244, "Three headers should be 244 bytes"); test_payload_round_trip(&multi_headers); } + + #[test] + fn test_encode_decode_unknown_message() { + // Create an Unknown message with a custom command and payload + let unknown_msg = NetworkMessage::Unknown { + command: CommandString::try_from_static("custom").unwrap(), + payload: vec![0xaa, 0xbb, 0xcc, 0xdd], + }; + + // Test encoding - should return raw payload bytes without length prefix + let encoded = unknown_msg.consensus_encode_payload(); + assert_eq!( + encoded, + vec![0xaa, 0xbb, 0xcc, 0xdd], + "Unknown message should encode to raw payload bytes without length prefix" + ); + + // Test decoding with the actual command (not "unknown") + // Note: We must use command() not cmd() because cmd() returns "unknown" for Unknown variants + let cmd = unknown_msg.command(); + let decoded = NetworkMessage::consensus_decode_payload(cmd.as_ref(), &encoded) + .expect("Failed to decode unknown message"); + + assert_eq!(unknown_msg, decoded, "Round-trip failed for unknown message"); + } } From 7bac099105204b2325ed37c7b664619f10bb8807 Mon Sep 17 00:00:00 2001 From: pasta Date: Fri, 12 Dec 2025 11:37:01 -0600 Subject: [PATCH 9/9] fix: Return raw bytes for Unknown message payload encoding MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix consensus_encode_payload for Unknown message variants to return the raw payload bytes directly instead of using serialize() which adds a VarInt length prefix. This matches consensus_decode_payload which expects raw bytes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- dash/src/network/message.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/src/network/message.rs b/dash/src/network/message.rs index 0b9d8907b..8c9c7efa5 100644 --- a/dash/src/network/message.rs +++ b/dash/src/network/message.rs @@ -408,7 +408,7 @@ impl NetworkMessage { NetworkMessage::Unknown { payload: ref data, .. - } => serialize(data), + } => data.clone(), NetworkMessage::Verack | NetworkMessage::SendHeaders | NetworkMessage::SendHeaders2