diff --git a/CHANGELOG.md b/CHANGELOG.md index d65cc8cb..50715a0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,17 @@ All notable changes to this project will be documented in this file. +## Unreleased + +### New features + +* **RPKI RTR Protocol Support**: Add full support for the RPKI-to-Router (RTR) protocol + - New `models::rpki::rtr` module with all PDU types: SerialNotify, SerialQuery, ResetQuery, CacheResponse, IPv4Prefix, IPv6Prefix, EndOfData, CacheReset, RouterKey, ErrorReport + - New `parser::rpki::rtr` module with parsing (`parse_rtr_pdu`, `read_rtr_pdu`) and encoding (`RtrEncode` trait) + - Support for both RTR v0 ([RFC 6810](https://datatracker.ietf.org/doc/html/rfc6810)) and v1 ([RFC 8210](https://datatracker.ietf.org/doc/html/rfc8210)) + - Comprehensive error handling with `RtrError` enum + - New example `rtr_client.rs` demonstrating RTR client implementation with ROA fetching and route validation + ## v0.13.0 - 2025-12-07 ### Breaking changes diff --git a/README.md b/README.md index bb12c7cc..90b7aadc 100644 --- a/README.md +++ b/README.md @@ -554,6 +554,65 @@ See the [MrtRecord] documentation for the complete structure definition. - **Message-level analysis**: Work with UPDATE messages or RIB entries as units - **Memory efficiency**: Shared attributes aren't cloned for each prefix +## RPKI RTR Protocol Support + +BGPKIT Parser includes support for the RPKI-to-Router (RTR) protocol, enabling downstream +clients to communicate with RTR cache servers and fetch Route Origin Authorizations (ROAs). + +### Overview + +The RTR protocol is used to deliver validated RPKI data from a cache server to a router. +BGPKIT Parser provides: +- **PDU definitions**: All RTR protocol data structures for both v0 (RFC 6810) and v1 (RFC 8210) +- **Parsing**: Decode binary RTR PDUs into structured Rust types +- **Encoding**: Serialize RTR PDUs to binary format for sending to servers + +**Note**: This library provides PDU parsing/encoding only. Transport (TCP, SSH, TLS) and +RPKI validation logic are out of scope and should be handled by downstream clients. + +### Quick Example + +```rust +use bgpkit_parser::models::rpki::rtr::*; +use bgpkit_parser::parser::rpki::rtr::{parse_rtr_pdu, RtrEncode}; + +// Create a Reset Query to request the full ROA database +let query = RtrResetQuery::new_v1(); +let bytes = query.encode(); + +// Parse a PDU from bytes +let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); +assert!(matches!(pdu, RtrPdu::ResetQuery(_))); +``` + +### Available PDU Types + +| PDU Type | Direction | Description | +|----------|-----------|-------------| +| Serial Notify | Server → Client | Notifies client of new data | +| Serial Query | Client → Server | Requests incremental update | +| Reset Query | Client → Server | Requests full database | +| Cache Response | Server → Client | Begins data transfer | +| IPv4 Prefix | Server → Client | ROA for IPv4 prefix | +| IPv6 Prefix | Server → Client | ROA for IPv6 prefix | +| End of Data | Server → Client | Ends data transfer | +| Cache Reset | Server → Client | Cannot provide incremental update | +| Router Key | Server → Client | BGPsec key (v1 only) | +| Error Report | Bidirectional | Error notification | + +### Building an RTR Client + +See the [`rtr_client` example](https://github.com/bgpkit/bgpkit-parser/blob/main/examples/rtr_client.rs) +for a complete working example that: +1. Connects to an RTR server +2. Sends a Reset Query +3. Collects ROAs +4. Validates a route announcement (1.1.1.0/24 → AS13335) + +```bash +cargo run --example rtr_client -- rtr.rpki.cloudflare.com 8282 +``` + **Supported message types** (via enum variants): - `Bgp4MpUpdate`: BGP UPDATE messages from UPDATES files - `TableDumpV2Entry`: RIB entries from TableDumpV2 RIB dumps @@ -696,6 +755,11 @@ Full support for standard, extended, and large communities: - [RFC 8097](https://datatracker.ietf.org/doc/html/rfc8097): BGP Prefix Origin Validation State Extended Community - [RFC 8092](https://datatracker.ietf.org/doc/html/rfc8092): BGP Large Communities +### RPKI-to-Router (RTR) Protocol + +- [RFC 6810](https://datatracker.ietf.org/doc/html/rfc6810): The Resource Public Key Infrastructure (RPKI) to Router Protocol +- [RFC 8210](https://datatracker.ietf.org/doc/html/rfc8210): The Resource Public Key Infrastructure (RPKI) to Router Protocol, Version 1 + ### Advanced Features **FlowSpec**: diff --git a/examples/README.md b/examples/README.md index d421fe26..b4392b50 100644 --- a/examples/README.md +++ b/examples/README.md @@ -42,5 +42,8 @@ This directory contains runnable examples for bgpkit_parser. They demonstrate ba - [mrt_debug.rs](mrt_debug.rs) — Demonstrate MRT debugging features: debug display for MRT records, raw byte export, and the new `Display` implementation. - [extract_problematic_records.rs](extract_problematic_records.rs) — Find and export MRT records that fail to parse for further analysis with other tools. +## RPKI RTR Protocol +- [rtr_client.rs](rtr_client.rs) — Connect to an RTR server (RFC 6810/8210), fetch ROAs, and validate a route announcement (1.1.1.0/24 -> AS13335). Demonstrates RTR PDU parsing and encoding. + ## Local-only and Misc - [local_only/src/main.rs](local_only/src/main.rs) — Minimal example that reads a local updates.bz2 file; intended for local experimentation (not network fetching). diff --git a/examples/rtr_client.rs b/examples/rtr_client.rs new file mode 100644 index 00000000..9f491d7d --- /dev/null +++ b/examples/rtr_client.rs @@ -0,0 +1,251 @@ +//! Example RTR client that fetches ROAs and validates 1.1.1.0/24 -> AS13335 +//! +//! This example demonstrates how to use the RTR protocol support in bgpkit-parser +//! to build a simple RTR client that: +//! 1. Connects to an RTR server +//! 2. Sends a Reset Query to get the full ROA database +//! 3. Collects IPv4 ROAs +//! 4. Validates a specific route announcement (1.1.1.0/24 -> AS13335) +//! +//! You can start a fully-functional RTR server with the `stayrtr` Docker image: +//! ```bash +//! docker run -it --rm -p 8282:8282 rpki/stayrtr -cache https://rpki.cloudflare.com/rpki.json +//! ``` +//! +//! Usage: +//! cargo run --example rtr_client -- +//! +//! Example: +//! cargo run --example rtr_client -- localhost 8282 +//! +//! Note: This is a simple example for demonstration purposes. A production +//! RTR client would need proper error handling, reconnection logic, and +//! session management. + +use bgpkit_parser::models::rpki::rtr::*; +use bgpkit_parser::parser::rpki::rtr::{read_rtr_pdu, RtrEncode, RtrError}; +use std::io::Write; +use std::net::{Ipv4Addr, TcpStream}; + +/// Simple ROA entry for validation +#[derive(Debug, Clone)] +struct RoaEntry { + prefix: Ipv4Addr, + prefix_len: u8, + max_len: u8, + asn: u32, +} + +/// Validation result per RFC 6811 +#[derive(Debug, PartialEq)] +enum ValidationState { + /// At least one VRP matches the route announcement + Valid, + /// At least one VRP covers the prefix, but none match the AS + Invalid, + /// No VRP covers the prefix + NotFound, +} + +fn main() -> Result<(), Box> { + let args: Vec = std::env::args().collect(); + if args.len() != 3 { + eprintln!("Usage: {} ", args[0]); + eprintln!(); + eprintln!("Example:"); + eprintln!(" {} rtr.rpki.cloudflare.com 8282", args[0]); + std::process::exit(1); + } + + let host = &args[1]; + let port: u16 = args[2].parse()?; + + // Connect to RTR server + println!("Connecting to {}:{}...", host, port); + let mut stream = TcpStream::connect((host.as_str(), port))?; + stream.set_read_timeout(Some(std::time::Duration::from_secs(60)))?; + + // Send Reset Query to get full database (start with v1) + let reset_query = RtrResetQuery::new_v1(); + stream.write_all(&reset_query.encode())?; + println!("Sent Reset Query (v1)"); + + // Collect ROAs + let mut ipv4_roas: Vec = Vec::new(); + let mut ipv6_count = 0usize; + let mut session_id: Option = None; + let mut serial: Option = None; + + // Read PDUs until End of Data + loop { + match read_rtr_pdu(&mut stream) { + Ok(pdu) => match pdu { + RtrPdu::CacheResponse(resp) => { + println!("Cache Response: session_id={}", resp.session_id); + session_id = Some(resp.session_id); + } + + RtrPdu::IPv4Prefix(p) => { + if p.is_announcement() { + ipv4_roas.push(RoaEntry { + prefix: p.prefix, + prefix_len: p.prefix_length, + max_len: p.max_length, + asn: p.asn.into(), + }); + } + } + + RtrPdu::IPv6Prefix(p) => { + if p.is_announcement() { + ipv6_count += 1; + } + } + + RtrPdu::RouterKey(_) => { + // BGPsec router keys - skip for this example + } + + RtrPdu::EndOfData(eod) => { + serial = Some(eod.serial_number); + println!("End of Data: serial={}", eod.serial_number); + if let (Some(refresh), Some(retry), Some(expire)) = ( + eod.refresh_interval, + eod.retry_interval, + eod.expire_interval, + ) { + println!( + " Timing: refresh={}s, retry={}s, expire={}s", + refresh, retry, expire + ); + } + break; + } + + RtrPdu::CacheReset(_) => { + println!("Received Cache Reset - server has no data"); + break; + } + + RtrPdu::ErrorReport(err) => { + eprintln!("Server error: {:?} - {}", err.error_code, err.error_text); + // Try downgrade to v0 if version not supported + if err.error_code == RtrErrorCode::UnsupportedProtocolVersion { + println!("Retrying with v0..."); + let reset_v0 = RtrResetQuery::new_v0(); + stream.write_all(&reset_v0.encode())?; + continue; + } + break; + } + + other => { + println!("Unexpected PDU: {:?}", other); + } + }, + Err(RtrError::IoError(ref e)) if e.kind() == std::io::ErrorKind::UnexpectedEof => { + println!("Connection closed"); + break; + } + Err(e) => { + eprintln!("Error reading PDU: {:?}", e); + break; + } + } + } + + println!(); + println!("Session Summary:"); + println!(" Session ID: {:?}", session_id); + println!(" Serial: {:?}", serial); + println!(" IPv4 ROAs: {}", ipv4_roas.len()); + println!(" IPv6 ROAs: {}", ipv6_count); + + // Validate 1.1.1.0/24 -> AS13335 (Cloudflare) + let test_prefix = Ipv4Addr::new(1, 1, 1, 0); + let test_prefix_len = 24u8; + let test_asn = 13335u32; + + let result = validate_route(&ipv4_roas, test_prefix, test_prefix_len, test_asn); + + println!(); + println!( + "Route Validation: {}/{} -> AS{}", + test_prefix, test_prefix_len, test_asn + ); + println!(" Result: {:?}", result); + + // Show matching/covering ROAs + let covering: Vec<_> = ipv4_roas + .iter() + .filter(|roa| covers(roa, test_prefix, test_prefix_len)) + .collect(); + + if !covering.is_empty() { + println!(); + println!("Covering ROAs:"); + for roa in covering { + let status = if test_prefix_len <= roa.max_len && test_asn == roa.asn { + "VALID" + } else { + "covers but doesn't match" + }; + println!( + " {}/{}-{} -> AS{} [{}]", + roa.prefix, roa.prefix_len, roa.max_len, roa.asn, status + ); + } + } + + Ok(()) +} + +/// Check if a ROA covers a given prefix +fn covers(roa: &RoaEntry, prefix: Ipv4Addr, prefix_len: u8) -> bool { + // The announced prefix must be at least as specific as the ROA prefix + if prefix_len < roa.prefix_len { + return false; + } + + // Check if the ROA prefix is a prefix of the announced prefix + let roa_bits: u32 = roa.prefix.into(); + let prefix_bits: u32 = prefix.into(); + let mask = if roa.prefix_len == 0 { + 0 + } else { + !0u32 << (32 - roa.prefix_len) + }; + + (roa_bits & mask) == (prefix_bits & mask) +} + +/// Validate a route announcement per RFC 6811 +fn validate_route( + roas: &[RoaEntry], + prefix: Ipv4Addr, + prefix_len: u8, + asn: u32, +) -> ValidationState { + let mut found_covering = false; + + for roa in roas { + if !covers(roa, prefix, prefix_len) { + continue; + } + + found_covering = true; + + // Check if this ROA validates the announcement + // The announced prefix length must be <= max_length + // The origin AS must match + if prefix_len <= roa.max_len && asn == roa.asn { + return ValidationState::Valid; + } + } + + if found_covering { + ValidationState::Invalid + } else { + ValidationState::NotFound + } +} diff --git a/src/lib.rs b/src/lib.rs index 4164a31c..4753265d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -550,6 +550,65 @@ See the [MrtRecord] documentation for the complete structure definition. - **Message-level analysis**: Work with UPDATE messages or RIB entries as units - **Memory efficiency**: Shared attributes aren't cloned for each prefix +# RPKI RTR Protocol Support + +BGPKIT Parser includes support for the RPKI-to-Router (RTR) protocol, enabling downstream +clients to communicate with RTR cache servers and fetch Route Origin Authorizations (ROAs). + +## Overview + +The RTR protocol is used to deliver validated RPKI data from a cache server to a router. +BGPKIT Parser provides: +- **PDU definitions**: All RTR protocol data structures for both v0 (RFC 6810) and v1 (RFC 8210) +- **Parsing**: Decode binary RTR PDUs into structured Rust types +- **Encoding**: Serialize RTR PDUs to binary format for sending to servers + +**Note**: This library provides PDU parsing/encoding only. Transport (TCP, SSH, TLS) and +RPKI validation logic are out of scope and should be handled by downstream clients. + +## Quick Example + +```rust +use bgpkit_parser::models::rpki::rtr::*; +use bgpkit_parser::parser::rpki::rtr::{parse_rtr_pdu, RtrEncode}; + +// Create a Reset Query to request the full ROA database +let query = RtrResetQuery::new_v1(); +let bytes = query.encode(); + +// Parse a PDU from bytes +let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); +assert!(matches!(pdu, RtrPdu::ResetQuery(_))); +``` + +## Available PDU Types + +| PDU Type | Direction | Description | +|----------|-----------|-------------| +| Serial Notify | Server → Client | Notifies client of new data | +| Serial Query | Client → Server | Requests incremental update | +| Reset Query | Client → Server | Requests full database | +| Cache Response | Server → Client | Begins data transfer | +| IPv4 Prefix | Server → Client | ROA for IPv4 prefix | +| IPv6 Prefix | Server → Client | ROA for IPv6 prefix | +| End of Data | Server → Client | Ends data transfer | +| Cache Reset | Server → Client | Cannot provide incremental update | +| Router Key | Server → Client | BGPsec key (v1 only) | +| Error Report | Bidirectional | Error notification | + +## Building an RTR Client + +See the [`rtr_client` example](https://github.com/bgpkit/bgpkit-parser/blob/main/examples/rtr_client.rs) +for a complete working example that: +1. Connects to an RTR server +2. Sends a Reset Query +3. Collects ROAs +4. Validates a route announcement (1.1.1.0/24 → AS13335) + +```bash +cargo run --example rtr_client -- rtr.rpki.cloudflare.com 8282 +``` + **Supported message types** (via enum variants): - `Bgp4MpUpdate`: BGP UPDATE messages from UPDATES files - `TableDumpV2Entry`: RIB entries from TableDumpV2 RIB dumps @@ -692,6 +751,11 @@ Full support for standard, extended, and large communities: - [RFC 8097](https://datatracker.ietf.org/doc/html/rfc8097): BGP Prefix Origin Validation State Extended Community - [RFC 8092](https://datatracker.ietf.org/doc/html/rfc8092): BGP Large Communities +## RPKI-to-Router (RTR) Protocol + +- [RFC 6810](https://datatracker.ietf.org/doc/html/rfc6810): The Resource Public Key Infrastructure (RPKI) to Router Protocol +- [RFC 8210](https://datatracker.ietf.org/doc/html/rfc8210): The Resource Public Key Infrastructure (RPKI) to Router Protocol, Version 1 + ## Advanced Features **FlowSpec**: diff --git a/src/models/mod.rs b/src/models/mod.rs index 628fc99b..7d26fd4a 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -55,6 +55,7 @@ mod bgp; mod err; mod mrt; mod network; +pub mod rpki; pub use bgp::*; pub use err::BgpModelsError; diff --git a/src/models/rpki/mod.rs b/src/models/rpki/mod.rs new file mode 100644 index 00000000..d5f66a27 --- /dev/null +++ b/src/models/rpki/mod.rs @@ -0,0 +1,9 @@ +//! RPKI (Resource Public Key Infrastructure) related data structures. +//! +//! This module provides data structures for RPKI-related protocols: +//! +//! - [`rtr`]: RPKI-to-Router (RTR) Protocol PDU definitions (RFC 6810, RFC 8210) + +pub mod rtr; + +pub use rtr::*; diff --git a/src/models/rpki/rtr.rs b/src/models/rpki/rtr.rs new file mode 100644 index 00000000..b122036c --- /dev/null +++ b/src/models/rpki/rtr.rs @@ -0,0 +1,1234 @@ +//! RPKI-to-Router (RTR) Protocol Data Structures +//! +//! This module defines the data structures for the RTR protocol as specified in: +//! - RTR v0: [RFC 6810](https://www.rfc-editor.org/rfc/rfc6810.txt) +//! - RTR v1: [RFC 8210](https://www.rfc-editor.org/rfc/rfc8210.txt) +//! +//! The RTR protocol is used to deliver validated RPKI data from a cache server +//! to a router. This module provides PDU definitions that can be used by +//! downstream clients to implement RTR protocol communication. +//! +//! # Example +//! +//! ```rust +//! use bgpkit_parser::models::rpki::rtr::*; +//! +//! // Create a reset query to request the full database +//! let query = RtrResetQuery { +//! version: RtrProtocolVersion::V1, +//! }; +//! +//! // Check if a prefix PDU is an announcement or withdrawal +//! let prefix = RtrIPv4Prefix { +//! version: RtrProtocolVersion::V1, +//! flags: 1, // announcement +//! prefix_length: 24, +//! max_length: 24, +//! prefix: std::net::Ipv4Addr::new(192, 0, 2, 0), +//! asn: 65001.into(), +//! }; +//! assert!(prefix.is_announcement()); +//! ``` + +use crate::models::Asn; +use std::net::{Ipv4Addr, Ipv6Addr}; + +// ============================================================================= +// Core Enums +// ============================================================================= + +/// RTR Protocol Version +/// +/// The RTR protocol has two versions: +/// - V0 (RFC 6810): Original protocol specification +/// - V1 (RFC 8210): Adds Router Key PDU and timing parameters in End of Data +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[repr(u8)] +pub enum RtrProtocolVersion { + /// RTR Protocol Version 0 (RFC 6810) + V0 = 0, + /// RTR Protocol Version 1 (RFC 8210) + #[default] + V1 = 1, +} + +impl RtrProtocolVersion { + /// Create from a raw byte value + pub fn from_u8(value: u8) -> Option { + match value { + 0 => Some(RtrProtocolVersion::V0), + 1 => Some(RtrProtocolVersion::V1), + _ => None, + } + } + + /// Convert to raw byte value + pub fn to_u8(self) -> u8 { + self as u8 + } +} + +impl From for u8 { + fn from(v: RtrProtocolVersion) -> Self { + v as u8 + } +} + +impl TryFrom for RtrProtocolVersion { + type Error = u8; + + fn try_from(value: u8) -> Result { + RtrProtocolVersion::from_u8(value).ok_or(value) + } +} + +/// RTR PDU Type +/// +/// Identifies the type of RTR Protocol Data Unit. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[repr(u8)] +pub enum RtrPduType { + /// Serial Notify - Server notifies client of new data available + SerialNotify = 0, + /// Serial Query - Client requests incremental update + SerialQuery = 1, + /// Reset Query - Client requests full database + ResetQuery = 2, + /// Cache Response - Server begins sending data + CacheResponse = 3, + /// IPv4 Prefix - ROA for IPv4 + IPv4Prefix = 4, + /// IPv6 Prefix - ROA for IPv6 + IPv6Prefix = 6, + /// End of Data - Server finished sending data + EndOfData = 7, + /// Cache Reset - Server cannot provide incremental update + CacheReset = 8, + /// Router Key - BGPsec router key (v1 only) + RouterKey = 9, + /// Error Report - Error notification + ErrorReport = 10, +} + +impl RtrPduType { + /// Create from a raw byte value + pub fn from_u8(value: u8) -> Option { + match value { + 0 => Some(RtrPduType::SerialNotify), + 1 => Some(RtrPduType::SerialQuery), + 2 => Some(RtrPduType::ResetQuery), + 3 => Some(RtrPduType::CacheResponse), + 4 => Some(RtrPduType::IPv4Prefix), + 6 => Some(RtrPduType::IPv6Prefix), + 7 => Some(RtrPduType::EndOfData), + 8 => Some(RtrPduType::CacheReset), + 9 => Some(RtrPduType::RouterKey), + 10 => Some(RtrPduType::ErrorReport), + _ => None, + } + } + + /// Convert to raw byte value + pub fn to_u8(self) -> u8 { + self as u8 + } +} + +impl From for u8 { + fn from(v: RtrPduType) -> Self { + v as u8 + } +} + +impl TryFrom for RtrPduType { + type Error = u8; + + fn try_from(value: u8) -> Result { + RtrPduType::from_u8(value).ok_or(value) + } +} + +/// RTR Error Code +/// +/// Error codes used in Error Report PDUs. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[repr(u16)] +pub enum RtrErrorCode { + /// Corrupt Data - PDU could not be parsed + CorruptData = 0, + /// Internal Error - Cache experienced an internal error + InternalError = 1, + /// No Data Available - Cache has no data yet + NoDataAvailable = 2, + /// Invalid Request - Request was invalid + InvalidRequest = 3, + /// Unsupported Protocol Version - Protocol version not supported + UnsupportedProtocolVersion = 4, + /// Unsupported PDU Type - PDU type not supported + UnsupportedPduType = 5, + /// Withdrawal of Unknown Record - Tried to withdraw non-existent record + WithdrawalOfUnknownRecord = 6, + /// Duplicate Announcement Received - Same record announced twice + DuplicateAnnouncementReceived = 7, + /// Unexpected Protocol Version - Version mismatch mid-session (v1 only) + UnexpectedProtocolVersion = 8, +} + +impl RtrErrorCode { + /// Create from a raw u16 value + pub fn from_u16(value: u16) -> Option { + match value { + 0 => Some(RtrErrorCode::CorruptData), + 1 => Some(RtrErrorCode::InternalError), + 2 => Some(RtrErrorCode::NoDataAvailable), + 3 => Some(RtrErrorCode::InvalidRequest), + 4 => Some(RtrErrorCode::UnsupportedProtocolVersion), + 5 => Some(RtrErrorCode::UnsupportedPduType), + 6 => Some(RtrErrorCode::WithdrawalOfUnknownRecord), + 7 => Some(RtrErrorCode::DuplicateAnnouncementReceived), + 8 => Some(RtrErrorCode::UnexpectedProtocolVersion), + _ => None, + } + } + + /// Convert to raw u16 value + pub fn to_u16(self) -> u16 { + self as u16 + } +} + +impl From for u16 { + fn from(v: RtrErrorCode) -> Self { + v as u16 + } +} + +impl TryFrom for RtrErrorCode { + type Error = u16; + + fn try_from(value: u16) -> Result { + RtrErrorCode::from_u16(value).ok_or(value) + } +} + +// ============================================================================= +// PDU Structs +// ============================================================================= + +/// Serial Notify PDU (Type 0) +/// +/// Sent by server to notify client that new data is available. +/// This is a hint that the client should send a Serial Query. +/// +/// Direction: Server → Client +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrSerialNotify { + /// Protocol version + pub version: RtrProtocolVersion, + /// Session identifier + pub session_id: u16, + /// Current serial number + pub serial_number: u32, +} + +/// Serial Query PDU (Type 1) +/// +/// Sent by client to request incremental update from a known serial number. +/// +/// Direction: Client → Server +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrSerialQuery { + /// Protocol version + pub version: RtrProtocolVersion, + /// Session identifier from previous session + pub session_id: u16, + /// Last known serial number + pub serial_number: u32, +} + +impl RtrSerialQuery { + /// Create a new Serial Query PDU + pub fn new(version: RtrProtocolVersion, session_id: u16, serial_number: u32) -> Self { + Self { + version, + session_id, + serial_number, + } + } +} + +/// Reset Query PDU (Type 2) +/// +/// Sent by client to request the full database from the server. +/// +/// Direction: Client → Server +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrResetQuery { + /// Protocol version + pub version: RtrProtocolVersion, +} + +impl RtrResetQuery { + /// Create a new Reset Query PDU with the specified version + pub fn new(version: RtrProtocolVersion) -> Self { + Self { version } + } + + /// Create a new Reset Query PDU with version 1 + pub fn new_v1() -> Self { + Self::new(RtrProtocolVersion::V1) + } + + /// Create a new Reset Query PDU with version 0 + pub fn new_v0() -> Self { + Self::new(RtrProtocolVersion::V0) + } +} + +impl Default for RtrResetQuery { + fn default() -> Self { + Self::new_v1() + } +} + +/// Cache Response PDU (Type 3) +/// +/// Sent by server to indicate the start of a data transfer. +/// Followed by zero or more IPv4/IPv6 Prefix and Router Key PDUs, +/// and terminated by an End of Data PDU. +/// +/// Direction: Server → Client +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrCacheResponse { + /// Protocol version + pub version: RtrProtocolVersion, + /// Session identifier + pub session_id: u16, +} + +/// IPv4 Prefix PDU (Type 4) +/// +/// Contains a single ROA for an IPv4 prefix. +/// +/// Direction: Server → Client +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrIPv4Prefix { + /// Protocol version + pub version: RtrProtocolVersion, + /// Flags (bit 0: 1=announcement, 0=withdrawal) + pub flags: u8, + /// Prefix length in bits + pub prefix_length: u8, + /// Maximum prefix length for this ROA + pub max_length: u8, + /// IPv4 prefix + pub prefix: Ipv4Addr, + /// Origin AS number + pub asn: Asn, +} + +impl RtrIPv4Prefix { + /// Check if this is an announcement (not a withdrawal) + #[inline] + pub fn is_announcement(&self) -> bool { + self.flags & 0x01 != 0 + } + + /// Check if this is a withdrawal + #[inline] + pub fn is_withdrawal(&self) -> bool { + !self.is_announcement() + } +} + +/// IPv6 Prefix PDU (Type 6) +/// +/// Contains a single ROA for an IPv6 prefix. +/// +/// Direction: Server → Client +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrIPv6Prefix { + /// Protocol version + pub version: RtrProtocolVersion, + /// Flags (bit 0: 1=announcement, 0=withdrawal) + pub flags: u8, + /// Prefix length in bits + pub prefix_length: u8, + /// Maximum prefix length for this ROA + pub max_length: u8, + /// IPv6 prefix + pub prefix: Ipv6Addr, + /// Origin AS number + pub asn: Asn, +} + +impl RtrIPv6Prefix { + /// Check if this is an announcement (not a withdrawal) + #[inline] + pub fn is_announcement(&self) -> bool { + self.flags & 0x01 != 0 + } + + /// Check if this is a withdrawal + #[inline] + pub fn is_withdrawal(&self) -> bool { + !self.is_announcement() + } +} + +/// End of Data PDU (Type 7) +/// +/// Sent by server to indicate the end of a data transfer. +/// +/// Direction: Server → Client +/// +/// Note: In v1, this PDU includes timing parameters. In v0, these are absent. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrEndOfData { + /// Protocol version + pub version: RtrProtocolVersion, + /// Session identifier + pub session_id: u16, + /// Current serial number + pub serial_number: u32, + /// Refresh interval in seconds (v1 only) + pub refresh_interval: Option, + /// Retry interval in seconds (v1 only) + pub retry_interval: Option, + /// Expire interval in seconds (v1 only) + pub expire_interval: Option, +} + +impl RtrEndOfData { + /// Default refresh interval (1 hour) as recommended by RFC 8210 + pub const DEFAULT_REFRESH: u32 = 3600; + /// Default retry interval (10 minutes) as recommended by RFC 8210 + pub const DEFAULT_RETRY: u32 = 600; + /// Default expire interval (2 hours) as recommended by RFC 8210 + pub const DEFAULT_EXPIRE: u32 = 7200; + + /// Get the refresh interval, using the default if not specified + pub fn refresh_interval_or_default(&self) -> u32 { + self.refresh_interval.unwrap_or(Self::DEFAULT_REFRESH) + } + + /// Get the retry interval, using the default if not specified + pub fn retry_interval_or_default(&self) -> u32 { + self.retry_interval.unwrap_or(Self::DEFAULT_RETRY) + } + + /// Get the expire interval, using the default if not specified + pub fn expire_interval_or_default(&self) -> u32 { + self.expire_interval.unwrap_or(Self::DEFAULT_EXPIRE) + } +} + +/// Cache Reset PDU (Type 8) +/// +/// Sent by server in response to a Serial Query when the server +/// cannot provide an incremental update (e.g., serial is too old). +/// The client should send a Reset Query. +/// +/// Direction: Server → Client +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrCacheReset { + /// Protocol version + pub version: RtrProtocolVersion, +} + +/// Router Key PDU (Type 9, v1 only) +/// +/// Contains a BGPsec router key. +/// +/// Direction: Server → Client +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrRouterKey { + /// Protocol version (always V1) + pub version: RtrProtocolVersion, + /// Flags (bit 0: 1=announcement, 0=withdrawal) + pub flags: u8, + /// Subject Key Identifier (SKI) - 20 bytes + pub subject_key_identifier: [u8; 20], + /// AS number + pub asn: Asn, + /// Subject Public Key Info (SPKI) - variable length + pub subject_public_key_info: Vec, +} + +impl RtrRouterKey { + /// Check if this is an announcement (not a withdrawal) + #[inline] + pub fn is_announcement(&self) -> bool { + self.flags & 0x01 != 0 + } + + /// Check if this is a withdrawal + #[inline] + pub fn is_withdrawal(&self) -> bool { + !self.is_announcement() + } +} + +/// Error Report PDU (Type 10) +/// +/// Sent by either client or server to report an error. +/// +/// Direction: Bidirectional (Client ↔ Server) +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct RtrErrorReport { + /// Protocol version + pub version: RtrProtocolVersion, + /// Error code + pub error_code: RtrErrorCode, + /// The erroneous PDU that caused the error (may be empty) + pub erroneous_pdu: Vec, + /// Human-readable error text (UTF-8) + pub error_text: String, +} + +impl RtrErrorReport { + /// Create a new Error Report PDU + pub fn new( + version: RtrProtocolVersion, + error_code: RtrErrorCode, + erroneous_pdu: Vec, + error_text: String, + ) -> Self { + Self { + version, + error_code, + erroneous_pdu, + error_text, + } + } + + /// Create an error report for unsupported protocol version + pub fn unsupported_version(version: RtrProtocolVersion, erroneous_pdu: Vec) -> Self { + Self::new( + version, + RtrErrorCode::UnsupportedProtocolVersion, + erroneous_pdu, + "Unsupported protocol version".to_string(), + ) + } + + /// Create an error report for unsupported PDU type + pub fn unsupported_pdu_type(version: RtrProtocolVersion, erroneous_pdu: Vec) -> Self { + Self::new( + version, + RtrErrorCode::UnsupportedPduType, + erroneous_pdu, + "Unsupported PDU type".to_string(), + ) + } + + /// Create an error report for corrupt data + pub fn corrupt_data( + version: RtrProtocolVersion, + erroneous_pdu: Vec, + message: &str, + ) -> Self { + Self::new( + version, + RtrErrorCode::CorruptData, + erroneous_pdu, + message.to_string(), + ) + } +} + +// ============================================================================= +// Unified PDU Enum +// ============================================================================= + +/// Unified RTR PDU Enum +/// +/// This enum represents any RTR Protocol Data Unit and is useful for +/// generic PDU handling when reading from a stream. +#[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub enum RtrPdu { + /// Serial Notify (Type 0) + SerialNotify(RtrSerialNotify), + /// Serial Query (Type 1) + SerialQuery(RtrSerialQuery), + /// Reset Query (Type 2) + ResetQuery(RtrResetQuery), + /// Cache Response (Type 3) + CacheResponse(RtrCacheResponse), + /// IPv4 Prefix (Type 4) + IPv4Prefix(RtrIPv4Prefix), + /// IPv6 Prefix (Type 6) + IPv6Prefix(RtrIPv6Prefix), + /// End of Data (Type 7) + EndOfData(RtrEndOfData), + /// Cache Reset (Type 8) + CacheReset(RtrCacheReset), + /// Router Key (Type 9, v1 only) + RouterKey(RtrRouterKey), + /// Error Report (Type 10) + ErrorReport(RtrErrorReport), +} + +impl RtrPdu { + /// Get the PDU type + pub fn pdu_type(&self) -> RtrPduType { + match self { + RtrPdu::SerialNotify(_) => RtrPduType::SerialNotify, + RtrPdu::SerialQuery(_) => RtrPduType::SerialQuery, + RtrPdu::ResetQuery(_) => RtrPduType::ResetQuery, + RtrPdu::CacheResponse(_) => RtrPduType::CacheResponse, + RtrPdu::IPv4Prefix(_) => RtrPduType::IPv4Prefix, + RtrPdu::IPv6Prefix(_) => RtrPduType::IPv6Prefix, + RtrPdu::EndOfData(_) => RtrPduType::EndOfData, + RtrPdu::CacheReset(_) => RtrPduType::CacheReset, + RtrPdu::RouterKey(_) => RtrPduType::RouterKey, + RtrPdu::ErrorReport(_) => RtrPduType::ErrorReport, + } + } + + /// Get the protocol version + pub fn version(&self) -> RtrProtocolVersion { + match self { + RtrPdu::SerialNotify(p) => p.version, + RtrPdu::SerialQuery(p) => p.version, + RtrPdu::ResetQuery(p) => p.version, + RtrPdu::CacheResponse(p) => p.version, + RtrPdu::IPv4Prefix(p) => p.version, + RtrPdu::IPv6Prefix(p) => p.version, + RtrPdu::EndOfData(p) => p.version, + RtrPdu::CacheReset(p) => p.version, + RtrPdu::RouterKey(p) => p.version, + RtrPdu::ErrorReport(p) => p.version, + } + } +} + +impl From for RtrPdu { + fn from(pdu: RtrSerialNotify) -> Self { + RtrPdu::SerialNotify(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrSerialQuery) -> Self { + RtrPdu::SerialQuery(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrResetQuery) -> Self { + RtrPdu::ResetQuery(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrCacheResponse) -> Self { + RtrPdu::CacheResponse(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrIPv4Prefix) -> Self { + RtrPdu::IPv4Prefix(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrIPv6Prefix) -> Self { + RtrPdu::IPv6Prefix(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrEndOfData) -> Self { + RtrPdu::EndOfData(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrCacheReset) -> Self { + RtrPdu::CacheReset(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrRouterKey) -> Self { + RtrPdu::RouterKey(pdu) + } +} + +impl From for RtrPdu { + fn from(pdu: RtrErrorReport) -> Self { + RtrPdu::ErrorReport(pdu) + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_protocol_version_conversion() { + assert_eq!(RtrProtocolVersion::from_u8(0), Some(RtrProtocolVersion::V0)); + assert_eq!(RtrProtocolVersion::from_u8(1), Some(RtrProtocolVersion::V1)); + assert_eq!(RtrProtocolVersion::from_u8(2), None); + + assert_eq!(RtrProtocolVersion::V0.to_u8(), 0); + assert_eq!(RtrProtocolVersion::V1.to_u8(), 1); + } + + #[test] + fn test_pdu_type_conversion() { + assert_eq!(RtrPduType::from_u8(0), Some(RtrPduType::SerialNotify)); + assert_eq!(RtrPduType::from_u8(1), Some(RtrPduType::SerialQuery)); + assert_eq!(RtrPduType::from_u8(4), Some(RtrPduType::IPv4Prefix)); + assert_eq!(RtrPduType::from_u8(5), None); // No type 5 + assert_eq!(RtrPduType::from_u8(6), Some(RtrPduType::IPv6Prefix)); + assert_eq!(RtrPduType::from_u8(9), Some(RtrPduType::RouterKey)); + + assert_eq!(RtrPduType::SerialNotify.to_u8(), 0); + assert_eq!(RtrPduType::IPv6Prefix.to_u8(), 6); + } + + #[test] + fn test_error_code_conversion() { + assert_eq!(RtrErrorCode::from_u16(0), Some(RtrErrorCode::CorruptData)); + assert_eq!( + RtrErrorCode::from_u16(4), + Some(RtrErrorCode::UnsupportedProtocolVersion) + ); + assert_eq!( + RtrErrorCode::from_u16(8), + Some(RtrErrorCode::UnexpectedProtocolVersion) + ); + assert_eq!(RtrErrorCode::from_u16(9), None); + + assert_eq!(RtrErrorCode::CorruptData.to_u16(), 0); + assert_eq!(RtrErrorCode::UnsupportedProtocolVersion.to_u16(), 4); + } + + #[test] + fn test_ipv4_prefix_flags() { + let announcement = RtrIPv4Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(192, 0, 2, 0), + asn: 65001.into(), + }; + assert!(announcement.is_announcement()); + assert!(!announcement.is_withdrawal()); + + let withdrawal = RtrIPv4Prefix { + version: RtrProtocolVersion::V1, + flags: 0, + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(192, 0, 2, 0), + asn: 65001.into(), + }; + assert!(!withdrawal.is_announcement()); + assert!(withdrawal.is_withdrawal()); + } + + #[test] + fn test_ipv6_prefix_flags() { + let announcement = RtrIPv6Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 48, + max_length: 48, + prefix: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + asn: 65001.into(), + }; + assert!(announcement.is_announcement()); + assert!(!announcement.is_withdrawal()); + } + + #[test] + fn test_router_key_flags() { + let announcement = RtrRouterKey { + version: RtrProtocolVersion::V1, + flags: 1, + subject_key_identifier: [0; 20], + asn: 65001.into(), + subject_public_key_info: vec![1, 2, 3], + }; + assert!(announcement.is_announcement()); + assert!(!announcement.is_withdrawal()); + } + + #[test] + fn test_end_of_data_defaults() { + assert_eq!(RtrEndOfData::DEFAULT_REFRESH, 3600); + assert_eq!(RtrEndOfData::DEFAULT_RETRY, 600); + assert_eq!(RtrEndOfData::DEFAULT_EXPIRE, 7200); + + let eod = RtrEndOfData { + version: RtrProtocolVersion::V0, + session_id: 1, + serial_number: 100, + refresh_interval: None, + retry_interval: None, + expire_interval: None, + }; + assert_eq!(eod.refresh_interval_or_default(), 3600); + assert_eq!(eod.retry_interval_or_default(), 600); + assert_eq!(eod.expire_interval_or_default(), 7200); + + let eod_v1 = RtrEndOfData { + version: RtrProtocolVersion::V1, + session_id: 1, + serial_number: 100, + refresh_interval: Some(1800), + retry_interval: Some(300), + expire_interval: Some(3600), + }; + assert_eq!(eod_v1.refresh_interval_or_default(), 1800); + assert_eq!(eod_v1.retry_interval_or_default(), 300); + assert_eq!(eod_v1.expire_interval_or_default(), 3600); + } + + #[test] + fn test_reset_query_constructors() { + let v1 = RtrResetQuery::new_v1(); + assert_eq!(v1.version, RtrProtocolVersion::V1); + + let v0 = RtrResetQuery::new_v0(); + assert_eq!(v0.version, RtrProtocolVersion::V0); + + let default = RtrResetQuery::default(); + assert_eq!(default.version, RtrProtocolVersion::V1); + } + + #[test] + fn test_pdu_enum_type() { + let pdu = RtrPdu::SerialNotify(RtrSerialNotify { + version: RtrProtocolVersion::V1, + session_id: 1, + serial_number: 100, + }); + assert_eq!(pdu.pdu_type(), RtrPduType::SerialNotify); + assert_eq!(pdu.version(), RtrProtocolVersion::V1); + } + + #[test] + fn test_pdu_enum_all_types_and_versions() { + // Test pdu_type() and version() for all PDU variants + let pdus = vec![ + ( + RtrPdu::SerialNotify(RtrSerialNotify { + version: RtrProtocolVersion::V0, + session_id: 1, + serial_number: 100, + }), + RtrPduType::SerialNotify, + RtrProtocolVersion::V0, + ), + ( + RtrPdu::SerialQuery(RtrSerialQuery { + version: RtrProtocolVersion::V1, + session_id: 2, + serial_number: 200, + }), + RtrPduType::SerialQuery, + RtrProtocolVersion::V1, + ), + ( + RtrPdu::ResetQuery(RtrResetQuery { + version: RtrProtocolVersion::V0, + }), + RtrPduType::ResetQuery, + RtrProtocolVersion::V0, + ), + ( + RtrPdu::CacheResponse(RtrCacheResponse { + version: RtrProtocolVersion::V1, + session_id: 3, + }), + RtrPduType::CacheResponse, + RtrProtocolVersion::V1, + ), + ( + RtrPdu::IPv4Prefix(RtrIPv4Prefix { + version: RtrProtocolVersion::V0, + flags: 1, + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(10, 0, 0, 0), + asn: 65000.into(), + }), + RtrPduType::IPv4Prefix, + RtrProtocolVersion::V0, + ), + ( + RtrPdu::IPv6Prefix(RtrIPv6Prefix { + version: RtrProtocolVersion::V1, + flags: 0, + prefix_length: 48, + max_length: 64, + prefix: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + asn: 65001.into(), + }), + RtrPduType::IPv6Prefix, + RtrProtocolVersion::V1, + ), + ( + RtrPdu::EndOfData(RtrEndOfData { + version: RtrProtocolVersion::V0, + session_id: 4, + serial_number: 300, + refresh_interval: None, + retry_interval: None, + expire_interval: None, + }), + RtrPduType::EndOfData, + RtrProtocolVersion::V0, + ), + ( + RtrPdu::CacheReset(RtrCacheReset { + version: RtrProtocolVersion::V1, + }), + RtrPduType::CacheReset, + RtrProtocolVersion::V1, + ), + ( + RtrPdu::RouterKey(RtrRouterKey { + version: RtrProtocolVersion::V1, + flags: 1, + subject_key_identifier: [0; 20], + asn: 65002.into(), + subject_public_key_info: vec![], + }), + RtrPduType::RouterKey, + RtrProtocolVersion::V1, + ), + ( + RtrPdu::ErrorReport(RtrErrorReport { + version: RtrProtocolVersion::V0, + error_code: RtrErrorCode::InternalError, + erroneous_pdu: vec![], + error_text: String::new(), + }), + RtrPduType::ErrorReport, + RtrProtocolVersion::V0, + ), + ]; + + for (pdu, expected_type, expected_version) in pdus { + assert_eq!(pdu.pdu_type(), expected_type); + assert_eq!(pdu.version(), expected_version); + } + } + + #[test] + fn test_pdu_from_impls() { + let query = RtrResetQuery::new_v1(); + let pdu: RtrPdu = query.into(); + assert_eq!(pdu.pdu_type(), RtrPduType::ResetQuery); + } + + #[test] + fn test_all_pdu_from_impls() { + // Test From impl for all PDU types + let notify = RtrSerialNotify { + version: RtrProtocolVersion::V1, + session_id: 1, + serial_number: 100, + }; + let pdu: RtrPdu = notify.into(); + assert!(matches!(pdu, RtrPdu::SerialNotify(_))); + + let query = RtrSerialQuery::new(RtrProtocolVersion::V1, 1, 100); + let pdu: RtrPdu = query.into(); + assert!(matches!(pdu, RtrPdu::SerialQuery(_))); + + let response = RtrCacheResponse { + version: RtrProtocolVersion::V1, + session_id: 1, + }; + let pdu: RtrPdu = response.into(); + assert!(matches!(pdu, RtrPdu::CacheResponse(_))); + + let prefix4 = RtrIPv4Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(10, 0, 0, 0), + asn: 65000.into(), + }; + let pdu: RtrPdu = prefix4.into(); + assert!(matches!(pdu, RtrPdu::IPv4Prefix(_))); + + let prefix6 = RtrIPv6Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 48, + max_length: 48, + prefix: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + asn: 65000.into(), + }; + let pdu: RtrPdu = prefix6.into(); + assert!(matches!(pdu, RtrPdu::IPv6Prefix(_))); + + let eod = RtrEndOfData { + version: RtrProtocolVersion::V1, + session_id: 1, + serial_number: 100, + refresh_interval: Some(3600), + retry_interval: Some(600), + expire_interval: Some(7200), + }; + let pdu: RtrPdu = eod.into(); + assert!(matches!(pdu, RtrPdu::EndOfData(_))); + + let reset = RtrCacheReset { + version: RtrProtocolVersion::V1, + }; + let pdu: RtrPdu = reset.into(); + assert!(matches!(pdu, RtrPdu::CacheReset(_))); + + let key = RtrRouterKey { + version: RtrProtocolVersion::V1, + flags: 1, + subject_key_identifier: [0; 20], + asn: 65000.into(), + subject_public_key_info: vec![], + }; + let pdu: RtrPdu = key.into(); + assert!(matches!(pdu, RtrPdu::RouterKey(_))); + + let error = RtrErrorReport::new( + RtrProtocolVersion::V1, + RtrErrorCode::InternalError, + vec![], + String::new(), + ); + let pdu: RtrPdu = error.into(); + assert!(matches!(pdu, RtrPdu::ErrorReport(_))); + } + + #[test] + fn test_error_report_constructors() { + let err = RtrErrorReport::unsupported_version(RtrProtocolVersion::V0, vec![1, 2, 3]); + assert_eq!(err.error_code, RtrErrorCode::UnsupportedProtocolVersion); + assert_eq!(err.error_text, "Unsupported protocol version"); + + let err = RtrErrorReport::unsupported_pdu_type(RtrProtocolVersion::V1, vec![4, 5, 6]); + assert_eq!(err.error_code, RtrErrorCode::UnsupportedPduType); + assert_eq!(err.error_text, "Unsupported PDU type"); + + let err = RtrErrorReport::corrupt_data(RtrProtocolVersion::V1, vec![], "test error"); + assert_eq!(err.error_code, RtrErrorCode::CorruptData); + assert_eq!(err.error_text, "test error"); + + let err = RtrErrorReport::new( + RtrProtocolVersion::V0, + RtrErrorCode::NoDataAvailable, + vec![7, 8, 9], + "Custom error".to_string(), + ); + assert_eq!(err.version, RtrProtocolVersion::V0); + assert_eq!(err.error_code, RtrErrorCode::NoDataAvailable); + assert_eq!(err.erroneous_pdu, vec![7, 8, 9]); + assert_eq!(err.error_text, "Custom error"); + } + + #[test] + #[cfg(feature = "serde")] + fn test_serde_roundtrip() { + let prefix = RtrIPv4Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(192, 0, 2, 0), + asn: 65001.into(), + }; + + let json = serde_json::to_string(&prefix).unwrap(); + let decoded: RtrIPv4Prefix = serde_json::from_str(&json).unwrap(); + assert_eq!(prefix, decoded); + } + + #[test] + fn test_try_from_protocol_version() { + // Test TryFrom for RtrProtocolVersion + let v0: Result = 0u8.try_into(); + assert_eq!(v0, Ok(RtrProtocolVersion::V0)); + + let v1: Result = 1u8.try_into(); + assert_eq!(v1, Ok(RtrProtocolVersion::V1)); + + let invalid: Result = 99u8.try_into(); + assert_eq!(invalid, Err(99u8)); + + // Test From for u8 + let byte: u8 = RtrProtocolVersion::V0.into(); + assert_eq!(byte, 0); + let byte: u8 = RtrProtocolVersion::V1.into(); + assert_eq!(byte, 1); + } + + #[test] + fn test_try_from_pdu_type() { + // Test TryFrom for RtrPduType + let serial_notify: Result = 0u8.try_into(); + assert_eq!(serial_notify, Ok(RtrPduType::SerialNotify)); + + let error_report: Result = 10u8.try_into(); + assert_eq!(error_report, Ok(RtrPduType::ErrorReport)); + + let invalid: Result = 5u8.try_into(); // Type 5 doesn't exist + assert_eq!(invalid, Err(5u8)); + + let invalid: Result = 255u8.try_into(); + assert_eq!(invalid, Err(255u8)); + + // Test From for u8 + let byte: u8 = RtrPduType::SerialNotify.into(); + assert_eq!(byte, 0); + let byte: u8 = RtrPduType::IPv6Prefix.into(); + assert_eq!(byte, 6); + let byte: u8 = RtrPduType::ErrorReport.into(); + assert_eq!(byte, 10); + } + + #[test] + fn test_try_from_error_code() { + // Test TryFrom for RtrErrorCode + let corrupt: Result = 0u16.try_into(); + assert_eq!(corrupt, Ok(RtrErrorCode::CorruptData)); + + let unexpected: Result = 8u16.try_into(); + assert_eq!(unexpected, Ok(RtrErrorCode::UnexpectedProtocolVersion)); + + let invalid: Result = 9u16.try_into(); + assert_eq!(invalid, Err(9u16)); + + let invalid: Result = 1000u16.try_into(); + assert_eq!(invalid, Err(1000u16)); + + // Test From for u16 + let code: u16 = RtrErrorCode::CorruptData.into(); + assert_eq!(code, 0); + let code: u16 = RtrErrorCode::UnexpectedProtocolVersion.into(); + assert_eq!(code, 8); + } + + #[test] + fn test_ipv6_prefix_withdrawal() { + let withdrawal = RtrIPv6Prefix { + version: RtrProtocolVersion::V1, + flags: 0, // withdrawal + prefix_length: 48, + max_length: 64, + prefix: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + asn: 65001.into(), + }; + assert!(!withdrawal.is_announcement()); + assert!(withdrawal.is_withdrawal()); + } + + #[test] + fn test_router_key_withdrawal() { + let withdrawal = RtrRouterKey { + version: RtrProtocolVersion::V1, + flags: 0, // withdrawal + subject_key_identifier: [1; 20], + asn: 65001.into(), + subject_public_key_info: vec![0xAB, 0xCD], + }; + assert!(!withdrawal.is_announcement()); + assert!(withdrawal.is_withdrawal()); + } + + #[test] + fn test_serial_query_new() { + let query = RtrSerialQuery::new(RtrProtocolVersion::V0, 12345, 67890); + assert_eq!(query.version, RtrProtocolVersion::V0); + assert_eq!(query.session_id, 12345); + assert_eq!(query.serial_number, 67890); + } + + #[test] + fn test_reset_query_new() { + let v0 = RtrResetQuery::new(RtrProtocolVersion::V0); + assert_eq!(v0.version, RtrProtocolVersion::V0); + + let v1 = RtrResetQuery::new(RtrProtocolVersion::V1); + assert_eq!(v1.version, RtrProtocolVersion::V1); + } + + #[test] + fn test_protocol_version_default() { + let default = RtrProtocolVersion::default(); + assert_eq!(default, RtrProtocolVersion::V1); + } + + #[test] + fn test_all_error_codes() { + // Test all error code conversions + let codes = [ + (0u16, RtrErrorCode::CorruptData), + (1u16, RtrErrorCode::InternalError), + (2u16, RtrErrorCode::NoDataAvailable), + (3u16, RtrErrorCode::InvalidRequest), + (4u16, RtrErrorCode::UnsupportedProtocolVersion), + (5u16, RtrErrorCode::UnsupportedPduType), + (6u16, RtrErrorCode::WithdrawalOfUnknownRecord), + (7u16, RtrErrorCode::DuplicateAnnouncementReceived), + (8u16, RtrErrorCode::UnexpectedProtocolVersion), + ]; + + for (value, expected) in codes { + assert_eq!(RtrErrorCode::from_u16(value), Some(expected)); + assert_eq!(expected.to_u16(), value); + } + } + + #[test] + fn test_all_pdu_types() { + // Test all PDU type conversions + let types = [ + (0u8, RtrPduType::SerialNotify), + (1u8, RtrPduType::SerialQuery), + (2u8, RtrPduType::ResetQuery), + (3u8, RtrPduType::CacheResponse), + (4u8, RtrPduType::IPv4Prefix), + (6u8, RtrPduType::IPv6Prefix), + (7u8, RtrPduType::EndOfData), + (8u8, RtrPduType::CacheReset), + (9u8, RtrPduType::RouterKey), + (10u8, RtrPduType::ErrorReport), + ]; + + for (value, expected) in types { + assert_eq!(RtrPduType::from_u8(value), Some(expected)); + assert_eq!(expected.to_u8(), value); + } + + // Test that type 5 doesn't exist + assert_eq!(RtrPduType::from_u8(5), None); + } +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 872b9b7f..553b2411 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -10,6 +10,7 @@ pub mod bmp; pub mod filter; pub mod iters; pub mod mrt; +pub mod rpki; #[cfg(feature = "rislive")] pub mod rislive; diff --git a/src/parser/rpki/mod.rs b/src/parser/rpki/mod.rs new file mode 100644 index 00000000..fcfd3904 --- /dev/null +++ b/src/parser/rpki/mod.rs @@ -0,0 +1,24 @@ +//! RPKI (Resource Public Key Infrastructure) protocol parsers. +//! +//! This module provides parsing and encoding functions for RPKI-related protocols: +//! +//! - [`rtr`]: RPKI-to-Router (RTR) Protocol parsing and encoding (RFC 6810, RFC 8210) +//! +//! # Example +//! +//! ```rust +//! use bgpkit_parser::parser::rpki::rtr::{parse_rtr_pdu, RtrEncode}; +//! use bgpkit_parser::models::rpki::rtr::*; +//! +//! // Create and encode a Reset Query +//! let query = RtrResetQuery::new_v1(); +//! let bytes = query.encode(); +//! +//! // Parse the bytes back +//! let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); +//! assert!(matches!(pdu, RtrPdu::ResetQuery(_))); +//! ``` + +pub mod rtr; + +pub use rtr::*; diff --git a/src/parser/rpki/rtr.rs b/src/parser/rpki/rtr.rs new file mode 100644 index 00000000..b7731c10 --- /dev/null +++ b/src/parser/rpki/rtr.rs @@ -0,0 +1,1511 @@ +//! RPKI-to-Router (RTR) Protocol Parser +//! +//! This module provides parsing and encoding functions for RTR protocol PDUs +//! as defined in RFC 6810 (v0) and RFC 8210 (v1). +//! +//! # Parsing +//! +//! ```rust +//! use bgpkit_parser::parser::rpki::rtr::{parse_rtr_pdu, read_rtr_pdu}; +//! use bgpkit_parser::models::rpki::rtr::*; +//! +//! // Parse from a byte slice +//! let bytes = [1, 2, 0, 0, 0, 0, 0, 8]; // Reset Query v1 +//! let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); +//! assert_eq!(consumed, 8); +//! ``` +//! +//! # Encoding +//! +//! ```rust +//! use bgpkit_parser::parser::rpki::rtr::RtrEncode; +//! use bgpkit_parser::models::rpki::rtr::*; +//! +//! let query = RtrResetQuery::new_v1(); +//! let bytes = query.encode(); +//! assert_eq!(bytes.len(), 8); +//! ``` + +use crate::models::rpki::rtr::*; +use crate::models::Asn; +use std::fmt; +use std::io::{self, Read}; +use std::net::{Ipv4Addr, Ipv6Addr}; + +// ============================================================================= +// Error Types +// ============================================================================= + +/// Errors that can occur during RTR PDU parsing or encoding +#[derive(Debug)] +pub enum RtrError { + /// I/O error during reading + IoError(io::Error), + /// PDU is incomplete (need more data) + IncompletePdu { + /// Number of bytes available + available: usize, + /// Number of bytes needed + needed: usize, + }, + /// Invalid PDU type + InvalidPduType(u8), + /// Invalid protocol version + InvalidProtocolVersion(u8), + /// Invalid error code + InvalidErrorCode(u16), + /// Invalid PDU length + InvalidLength { + /// Expected length + expected: u32, + /// Actual length in header + actual: u32, + /// PDU type + pdu_type: u8, + }, + /// Invalid prefix length + InvalidPrefixLength { + /// Prefix length + prefix_len: u8, + /// Maximum length + max_len: u8, + /// Maximum allowed for address family (32 for IPv4, 128 for IPv6) + max_allowed: u8, + }, + /// Invalid UTF-8 in error text + InvalidUtf8, + /// Router Key PDU in v0 (not supported) + RouterKeyInV0, +} + +impl fmt::Display for RtrError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RtrError::IoError(e) => write!(f, "I/O error: {}", e), + RtrError::IncompletePdu { available, needed } => { + write!( + f, + "Incomplete PDU: have {} bytes, need {} bytes", + available, needed + ) + } + RtrError::InvalidPduType(t) => write!(f, "Invalid PDU type: {}", t), + RtrError::InvalidProtocolVersion(v) => write!(f, "Invalid protocol version: {}", v), + RtrError::InvalidErrorCode(c) => write!(f, "Invalid error code: {}", c), + RtrError::InvalidLength { + expected, + actual, + pdu_type, + } => { + write!( + f, + "Invalid length for PDU type {}: expected {}, got {}", + pdu_type, expected, actual + ) + } + RtrError::InvalidPrefixLength { + prefix_len, + max_len, + max_allowed, + } => { + write!( + f, + "Invalid prefix length: prefix_len={}, max_len={}, max_allowed={}", + prefix_len, max_len, max_allowed + ) + } + RtrError::InvalidUtf8 => write!(f, "Invalid UTF-8 in error text"), + RtrError::RouterKeyInV0 => write!(f, "Router Key PDU is not valid in RTR v0"), + } + } +} + +impl std::error::Error for RtrError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + RtrError::IoError(e) => Some(e), + _ => None, + } + } +} + +impl From for RtrError { + fn from(e: io::Error) -> Self { + RtrError::IoError(e) + } +} + +// ============================================================================= +// PDU Length Constants +// ============================================================================= + +/// RTR PDU header length (common to all PDUs) +pub const RTR_HEADER_LEN: usize = 8; + +/// Serial Notify PDU length +pub const RTR_SERIAL_NOTIFY_LEN: u32 = 12; + +/// Serial Query PDU length +pub const RTR_SERIAL_QUERY_LEN: u32 = 12; + +/// Reset Query PDU length +pub const RTR_RESET_QUERY_LEN: u32 = 8; + +/// Cache Response PDU length +pub const RTR_CACHE_RESPONSE_LEN: u32 = 8; + +/// IPv4 Prefix PDU length +pub const RTR_IPV4_PREFIX_LEN: u32 = 20; + +/// IPv6 Prefix PDU length +pub const RTR_IPV6_PREFIX_LEN: u32 = 32; + +/// End of Data PDU length (v0) +pub const RTR_END_OF_DATA_V0_LEN: u32 = 12; + +/// End of Data PDU length (v1) +pub const RTR_END_OF_DATA_V1_LEN: u32 = 24; + +/// Cache Reset PDU length +pub const RTR_CACHE_RESET_LEN: u32 = 8; + +/// Router Key PDU minimum length (header(8) + flags(1) + zero(1) + SKI(20) + ASN(4) = 34) +pub const RTR_ROUTER_KEY_MIN_LEN: u32 = 34; + +// ============================================================================= +// Parsing Functions +// ============================================================================= + +/// Parse a single RTR PDU from a byte slice. +/// +/// Returns the parsed PDU and the number of bytes consumed. +/// +/// # Errors +/// +/// Returns an error if the input is too short, contains invalid data, +/// or references an unknown PDU type. +/// +/// # Example +/// +/// ```rust +/// use bgpkit_parser::parser::rpki::rtr::parse_rtr_pdu; +/// use bgpkit_parser::models::rpki::rtr::*; +/// +/// // Reset Query PDU (v1) +/// let bytes = [1, 2, 0, 0, 0, 0, 0, 8]; +/// let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); +/// assert!(matches!(pdu, RtrPdu::ResetQuery(_))); +/// assert_eq!(consumed, 8); +/// ``` +pub fn parse_rtr_pdu(input: &[u8]) -> Result<(RtrPdu, usize), RtrError> { + // Need at least the header + if input.len() < RTR_HEADER_LEN { + return Err(RtrError::IncompletePdu { + available: input.len(), + needed: RTR_HEADER_LEN, + }); + } + + // Parse header + let version_byte = input[0]; + let pdu_type_byte = input[1]; + let session_or_error = u16::from_be_bytes([input[2], input[3]]); + let length = u32::from_be_bytes([input[4], input[5], input[6], input[7]]); + + // Validate we have enough data + let length_usize = length as usize; + if input.len() < length_usize { + return Err(RtrError::IncompletePdu { + available: input.len(), + needed: length_usize, + }); + } + + // Parse version + let version = RtrProtocolVersion::from_u8(version_byte) + .ok_or(RtrError::InvalidProtocolVersion(version_byte))?; + + // Parse PDU type + let pdu_type = + RtrPduType::from_u8(pdu_type_byte).ok_or(RtrError::InvalidPduType(pdu_type_byte))?; + + // Parse based on PDU type + let pdu = match pdu_type { + RtrPduType::SerialNotify => { + validate_length(length, RTR_SERIAL_NOTIFY_LEN, pdu_type_byte)?; + let serial_number = u32::from_be_bytes([input[8], input[9], input[10], input[11]]); + RtrPdu::SerialNotify(RtrSerialNotify { + version, + session_id: session_or_error, + serial_number, + }) + } + + RtrPduType::SerialQuery => { + validate_length(length, RTR_SERIAL_QUERY_LEN, pdu_type_byte)?; + let serial_number = u32::from_be_bytes([input[8], input[9], input[10], input[11]]); + RtrPdu::SerialQuery(RtrSerialQuery { + version, + session_id: session_or_error, + serial_number, + }) + } + + RtrPduType::ResetQuery => { + validate_length(length, RTR_RESET_QUERY_LEN, pdu_type_byte)?; + RtrPdu::ResetQuery(RtrResetQuery { version }) + } + + RtrPduType::CacheResponse => { + validate_length(length, RTR_CACHE_RESPONSE_LEN, pdu_type_byte)?; + RtrPdu::CacheResponse(RtrCacheResponse { + version, + session_id: session_or_error, + }) + } + + RtrPduType::IPv4Prefix => { + validate_length(length, RTR_IPV4_PREFIX_LEN, pdu_type_byte)?; + let flags = input[8]; + let prefix_length = input[9]; + let max_length = input[10]; + // input[11] is reserved/zero + + validate_prefix_length(prefix_length, max_length, 32)?; + + let prefix = Ipv4Addr::new(input[12], input[13], input[14], input[15]); + let asn = u32::from_be_bytes([input[16], input[17], input[18], input[19]]); + + RtrPdu::IPv4Prefix(RtrIPv4Prefix { + version, + flags, + prefix_length, + max_length, + prefix, + asn: Asn::from(asn), + }) + } + + RtrPduType::IPv6Prefix => { + validate_length(length, RTR_IPV6_PREFIX_LEN, pdu_type_byte)?; + let flags = input[8]; + let prefix_length = input[9]; + let max_length = input[10]; + // input[11] is reserved/zero + + validate_prefix_length(prefix_length, max_length, 128)?; + + let prefix = Ipv6Addr::from([ + input[12], input[13], input[14], input[15], input[16], input[17], input[18], + input[19], input[20], input[21], input[22], input[23], input[24], input[25], + input[26], input[27], + ]); + let asn = u32::from_be_bytes([input[28], input[29], input[30], input[31]]); + + RtrPdu::IPv6Prefix(RtrIPv6Prefix { + version, + flags, + prefix_length, + max_length, + prefix, + asn: Asn::from(asn), + }) + } + + RtrPduType::EndOfData => { + let expected_len = match version { + RtrProtocolVersion::V0 => RTR_END_OF_DATA_V0_LEN, + RtrProtocolVersion::V1 => RTR_END_OF_DATA_V1_LEN, + }; + validate_length(length, expected_len, pdu_type_byte)?; + + let serial_number = u32::from_be_bytes([input[8], input[9], input[10], input[11]]); + + let (refresh_interval, retry_interval, expire_interval) = match version { + RtrProtocolVersion::V0 => (None, None, None), + RtrProtocolVersion::V1 => { + let refresh = u32::from_be_bytes([input[12], input[13], input[14], input[15]]); + let retry = u32::from_be_bytes([input[16], input[17], input[18], input[19]]); + let expire = u32::from_be_bytes([input[20], input[21], input[22], input[23]]); + (Some(refresh), Some(retry), Some(expire)) + } + }; + + RtrPdu::EndOfData(RtrEndOfData { + version, + session_id: session_or_error, + serial_number, + refresh_interval, + retry_interval, + expire_interval, + }) + } + + RtrPduType::CacheReset => { + validate_length(length, RTR_CACHE_RESET_LEN, pdu_type_byte)?; + RtrPdu::CacheReset(RtrCacheReset { version }) + } + + RtrPduType::RouterKey => { + // Router Key is v1 only + if version == RtrProtocolVersion::V0 { + return Err(RtrError::RouterKeyInV0); + } + + if length < RTR_ROUTER_KEY_MIN_LEN { + return Err(RtrError::InvalidLength { + expected: RTR_ROUTER_KEY_MIN_LEN, + actual: length, + pdu_type: pdu_type_byte, + }); + } + + let flags = input[8]; + // input[9] is zero + let mut ski = [0u8; 20]; + ski.copy_from_slice(&input[10..30]); + let asn = u32::from_be_bytes([input[30], input[31], input[32], input[33]]); + + // SPKI is the rest of the PDU (34 bytes of header + fixed fields already parsed) + let spki_len = (length as usize) - 34; + let spki = if spki_len > 0 { + input[34..34 + spki_len].to_vec() + } else { + Vec::new() + }; + + RtrPdu::RouterKey(RtrRouterKey { + version, + flags, + subject_key_identifier: ski, + asn: Asn::from(asn), + subject_public_key_info: spki, + }) + } + + RtrPduType::ErrorReport => { + // Error Report has variable length + // Minimum: header (8) + length of encapsulated PDU (4) + length of error text (4) = 16 + if length < 16 { + return Err(RtrError::InvalidLength { + expected: 16, + actual: length, + pdu_type: pdu_type_byte, + }); + } + + let error_code = RtrErrorCode::from_u16(session_or_error) + .ok_or(RtrError::InvalidErrorCode(session_or_error))?; + + let encap_pdu_len = + u32::from_be_bytes([input[8], input[9], input[10], input[11]]) as usize; + + // Validate encapsulated PDU fits + if 12 + encap_pdu_len + 4 > length_usize { + return Err(RtrError::InvalidLength { + expected: (12 + encap_pdu_len + 4) as u32, + actual: length, + pdu_type: pdu_type_byte, + }); + } + + let erroneous_pdu = if encap_pdu_len > 0 { + input[12..12 + encap_pdu_len].to_vec() + } else { + Vec::new() + }; + + let error_text_len_offset = 12 + encap_pdu_len; + let error_text_len = u32::from_be_bytes([ + input[error_text_len_offset], + input[error_text_len_offset + 1], + input[error_text_len_offset + 2], + input[error_text_len_offset + 3], + ]) as usize; + + let error_text_offset = error_text_len_offset + 4; + let error_text = if error_text_len > 0 { + std::str::from_utf8(&input[error_text_offset..error_text_offset + error_text_len]) + .map_err(|_| RtrError::InvalidUtf8)? + .to_string() + } else { + String::new() + }; + + RtrPdu::ErrorReport(RtrErrorReport { + version, + error_code, + erroneous_pdu, + error_text, + }) + } + }; + + Ok((pdu, length_usize)) +} + +/// Read a single RTR PDU from a reader. +/// +/// This function reads exactly one complete PDU from the reader. +/// +/// # Errors +/// +/// Returns an error if reading fails or the PDU is invalid. +/// +/// # Example +/// +/// ```rust,no_run +/// use std::net::TcpStream; +/// use bgpkit_parser::parser::rpki::rtr::read_rtr_pdu; +/// +/// let mut stream = TcpStream::connect("rtr.example.com:8282").unwrap(); +/// let pdu = read_rtr_pdu(&mut stream).unwrap(); +/// ``` +pub fn read_rtr_pdu(reader: &mut R) -> Result { + // Read header first + let mut header = [0u8; RTR_HEADER_LEN]; + reader.read_exact(&mut header)?; + + // Get length from header + let length = u32::from_be_bytes([header[4], header[5], header[6], header[7]]) as usize; + + if length < RTR_HEADER_LEN { + return Err(RtrError::InvalidLength { + expected: RTR_HEADER_LEN as u32, + actual: length as u32, + pdu_type: header[1], + }); + } + + // Allocate buffer for full PDU + let mut buffer = vec![0u8; length]; + buffer[..RTR_HEADER_LEN].copy_from_slice(&header); + + // Read remaining bytes + if length > RTR_HEADER_LEN { + reader.read_exact(&mut buffer[RTR_HEADER_LEN..])?; + } + + // Parse the complete PDU + let (pdu, _) = parse_rtr_pdu(&buffer)?; + Ok(pdu) +} + +fn validate_length(actual: u32, expected: u32, pdu_type: u8) -> Result<(), RtrError> { + if actual != expected { + Err(RtrError::InvalidLength { + expected, + actual, + pdu_type, + }) + } else { + Ok(()) + } +} + +fn validate_prefix_length(prefix_len: u8, max_len: u8, max_allowed: u8) -> Result<(), RtrError> { + if prefix_len > max_len || max_len > max_allowed { + Err(RtrError::InvalidPrefixLength { + prefix_len, + max_len, + max_allowed, + }) + } else { + Ok(()) + } +} + +// ============================================================================= +// Encoding Trait and Implementations +// ============================================================================= + +/// Trait for encoding RTR PDUs to bytes +pub trait RtrEncode { + /// Encode this PDU to a byte vector + fn encode(&self) -> Vec; +} + +impl RtrEncode for RtrSerialNotify { + fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(RTR_SERIAL_NOTIFY_LEN as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::SerialNotify.to_u8()); + buf.extend_from_slice(&self.session_id.to_be_bytes()); + buf.extend_from_slice(&RTR_SERIAL_NOTIFY_LEN.to_be_bytes()); + buf.extend_from_slice(&self.serial_number.to_be_bytes()); + buf + } +} + +impl RtrEncode for RtrSerialQuery { + fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(RTR_SERIAL_QUERY_LEN as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::SerialQuery.to_u8()); + buf.extend_from_slice(&self.session_id.to_be_bytes()); + buf.extend_from_slice(&RTR_SERIAL_QUERY_LEN.to_be_bytes()); + buf.extend_from_slice(&self.serial_number.to_be_bytes()); + buf + } +} + +impl RtrEncode for RtrResetQuery { + fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(RTR_RESET_QUERY_LEN as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::ResetQuery.to_u8()); + buf.extend_from_slice(&[0, 0]); // zero + buf.extend_from_slice(&RTR_RESET_QUERY_LEN.to_be_bytes()); + buf + } +} + +impl RtrEncode for RtrCacheResponse { + fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(RTR_CACHE_RESPONSE_LEN as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::CacheResponse.to_u8()); + buf.extend_from_slice(&self.session_id.to_be_bytes()); + buf.extend_from_slice(&RTR_CACHE_RESPONSE_LEN.to_be_bytes()); + buf + } +} + +impl RtrEncode for RtrIPv4Prefix { + fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(RTR_IPV4_PREFIX_LEN as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::IPv4Prefix.to_u8()); + buf.extend_from_slice(&[0, 0]); // zero + buf.extend_from_slice(&RTR_IPV4_PREFIX_LEN.to_be_bytes()); + buf.push(self.flags); + buf.push(self.prefix_length); + buf.push(self.max_length); + buf.push(0); // zero + buf.extend_from_slice(&self.prefix.octets()); + buf.extend_from_slice(&self.asn.to_u32().to_be_bytes()); + buf + } +} + +impl RtrEncode for RtrIPv6Prefix { + fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(RTR_IPV6_PREFIX_LEN as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::IPv6Prefix.to_u8()); + buf.extend_from_slice(&[0, 0]); // zero + buf.extend_from_slice(&RTR_IPV6_PREFIX_LEN.to_be_bytes()); + buf.push(self.flags); + buf.push(self.prefix_length); + buf.push(self.max_length); + buf.push(0); // zero + buf.extend_from_slice(&self.prefix.octets()); + buf.extend_from_slice(&self.asn.to_u32().to_be_bytes()); + buf + } +} + +impl RtrEncode for RtrEndOfData { + fn encode(&self) -> Vec { + let length = match self.version { + RtrProtocolVersion::V0 => RTR_END_OF_DATA_V0_LEN, + RtrProtocolVersion::V1 => RTR_END_OF_DATA_V1_LEN, + }; + let mut buf = Vec::with_capacity(length as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::EndOfData.to_u8()); + buf.extend_from_slice(&self.session_id.to_be_bytes()); + buf.extend_from_slice(&length.to_be_bytes()); + buf.extend_from_slice(&self.serial_number.to_be_bytes()); + + if self.version == RtrProtocolVersion::V1 { + buf.extend_from_slice( + &self + .refresh_interval + .unwrap_or(RtrEndOfData::DEFAULT_REFRESH) + .to_be_bytes(), + ); + buf.extend_from_slice( + &self + .retry_interval + .unwrap_or(RtrEndOfData::DEFAULT_RETRY) + .to_be_bytes(), + ); + buf.extend_from_slice( + &self + .expire_interval + .unwrap_or(RtrEndOfData::DEFAULT_EXPIRE) + .to_be_bytes(), + ); + } + buf + } +} + +impl RtrEncode for RtrCacheReset { + fn encode(&self) -> Vec { + let mut buf = Vec::with_capacity(RTR_CACHE_RESET_LEN as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::CacheReset.to_u8()); + buf.extend_from_slice(&[0, 0]); // zero + buf.extend_from_slice(&RTR_CACHE_RESET_LEN.to_be_bytes()); + buf + } +} + +impl RtrEncode for RtrRouterKey { + fn encode(&self) -> Vec { + let length = RTR_ROUTER_KEY_MIN_LEN + self.subject_public_key_info.len() as u32; + let mut buf = Vec::with_capacity(length as usize); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::RouterKey.to_u8()); + buf.extend_from_slice(&[0, 0]); // zero (session_id field is zero for Router Key) + buf.extend_from_slice(&length.to_be_bytes()); + buf.push(self.flags); + buf.push(0); // zero + buf.extend_from_slice(&self.subject_key_identifier); + buf.extend_from_slice(&self.asn.to_u32().to_be_bytes()); + buf.extend_from_slice(&self.subject_public_key_info); + buf + } +} + +impl RtrEncode for RtrErrorReport { + fn encode(&self) -> Vec { + let error_text_bytes = self.error_text.as_bytes(); + let length = 16 + self.erroneous_pdu.len() + error_text_bytes.len(); + let mut buf = Vec::with_capacity(length); + buf.push(self.version.to_u8()); + buf.push(RtrPduType::ErrorReport.to_u8()); + buf.extend_from_slice(&self.error_code.to_u16().to_be_bytes()); + buf.extend_from_slice(&(length as u32).to_be_bytes()); + buf.extend_from_slice(&(self.erroneous_pdu.len() as u32).to_be_bytes()); + buf.extend_from_slice(&self.erroneous_pdu); + buf.extend_from_slice(&(error_text_bytes.len() as u32).to_be_bytes()); + buf.extend_from_slice(error_text_bytes); + buf + } +} + +impl RtrEncode for RtrPdu { + fn encode(&self) -> Vec { + match self { + RtrPdu::SerialNotify(p) => p.encode(), + RtrPdu::SerialQuery(p) => p.encode(), + RtrPdu::ResetQuery(p) => p.encode(), + RtrPdu::CacheResponse(p) => p.encode(), + RtrPdu::IPv4Prefix(p) => p.encode(), + RtrPdu::IPv6Prefix(p) => p.encode(), + RtrPdu::EndOfData(p) => p.encode(), + RtrPdu::CacheReset(p) => p.encode(), + RtrPdu::RouterKey(p) => p.encode(), + RtrPdu::ErrorReport(p) => p.encode(), + } + } +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_reset_query_roundtrip() { + let query = RtrResetQuery::new_v1(); + let bytes = query.encode(); + assert_eq!(bytes.len(), 8); + + let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); + assert_eq!(consumed, 8); + assert!(matches!(pdu, RtrPdu::ResetQuery(q) if q.version == RtrProtocolVersion::V1)); + } + + #[test] + fn test_reset_query_v0_roundtrip() { + let query = RtrResetQuery::new_v0(); + let bytes = query.encode(); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + assert!(matches!(pdu, RtrPdu::ResetQuery(q) if q.version == RtrProtocolVersion::V0)); + } + + #[test] + fn test_serial_query_roundtrip() { + let query = RtrSerialQuery::new(RtrProtocolVersion::V1, 12345, 67890); + let bytes = query.encode(); + assert_eq!(bytes.len(), 12); + + let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); + assert_eq!(consumed, 12); + match pdu { + RtrPdu::SerialQuery(q) => { + assert_eq!(q.session_id, 12345); + assert_eq!(q.serial_number, 67890); + } + _ => panic!("Expected SerialQuery"), + } + } + + #[test] + fn test_serial_notify_roundtrip() { + let notify = RtrSerialNotify { + version: RtrProtocolVersion::V1, + session_id: 100, + serial_number: 200, + }; + let bytes = notify.encode(); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::SerialNotify(n) => { + assert_eq!(n.session_id, 100); + assert_eq!(n.serial_number, 200); + } + _ => panic!("Expected SerialNotify"), + } + } + + #[test] + fn test_cache_response_roundtrip() { + let response = RtrCacheResponse { + version: RtrProtocolVersion::V1, + session_id: 42, + }; + let bytes = response.encode(); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::CacheResponse(r) => { + assert_eq!(r.session_id, 42); + } + _ => panic!("Expected CacheResponse"), + } + } + + #[test] + fn test_ipv4_prefix_roundtrip() { + let prefix = RtrIPv4Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(192, 0, 2, 0), + asn: Asn::from(65001u32), + }; + let bytes = prefix.encode(); + assert_eq!(bytes.len(), 20); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::IPv4Prefix(p) => { + assert!(p.is_announcement()); + assert_eq!(p.prefix_length, 24); + assert_eq!(p.max_length, 24); + assert_eq!(p.prefix, Ipv4Addr::new(192, 0, 2, 0)); + assert_eq!(p.asn.to_u32(), 65001); + } + _ => panic!("Expected IPv4Prefix"), + } + } + + #[test] + fn test_ipv4_prefix_withdrawal() { + let prefix = RtrIPv4Prefix { + version: RtrProtocolVersion::V1, + flags: 0, // withdrawal + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(192, 0, 2, 0), + asn: Asn::from(65001u32), + }; + let bytes = prefix.encode(); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::IPv4Prefix(p) => { + assert!(p.is_withdrawal()); + assert!(!p.is_announcement()); + } + _ => panic!("Expected IPv4Prefix"), + } + } + + #[test] + fn test_ipv6_prefix_roundtrip() { + let prefix = RtrIPv6Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 48, + max_length: 64, + prefix: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + asn: Asn::from(65002u32), + }; + let bytes = prefix.encode(); + assert_eq!(bytes.len(), 32); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::IPv6Prefix(p) => { + assert!(p.is_announcement()); + assert_eq!(p.prefix_length, 48); + assert_eq!(p.max_length, 64); + assert_eq!(p.prefix, Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0)); + assert_eq!(p.asn.to_u32(), 65002); + } + _ => panic!("Expected IPv6Prefix"), + } + } + + #[test] + fn test_end_of_data_v0_roundtrip() { + let eod = RtrEndOfData { + version: RtrProtocolVersion::V0, + session_id: 100, + serial_number: 200, + refresh_interval: None, + retry_interval: None, + expire_interval: None, + }; + let bytes = eod.encode(); + assert_eq!(bytes.len(), 12); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::EndOfData(e) => { + assert_eq!(e.version, RtrProtocolVersion::V0); + assert_eq!(e.session_id, 100); + assert_eq!(e.serial_number, 200); + assert_eq!(e.refresh_interval, None); + assert_eq!(e.retry_interval, None); + assert_eq!(e.expire_interval, None); + } + _ => panic!("Expected EndOfData"), + } + } + + #[test] + fn test_end_of_data_v1_roundtrip() { + let eod = RtrEndOfData { + version: RtrProtocolVersion::V1, + session_id: 100, + serial_number: 200, + refresh_interval: Some(1800), + retry_interval: Some(300), + expire_interval: Some(3600), + }; + let bytes = eod.encode(); + assert_eq!(bytes.len(), 24); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::EndOfData(e) => { + assert_eq!(e.version, RtrProtocolVersion::V1); + assert_eq!(e.refresh_interval, Some(1800)); + assert_eq!(e.retry_interval, Some(300)); + assert_eq!(e.expire_interval, Some(3600)); + } + _ => panic!("Expected EndOfData"), + } + } + + #[test] + fn test_end_of_data_v1_with_defaults() { + let eod = RtrEndOfData { + version: RtrProtocolVersion::V1, + session_id: 100, + serial_number: 200, + refresh_interval: None, // Will use default when encoding + retry_interval: None, + expire_interval: None, + }; + let bytes = eod.encode(); + assert_eq!(bytes.len(), 24); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::EndOfData(e) => { + // Since v1 encoding always includes timing, they'll be the defaults + assert_eq!(e.refresh_interval, Some(3600)); + assert_eq!(e.retry_interval, Some(600)); + assert_eq!(e.expire_interval, Some(7200)); + } + _ => panic!("Expected EndOfData"), + } + } + + #[test] + fn test_cache_reset_roundtrip() { + let reset = RtrCacheReset { + version: RtrProtocolVersion::V1, + }; + let bytes = reset.encode(); + assert_eq!(bytes.len(), 8); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + assert!(matches!(pdu, RtrPdu::CacheReset(_))); + } + + #[test] + fn test_router_key_roundtrip() { + let key = RtrRouterKey { + version: RtrProtocolVersion::V1, + flags: 1, + subject_key_identifier: [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + ], + asn: Asn::from(65003u32), + subject_public_key_info: vec![0xAB, 0xCD, 0xEF], + }; + let bytes = key.encode(); + assert_eq!(bytes.len(), 37); // 34 min + 3 SPKI + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::RouterKey(k) => { + assert!(k.is_announcement()); + assert_eq!(k.subject_key_identifier[0], 1); + assert_eq!(k.subject_key_identifier[19], 20); + assert_eq!(k.asn.to_u32(), 65003); + assert_eq!(k.subject_public_key_info, vec![0xAB, 0xCD, 0xEF]); + } + _ => panic!("Expected RouterKey"), + } + } + + #[test] + fn test_router_key_in_v0_error() { + // Manually construct a Router Key PDU with v0 version + let mut bytes = vec![ + 0, // version 0 + 9, // type 9 (Router Key) + 0, 0, // zero + 0, 0, 0, 34, // length = 34 (minimum) + 1, // flags + 0, // zero + ]; + bytes.extend_from_slice(&[0u8; 20]); // SKI + bytes.extend_from_slice(&[0, 0, 0, 1]); // ASN + + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::RouterKeyInV0))); + } + + #[test] + fn test_error_report_roundtrip() { + let error = RtrErrorReport { + version: RtrProtocolVersion::V1, + error_code: RtrErrorCode::UnsupportedProtocolVersion, + erroneous_pdu: vec![99, 2, 0, 0, 0, 0, 0, 8], // Some invalid PDU + error_text: "Test error".to_string(), + }; + let bytes = error.encode(); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::ErrorReport(e) => { + assert_eq!(e.error_code, RtrErrorCode::UnsupportedProtocolVersion); + assert_eq!(e.erroneous_pdu, vec![99, 2, 0, 0, 0, 0, 0, 8]); + assert_eq!(e.error_text, "Test error"); + } + _ => panic!("Expected ErrorReport"), + } + } + + #[test] + fn test_error_report_empty() { + let error = RtrErrorReport { + version: RtrProtocolVersion::V1, + error_code: RtrErrorCode::InternalError, + erroneous_pdu: vec![], + error_text: String::new(), + }; + let bytes = error.encode(); + assert_eq!(bytes.len(), 16); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::ErrorReport(e) => { + assert_eq!(e.error_code, RtrErrorCode::InternalError); + assert!(e.erroneous_pdu.is_empty()); + assert!(e.error_text.is_empty()); + } + _ => panic!("Expected ErrorReport"), + } + } + + #[test] + fn test_incomplete_pdu_error() { + let bytes = [1, 2, 0]; // Too short + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::IncompletePdu { .. }))); + } + + #[test] + fn test_invalid_pdu_type_error() { + let bytes = [1, 5, 0, 0, 0, 0, 0, 8]; // Type 5 doesn't exist + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidPduType(5)))); + } + + #[test] + fn test_invalid_protocol_version_error() { + let bytes = [99, 2, 0, 0, 0, 0, 0, 8]; // Version 99 doesn't exist + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidProtocolVersion(99)))); + } + + #[test] + fn test_invalid_length_error() { + // Reset Query with wrong length - need full buffer to match declared length + let bytes = [1, 2, 0, 0, 0, 0, 0, 10, 0, 0]; // Length says 10, should be 8 + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidLength { .. }))); + } + + #[test] + fn test_invalid_prefix_length_error() { + // IPv4 prefix with prefix_len > max_len + let mut bytes = vec![ + 1, // version + 4, // type (IPv4 Prefix) + 0, 0, // zero + 0, 0, 0, 20, // length + 1, // flags + 25, // prefix_length (25) + 24, // max_length (24) - INVALID: prefix_len > max_len + 0, // zero + ]; + bytes.extend_from_slice(&[192, 0, 2, 0]); // prefix + bytes.extend_from_slice(&[0, 0, 0, 1]); // ASN + + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidPrefixLength { .. }))); + } + + #[test] + fn test_invalid_max_length_error() { + // IPv4 prefix with max_len > 32 + let mut bytes = vec![ + 1, // version + 4, // type (IPv4 Prefix) + 0, 0, // zero + 0, 0, 0, 20, // length + 1, // flags + 24, // prefix_length + 33, // max_length (33) - INVALID: > 32 for IPv4 + 0, // zero + ]; + bytes.extend_from_slice(&[192, 0, 2, 0]); // prefix + bytes.extend_from_slice(&[0, 0, 0, 1]); // ASN + + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidPrefixLength { .. }))); + } + + #[test] + fn test_read_rtr_pdu_from_cursor() { + use std::io::Cursor; + + let query = RtrResetQuery::new_v1(); + let bytes = query.encode(); + let mut cursor = Cursor::new(bytes); + + let pdu = read_rtr_pdu(&mut cursor).unwrap(); + assert!(matches!(pdu, RtrPdu::ResetQuery(_))); + } + + #[test] + fn test_pdu_enum_encode() { + let pdu = RtrPdu::ResetQuery(RtrResetQuery::new_v1()); + let bytes = pdu.encode(); + assert_eq!(bytes.len(), 8); + + let (parsed, _) = parse_rtr_pdu(&bytes).unwrap(); + assert!(matches!(parsed, RtrPdu::ResetQuery(_))); + } + + #[test] + fn test_all_pdu_types_roundtrip() { + // Test that all PDU types can be encoded and decoded + let pdus: Vec = vec![ + RtrPdu::SerialNotify(RtrSerialNotify { + version: RtrProtocolVersion::V1, + session_id: 1, + serial_number: 100, + }), + RtrPdu::SerialQuery(RtrSerialQuery::new(RtrProtocolVersion::V1, 1, 100)), + RtrPdu::ResetQuery(RtrResetQuery::new_v1()), + RtrPdu::CacheResponse(RtrCacheResponse { + version: RtrProtocolVersion::V1, + session_id: 1, + }), + RtrPdu::IPv4Prefix(RtrIPv4Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 24, + max_length: 24, + prefix: Ipv4Addr::new(10, 0, 0, 0), + asn: Asn::from(65000u32), + }), + RtrPdu::IPv6Prefix(RtrIPv6Prefix { + version: RtrProtocolVersion::V1, + flags: 1, + prefix_length: 48, + max_length: 48, + prefix: Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0), + asn: Asn::from(65000u32), + }), + RtrPdu::EndOfData(RtrEndOfData { + version: RtrProtocolVersion::V1, + session_id: 1, + serial_number: 100, + refresh_interval: Some(3600), + retry_interval: Some(600), + expire_interval: Some(7200), + }), + RtrPdu::CacheReset(RtrCacheReset { + version: RtrProtocolVersion::V1, + }), + RtrPdu::RouterKey(RtrRouterKey { + version: RtrProtocolVersion::V1, + flags: 1, + subject_key_identifier: [0; 20], + asn: Asn::from(65000u32), + subject_public_key_info: vec![1, 2, 3, 4], + }), + RtrPdu::ErrorReport(RtrErrorReport { + version: RtrProtocolVersion::V1, + error_code: RtrErrorCode::NoDataAvailable, + erroneous_pdu: vec![], + error_text: "No data".to_string(), + }), + ]; + + for original in pdus { + let bytes = original.encode(); + let (parsed, consumed) = parse_rtr_pdu(&bytes).unwrap(); + assert_eq!(consumed, bytes.len()); + assert_eq!(parsed.pdu_type(), original.pdu_type()); + } + } + + #[test] + fn test_error_display() { + let err = RtrError::InvalidPduType(42); + assert!(err.to_string().contains("42")); + + let err = RtrError::IncompletePdu { + available: 4, + needed: 8, + }; + assert!(err.to_string().contains("4")); + assert!(err.to_string().contains("8")); + } + + #[test] + fn test_error_display_all_variants() { + // Test Display for all RtrError variants + let io_err = RtrError::IoError(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "connection reset", + )); + assert!(io_err.to_string().contains("I/O error")); + + let incomplete = RtrError::IncompletePdu { + available: 5, + needed: 10, + }; + assert!(incomplete.to_string().contains("Incomplete PDU")); + assert!(incomplete.to_string().contains("5")); + assert!(incomplete.to_string().contains("10")); + + let invalid_type = RtrError::InvalidPduType(99); + assert!(invalid_type.to_string().contains("Invalid PDU type")); + assert!(invalid_type.to_string().contains("99")); + + let invalid_version = RtrError::InvalidProtocolVersion(5); + assert!(invalid_version + .to_string() + .contains("Invalid protocol version")); + assert!(invalid_version.to_string().contains("5")); + + let invalid_error_code = RtrError::InvalidErrorCode(100); + assert!(invalid_error_code + .to_string() + .contains("Invalid error code")); + assert!(invalid_error_code.to_string().contains("100")); + + let invalid_length = RtrError::InvalidLength { + expected: 20, + actual: 15, + pdu_type: 4, + }; + assert!(invalid_length.to_string().contains("Invalid length")); + assert!(invalid_length.to_string().contains("20")); + assert!(invalid_length.to_string().contains("15")); + assert!(invalid_length.to_string().contains("4")); + + let invalid_prefix = RtrError::InvalidPrefixLength { + prefix_len: 25, + max_len: 24, + max_allowed: 32, + }; + assert!(invalid_prefix.to_string().contains("Invalid prefix length")); + assert!(invalid_prefix.to_string().contains("25")); + assert!(invalid_prefix.to_string().contains("24")); + assert!(invalid_prefix.to_string().contains("32")); + + let invalid_utf8 = RtrError::InvalidUtf8; + assert!(invalid_utf8.to_string().contains("Invalid UTF-8")); + + let router_key_v0 = RtrError::RouterKeyInV0; + assert!(router_key_v0.to_string().contains("Router Key PDU")); + assert!(router_key_v0.to_string().contains("v0")); + } + + #[test] + fn test_error_source() { + use std::error::Error; + + // IoError should have a source + let io_err = RtrError::IoError(std::io::Error::new( + std::io::ErrorKind::NotFound, + "file not found", + )); + assert!(io_err.source().is_some()); + + // Other errors should not have a source + let incomplete = RtrError::IncompletePdu { + available: 1, + needed: 2, + }; + assert!(incomplete.source().is_none()); + + let invalid_type = RtrError::InvalidPduType(5); + assert!(invalid_type.source().is_none()); + + let invalid_utf8 = RtrError::InvalidUtf8; + assert!(invalid_utf8.source().is_none()); + } + + #[test] + fn test_error_from_io_error() { + let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout"); + let rtr_err: RtrError = io_err.into(); + assert!(matches!(rtr_err, RtrError::IoError(_))); + } + + #[test] + fn test_read_rtr_pdu_short_length() { + use std::io::Cursor; + + // PDU with length less than header size + let bytes = vec![ + 1, // version + 2, // type (Reset Query) + 0, 0, // zero + 0, 0, 0, 4, // length = 4 (less than header) + ]; + let mut cursor = Cursor::new(bytes); + let result = read_rtr_pdu(&mut cursor); + assert!(matches!(result, Err(RtrError::InvalidLength { .. }))); + } + + #[test] + fn test_parse_invalid_error_code() { + // Error Report with invalid error code + let bytes = vec![ + 1, // version + 10, // type (Error Report) + 0, 100, // error code = 100 (invalid) + 0, 0, 0, 16, // length = 16 (minimum) + 0, 0, 0, 0, // encapsulated PDU length = 0 + 0, 0, 0, 0, // error text length = 0 + ]; + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidErrorCode(100)))); + } + + #[test] + fn test_parse_error_report_invalid_utf8() { + // Error Report with invalid UTF-8 in error text + let bytes = vec![ + 1, // version + 10, // type (Error Report) + 0, 0, // error code = 0 + 0, 0, 0, 20, // length = 20 + 0, 0, 0, 0, // encapsulated PDU length = 0 + 0, 0, 0, 4, // error text length = 4 + 0xFF, 0xFE, 0xFF, 0xFE, // invalid UTF-8 + ]; + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidUtf8))); + } + + #[test] + fn test_parse_error_report_truncated() { + // Error Report with encapsulated PDU length exceeding bounds + let bytes = vec![ + 1, // version + 10, // type (Error Report) + 0, 0, // error code = 0 + 0, 0, 0, 16, // length = 16 + 0, 0, 0, 100, // encapsulated PDU length = 100 (too large) + 0, 0, 0, 0, // error text length = 0 + ]; + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidLength { .. }))); + } + + #[test] + fn test_parse_router_key_empty_spki() { + // Router Key with no SPKI (minimum length) + let mut bytes = vec![ + 1, // version + 9, // type (Router Key) + 0, 0, // zero + 0, 0, 0, 34, // length = 34 (minimum) + 1, // flags + 0, // zero + ]; + bytes.extend_from_slice(&[0u8; 20]); // SKI + bytes.extend_from_slice(&[0, 0, 0, 1]); // ASN = 1 + + let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); + assert_eq!(consumed, 34); + match pdu { + RtrPdu::RouterKey(k) => { + assert!(k.subject_public_key_info.is_empty()); + } + _ => panic!("Expected RouterKey"), + } + } + + #[test] + fn test_parse_ipv6_invalid_max_length() { + // IPv6 prefix with max_len > 128 + let mut bytes = vec![ + 1, // version + 6, // type (IPv6 Prefix) + 0, 0, // zero + 0, 0, 0, 32, // length + 1, // flags + 64, // prefix_length + 129, // max_length (129) - INVALID: > 128 for IPv6 + 0, // zero + ]; + bytes.extend_from_slice(&[0u8; 16]); // prefix + bytes.extend_from_slice(&[0, 0, 0, 1]); // ASN + + let result = parse_rtr_pdu(&bytes); + assert!(matches!(result, Err(RtrError::InvalidPrefixLength { .. }))); + } + + #[test] + fn test_encode_all_pdu_types_v0() { + // Test encoding v0 PDUs + let notify = RtrSerialNotify { + version: RtrProtocolVersion::V0, + session_id: 100, + serial_number: 200, + }; + let bytes = notify.encode(); + assert_eq!(bytes[0], 0); // version 0 + + let response = RtrCacheResponse { + version: RtrProtocolVersion::V0, + session_id: 300, + }; + let bytes = response.encode(); + assert_eq!(bytes[0], 0); // version 0 + + let reset = RtrCacheReset { + version: RtrProtocolVersion::V0, + }; + let bytes = reset.encode(); + assert_eq!(bytes[0], 0); // version 0 + + let prefix4 = RtrIPv4Prefix { + version: RtrProtocolVersion::V0, + flags: 0, + prefix_length: 16, + max_length: 24, + prefix: Ipv4Addr::new(172, 16, 0, 0), + asn: Asn::from(64512u32), + }; + let bytes = prefix4.encode(); + assert_eq!(bytes[0], 0); // version 0 + assert_eq!(bytes.len(), 20); + + let prefix6 = RtrIPv6Prefix { + version: RtrProtocolVersion::V0, + flags: 1, + prefix_length: 32, + max_length: 48, + prefix: Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 0), + asn: Asn::from(64513u32), + }; + let bytes = prefix6.encode(); + assert_eq!(bytes[0], 0); // version 0 + assert_eq!(bytes.len(), 32); + } + + #[test] + fn test_read_multiple_pdus() { + use std::io::Cursor; + + // Create buffer with two PDUs + let query1 = RtrResetQuery::new_v1(); + let query2 = RtrSerialQuery::new(RtrProtocolVersion::V1, 100, 200); + + let mut buffer = query1.encode(); + buffer.extend(query2.encode()); + + let mut cursor = Cursor::new(buffer); + + // Read first PDU + let pdu1 = read_rtr_pdu(&mut cursor).unwrap(); + assert!(matches!(pdu1, RtrPdu::ResetQuery(_))); + + // Read second PDU + let pdu2 = read_rtr_pdu(&mut cursor).unwrap(); + assert!(matches!(pdu2, RtrPdu::SerialQuery(_))); + } + + #[test] + fn test_parse_with_extra_bytes() { + // PDU followed by extra bytes - should only consume PDU length + let query = RtrResetQuery::new_v1(); + let mut bytes = query.encode(); + bytes.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]); // extra bytes + + let (pdu, consumed) = parse_rtr_pdu(&bytes).unwrap(); + assert!(matches!(pdu, RtrPdu::ResetQuery(_))); + assert_eq!(consumed, 8); // Only consumed the PDU, not extra bytes + } + + #[test] + fn test_error_report_with_pdu_and_text() { + // Full error report with both encapsulated PDU and error text + let error = RtrErrorReport { + version: RtrProtocolVersion::V1, + error_code: RtrErrorCode::CorruptData, + erroneous_pdu: vec![1, 2, 3, 4, 5, 6, 7, 8], // 8 bytes + error_text: "Something went wrong!".to_string(), // 21 bytes + }; + let bytes = error.encode(); + + let (pdu, _) = parse_rtr_pdu(&bytes).unwrap(); + match pdu { + RtrPdu::ErrorReport(e) => { + assert_eq!(e.error_code, RtrErrorCode::CorruptData); + assert_eq!(e.erroneous_pdu, vec![1, 2, 3, 4, 5, 6, 7, 8]); + assert_eq!(e.error_text, "Something went wrong!"); + } + _ => panic!("Expected ErrorReport"), + } + } +}